diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_C/_acc/__init__.pyi b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_C/_acc/__init__.pyi new file mode 100644 index 0000000000000000000000000000000000000000..aa17e5cb2190bbe5d4f9d349a03ff2ffb319e603 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_C/_acc/__init__.pyi @@ -0,0 +1,15 @@ +from torch import Tensor +from torch.types import _dtype, _int, Device + +# Defined in torch/csrc/acc/Module.cpp +class PrivateUse1Hooks: + def has_primary_context(self, device_index: _int) -> bool: ... + def is_built(self) -> bool: ... + def is_avaible(self) -> bool: ... + +class DeviceGuard: + def type_(self) -> Device: ... + +def register_python_privateuseone_device_guard(guard: DeviceGuard) -> bool: ... +def register_python_privateuseone_hook(hook: PrivateUse1Hooks) -> bool: ... +def create_empty_tensor(shape: tuple[_int, ...], dtype: _dtype) -> Tensor: ... diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_C/_dynamo/__init__.pyi b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_C/_dynamo/__init__.pyi new file mode 100644 index 0000000000000000000000000000000000000000..67d515697cbe4b43edb18dbdc4cf0270ebf13fb2 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_C/_dynamo/__init__.pyi @@ -0,0 +1,4 @@ +from . import compiled_autograd, eval_frame, guards # noqa: F401 + +def strip_function_call(name: str) -> str: ... +def is_valid_var_name(name: str) -> bool | int: ... diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_C/_dynamo/compiled_autograd.pyi b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_C/_dynamo/compiled_autograd.pyi new file mode 100644 index 0000000000000000000000000000000000000000..ef24582b5023109733955cc77db0a84fae03b3fd --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_C/_dynamo/compiled_autograd.pyi @@ -0,0 +1,13 @@ +from collections.abc import Callable + +from torch import Tensor +from torch._dynamo.compiled_autograd import AutogradCompilerInstance + +def set_autograd_compiler( + autograd_compiler: Callable[[], AutogradCompilerInstance] | None, + dynamic: bool, +) -> tuple[Callable[[], AutogradCompilerInstance] | None, bool]: ... +def clear_cache() -> None: ... +def is_cache_empty() -> bool: ... +def set_verbose_logger(fn: Callable[[str], None] | None) -> bool: ... +def call_cpp_tensor_pre_hooks(idx: int, grad: Tensor) -> Tensor: ... diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_C/_dynamo/eval_frame.pyi b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_C/_dynamo/eval_frame.pyi new file mode 100644 index 0000000000000000000000000000000000000000..641aaece6269c51fd94edc0ed0ceb2ac51a8b62c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_C/_dynamo/eval_frame.pyi @@ -0,0 +1,84 @@ +import enum +import types +from collections.abc import Callable +from typing import Optional, overload + +from torch._dynamo.guards import GuardManagerWrapper +from torch._dynamo.types import DynamoCallback, DynamoGuardCompleteHook, DynamoGuardHook +from torch._guards import CompileId + +def set_eval_frame(callback: DynamoCallback) -> DynamoCallback: ... +def set_skip_guard_eval_unsafe(value: bool) -> bool: ... +def get_eval_frame_callback() -> DynamoCallback: ... +def reset_code(code: types.CodeType) -> None: ... +def unsupported(obj1: object, obj2: object) -> object: ... +def set_code_exec_strategy( + code: types.CodeType, strategy: _FrameExecStrategy +) -> None: ... +def set_guard_error_hook(hook: DynamoGuardHook) -> None: ... +def set_guard_complete_hook( + hook: Optional[DynamoGuardCompleteHook], +) -> Optional[DynamoGuardCompleteHook]: ... +def raise_sigtrap() -> None: ... +def set_c_recursion_limit(limit: int) -> None: ... +def get_c_recursion_limit() -> int: ... + +class _CacheEntry: + def check_fn(self, *args: object, **kwargs: object) -> bool: ... + def update_diff_guard_root_manager(self) -> None: ... + code: types.CodeType + compile_id: CompileId + # If we run into circular issues, just use object + guard_manager: GuardManagerWrapper + backend: Callable + next: _CacheEntry | None + +class _PrecompileEntry: + guard_manager: GuardManagerWrapper + +class _ExtraState: + def invalidate( + self, cache_entry: _CacheEntry, guard_manager: GuardManagerWrapper + ) -> None: ... + +class _FrameAction(enum.IntEnum): + DEFAULT = 0 + SKIP = 1 + RUN_ONLY = 2 + +class _FrameExecStrategy: + cur_action: _FrameAction + recursive_action: _FrameAction + + @overload + def __init__(self) -> None: ... + @overload + def __init__( + self, cur_action: _FrameAction, recursive_action: _FrameAction + ) -> None: ... + +# This is an object that encapsulates the Python FrameType, and exposes +# properties Dynamo cares about for a frame. +class _PyInterpreterFrame: + f_code: types.CodeType + f_locals: dict[str, object] + f_globals: dict[str, object] + f_builtins: dict[str, object] + f_lasti: int + f_lineno: int + f_back: types.FrameType + # A tuple containing cell objects captured by this frame. + closure: tuple[types.CellType] + +def _debug_get_cache_entry_list(code: types.CodeType) -> list[_CacheEntry]: ... + +py_opcode_caches: list[int] + +def code_framelocals_names(code: types.CodeType) -> tuple[str, ...]: ... +def _load_precompile_entry( + code: types.CodeType, + guard_manager: GuardManagerWrapper, + dynamo_code: types.CodeType, +) -> None: ... +def _reset_precompile_entries(code: types.CodeType) -> None: ... +def _debug_get_precompile_entries(code: types.CodeType) -> list[_PrecompileEntry]: ... diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_C/_dynamo/guards.pyi b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_C/_dynamo/guards.pyi new file mode 100644 index 0000000000000000000000000000000000000000..e3003f0e97b12b58f65454ccaeb82d305f884233 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_C/_dynamo/guards.pyi @@ -0,0 +1,452 @@ +import enum +from collections.abc import Callable +from typing import Any, Optional, TypeAlias + +import torch + +# TODO: We should move the `GuardManagerType` +# defined in `guards.py` here and update other +# imports +GuardManagerType: TypeAlias = enum.Enum + +class GlobalStateGuard: + def check(self) -> bool: ... + def reason(self) -> str: ... + +class LeafGuard: + def verbose_code_parts(self) -> list[str]: ... + +class RelationalGuard: ... + +class GuardDebugInfo: + verbose_code_parts: list[str] + result: bool + num_guards_executed: int + +class GuardManager: + def check(self, value: Any) -> bool: ... + def check_verbose(self, value: Any) -> GuardDebugInfo: ... + + # Accessors + def globals_dict_manager( + self, + f_globals: dict[str, Any], + source: str, + example_value: Any, + guard_manager_enum: GuardManagerType, + ) -> GuardManager: ... + def framelocals_manager( + self, + key: tuple[str, int], + source: str, + example_value: Any, + guard_manager_enum: GuardManagerType, + ) -> GuardManager: ... + def dict_getitem_manager( + self, + key: Any, + source: str, + example_value: Any, + guard_manager_enum: GuardManagerType, + ) -> GuardManager: ... + def grad_manager( + self, + source: str, + example_value: Any, + guard_manager_enum: GuardManagerType, + ) -> GuardManager: ... + def generic_getattr_manager( + self, + attr: str, + source: str, + example_value: Any, + guard_manager_enum: GuardManagerType, + ) -> GuardManager: ... + def getitem_manager( + self, + key: Any, + source: str, + example_value: Any, + guard_manager_enum: GuardManagerType, + ) -> GuardManager: ... + def get_generic_dict_manager( + self, + source: str, + example_value: Any, + guard_manager_enum: GuardManagerType, + ) -> GuardManager: ... + def list_getitem_manager( + self, + key: Any, + source: str, + example_value: Any, + guard_manager_enum: GuardManagerType, + ) -> GuardManager: ... + def tuple_getitem_manager( + self, + key: Any, + source: str, + example_value: Any, + guard_manager_enum: GuardManagerType, + ) -> GuardManager: ... + def set_getitem_manager( + self, + index: Any, + source: str, + example_value: Any, + guard_manager_enum: GuardManagerType, + ) -> GuardManager: ... + def func_defaults_manager( + self, + source: str, + example_value: Any, + guard_manager_enum: GuardManagerType, + ) -> GuardManager: ... + def func_kwdefaults_manager( + self, + source: str, + example_value: Any, + guard_manager_enum: GuardManagerType, + ) -> GuardManager: ... + def tuple_iterator_getitem_manager( + self, + index: Any, + source: str, + example_value: Any, + guard_manager_enum: GuardManagerType, + ) -> GuardManager: ... + def weakref_call_manager( + self, + source: str, + example_value: Any, + guard_manager_enum: GuardManagerType, + ) -> GuardManager: ... + def call_function_no_args_manager( + self, + source: str, + example_value: Any, + guard_manager_enum: GuardManagerType, + ) -> GuardManager: ... + def global_weakref_manager( + self, + global_name: str, + source: str, + example_value: Any, + guard_manager_enum: GuardManagerType, + ) -> GuardManager: ... + def type_manager( + self, + source: str, + example_value: Any, + guard_manager_enum: GuardManagerType, + ) -> GuardManager: ... + def getattr_manager( + self, + attr: str, + source: str, + example_value: Any, + guard_manager_enum: GuardManagerType, + ) -> GuardManager: ... + def tensor_property_size_manager( + self, + idx: int, + source: str, + example_value: Any, + guard_manager_enum: GuardManagerType, + ) -> GuardManager: ... + def tensor_property_shape_manager( + self, + idx: int, + source: str, + example_value: Any, + guard_manager_enum: GuardManagerType, + ) -> GuardManager: ... + def tensor_property_storage_offset_manager( + self, + idx: int, + source: str, + example_value: Any, + guard_manager_enum: GuardManagerType, + ) -> GuardManager: ... + def indexed_manager( + self, + idx: int, + source: str, + example_value: Any, + guard_manager_enum: GuardManagerType, + ) -> GuardManager: ... + def lambda_manager( + self, + python_lambda: Callable[..., Any], + source: str, + example_value: Any, + guard_manager_enum: GuardManagerType, + ) -> GuardManager: ... + def get_root(self) -> RootGuardManager: ... + def get_source(self) -> str: ... + def fail_count(self) -> int: ... + def get_child_managers(self) -> list[GuardManager]: ... + def repr(self) -> str: ... + def type_of_guarded_value(self) -> str: ... + def get_leaf_guards(self) -> list[LeafGuard]: ... + def get_accessors(self) -> list[GuardManager]: ... + def is_guarded_value_immutable(self) -> bool: ... + def is_tag_safe(self) -> bool: ... + def is_tag_safe_root(self) -> bool: ... + def has_no_accessors(self) -> bool: ... + def has_object_aliasing_guard(self) -> bool: ... + def get_type_of_guarded_value(self) -> type: ... + def type_dict_manager( + self, + source: str, + example_value: Any, + guard_manager_enum: GuardManagerType, + ) -> GuardManager: ... + def type_mro_manager( + self, + source: str, + example_value: Any, + guard_manager_enum: GuardManagerType, + ) -> GuardManager: ... + def code_manager( + self, + source: str, + example_value: Any, + guard_manager_enum: GuardManagerType, + ) -> GuardManager: ... + def closure_manager( + self, + source: str, + example_value: Any, + guard_manager_enum: GuardManagerType, + ) -> GuardManager: ... + # Leaf guards + def add_lambda_guard( + self, user_lambda: Callable[..., Any], verbose_code_parts: list[str] + ) -> None: ... + def add_id_match_guard( + self, id_val: int, verbose_code_parts: list[str] + ) -> None: ... + def add_equals_match_guard( + self, + equals_val: Any, + verbose_code_parts: list[str], + ) -> None: ... + def add_global_state_guard( + self, initial_state: Any, verbose_code_parts: list[str] + ) -> None: ... + def add_torch_function_mode_stack_guard( + self, initial_stack: list[Any], verbose_code_parts: list[str] + ) -> None: ... + def add_mapping_keys_guard( + self, value: Any, verbose_code_parts: list[str] + ) -> None: ... + def add_dict_length_check_guard( + self, value: int, verbose_code_parts: list[str] + ) -> None: ... + def add_length_check_guard( + self, value: int, verbose_code_parts: list[str] + ) -> None: ... + def add_true_match_guard( + self, + verbose_code_parts: list[str], + ) -> None: ... + def add_false_match_guard( + self, + verbose_code_parts: list[str], + ) -> None: ... + def add_none_match_guard( + self, + verbose_code_parts: list[str], + ) -> None: ... + def add_not_none_guard( + self, + verbose_code_parts: list[str], + ) -> None: ... + def add_dispatch_key_set_guard( + self, + dispatch_key: Any, + verbose_code_parts: list[str], + ) -> None: ... + def add_tensor_match_guard( + self, + value: Any, + sizes: list[int], + strides: list[int], + tensor_name: str, + verbose_code_parts: list[str], + ptype: Any, + dispatch_keys: Any, + ) -> None: ... + def add_dynamic_indices_guard( + self, + value: set[Any], + verbose_code_parts: list[str], + ) -> None: ... + def add_no_hasattr_guard( + self, + attr_name: str, + verbose_code_parts: list[str], + ) -> None: ... + def add_dict_contains_guard( + self, + contains: bool, + key: Any, + verbose_code_parts: list[str], + ) -> None: ... + def add_type_match_guard( + self, + value: int, + verbose_code_parts: list[str], + ) -> None: ... + def add_dict_version_guard( + self, + value: Any, + verbose_code_parts: list[str], + ) -> None: ... + def add_set_contains_guard( + self, + contains: bool, + item: Any, + verbose_code_parts: list[str], + ) -> None: ... + def add_dual_level_match_guard( + self, + level: int, + verbose_code_parts: list[str], + ) -> None: ... + def add_float_is_nan_guard( + self, + verbose_code_parts: list[str], + ) -> None: ... + def add_complex_is_nan_guard( + self, + verbose_code_parts: list[str], + ) -> None: ... + def add_tuple_iterator_length_guard( + self, + length: int, + type_id: int, + verbose_code_parts: list[str], + ) -> None: ... + def add_range_iterator_match_guard( + self, + start: int, + stop: int, + step: int, + type_id: int, + verbose_code_parts: list[str], + ) -> None: ... + def add_default_device_guard( + self, + verbose_code_parts: list[str], + ) -> None: ... + def mark_tag_safe(self) -> None: ... + def mark_tag_safe_root(self) -> None: ... + +class RootGuardManager(GuardManager): + def get_epilogue_lambda_guards(self) -> list[LeafGuard]: ... + def add_epilogue_lambda_guard( + self, + guard: LeafGuard, + verbose_code_parts: list[str], + ) -> None: ... + def clone_manager( + self, clone_filter_fn: Callable[[GuardManager], bool] + ) -> RootGuardManager: ... + def attach_compile_id(self, compile_id: str) -> None: ... + +class DictGuardManager(GuardManager): + def get_key_manager( + self, + index: int, + source: str, + example_value: Any, + guard_manager_enum: GuardManagerType, + ) -> GuardManager: ... + def get_value_manager( + self, + index: int, + source: str, + example_value: Any, + guard_manager_enum: GuardManagerType, + ) -> GuardManager: ... + def get_key_value_managers( + self, + ) -> dict[int, tuple[GuardManager, GuardManager]]: ... + +# Guard accessor stubs +class GuardAccessor: ... +class DictGetItemGuardAccessor(GuardAccessor): ... +class GetGenericDictGuardAccessor(GuardAccessor): ... +class TypeDictGuardAccessor(GuardAccessor): ... +class TypeMROGuardAccessor(GuardAccessor): ... +class ClosureGuardAccessor(GuardAccessor): ... +class TupleGetItemGuardAccessor(GuardAccessor): ... +class TypeGuardAccessor(GuardAccessor): ... +class CodeGuardAccessor(GuardAccessor): ... +class FuncDefaultsGuardAccessor(GuardAccessor): ... +class FuncKwDefaultsGuardAccessor(GuardAccessor): ... + +class GetAttrGuardAccessor(GuardAccessor): + def get_attr_name(self) -> str: ... + +def install_object_aliasing_guard( + x: GuardManager, + y: GuardManager, + verbose_code_parts: list[str], +) -> None: ... +def install_no_tensor_aliasing_guard( + guard_managers: list[GuardManager], + tensor_names: list[str], + verbose_code_parts: list[str], +) -> None: ... +def install_storage_overlapping_guard( + overlapping_guard_managers: list[GuardManager], + non_overlapping_guard_managers: list[GuardManager], + verbose_code_parts: list[str], +) -> None: ... +def install_symbolic_shape_guard( + guard_managers: list[GuardManager], + nargs_int: int, + nargs_float: int, + py_addr: int, + py_addr_keep_alive: Any, + verbose_code_parts: list[str], +) -> None: ... +def profile_guard_manager( + guard_manager: GuardManager, + f_locals: dict[str, Any], + n_iters: int, +) -> float: ... + +class TensorGuards: + def __init__( + self, + *, + dynamic_dims_sizes: list[torch.SymInt | None] | None = None, + dynamic_dims_strides: list[torch.SymInt | None] | None = None, + ) -> None: ... + def check(self, *args: Any) -> bool: ... + def check_verbose( + self, *args: Any, tensor_check_names: Optional[list[str]] = None + ) -> bool | str: ... + +def assert_size_stride( + item: torch.Tensor, + size: torch.types._size, + stride: torch.types._size, + op_name: str | None = None, +) -> None: ... +def assert_alignment( + item: torch.Tensor, + alignment: int, + op_name: str | None = None, +) -> None: ... +def check_obj_id(obj: object, expected: int) -> bool: ... +def check_type_id(obj: object, expected: int) -> bool: ... +def dict_version(d: dict[Any, Any]) -> int: ... +def compute_overlapping_tensors( + tensors: list[torch.Tensor], symbolic: bool = True +) -> set[int]: ... +def set_is_in_mode_without_ignore_compile_internals(value: bool) -> None: ... diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_C/_export/__init__.pyi b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_C/_export/__init__.pyi new file mode 100644 index 0000000000000000000000000000000000000000..039f9c22eea620bc9675d233684df72c7ac4471c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_C/_export/__init__.pyi @@ -0,0 +1,9 @@ +# Defined in torch/csrc/export/pybind.cpp +class CppExportedProgram: ... + +def deserialize_exported_program( + serialized_program: str, +) -> CppExportedProgram: ... +def serialize_exported_program( + cpp_exported_program: CppExportedProgram, +) -> str: ... diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_C/_export/pt2_archive_constants.pyi b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_C/_export/pt2_archive_constants.pyi new file mode 100644 index 0000000000000000000000000000000000000000..f7a92ddd0c961513d42949e8c2c4b18fcadcc8cc --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_C/_export/pt2_archive_constants.pyi @@ -0,0 +1,25 @@ +# Defined in torch/csrc/export/pt2_archive_constants.h + +ARCHIVE_ROOT_NAME: str = ... +ARCHIVE_FORMAT_PATH: str = ... +ARCHIVE_FORMAT_VALUE: str = ... +ARCHIVE_VERSION_PATH: str = ... +ARCHIVE_VERSION_VALUE: str = ... +MODELS_DIR: str = ... +MODELS_FILENAME_FORMAT: str = ... +AOTINDUCTOR_DIR: str = ... +MTIA_DIR: str = ... +WEIGHTS_DIR: str = ... +WEIGHTS_CONFIG_FILENAME_FORMAT: str = ... +WEIGHT_FILENAME_PREFIX: str = ... +CONSTANTS_DIR: str = ... +CONSTANTS_CONFIG_FILENAME_FORMAT: str = ... +TENSOR_CONSTANT_FILENAME_PREFIX: str = ... +CUSTOM_OBJ_FILENAME_PREFIX: str = ... +SAMPLE_INPUTS_DIR: str = ... +SAMPLE_INPUTS_FILENAME_FORMAT: str = ... +EXECUTORCH_DIR: str = ... +EXTRA_DIR: str = ... +MODULE_INFO_PATH: str = ... +XL_MODEL_WEIGHTS_DIR: str = ... +XL_MODEL_WEIGHTS_PARAM_CONFIG_PATH: str = ... diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_custom_op/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_custom_op/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..547ae8a6aaf9592ea0137898deeee23d619f425b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_custom_op/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_custom_op/__pycache__/autograd.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_custom_op/__pycache__/autograd.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..564718264900dc8e429d0f77def99fec1a390bce Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_custom_op/__pycache__/autograd.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_custom_op/__pycache__/impl.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_custom_op/__pycache__/impl.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f1f473f12e1afcb906c1661be8a5ae97601b2215 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/_custom_op/__pycache__/impl.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/amp/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/amp/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96f0cb92bf014a2bcfbafb08f3e895a7262f4e47 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/amp/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/amp/__pycache__/autocast_mode.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/amp/__pycache__/autocast_mode.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ae2741e295e3f3e4290269c03753a27902bb4ae Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/amp/__pycache__/autocast_mode.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/amp/__pycache__/grad_scaler.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/amp/__pycache__/grad_scaler.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0e30e3b3d0425b02d4f63e2cdfa3a125bbbe6bc Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/amp/__pycache__/grad_scaler.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c0f81a884f0c7257abd0977837d8c2a679bb926 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7439c22d66882d058e617edb85bc4407cfd742a9 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/__init__.py @@ -0,0 +1,35 @@ +# We are exposing all subpackages to the end-user. +# Because of possible inter-dependency, we want to avoid +# the cyclic imports, thus implementing lazy version +# as per https://peps.python.org/pep-0562/ + +from typing import TYPE_CHECKING as _TYPE_CHECKING + + +if _TYPE_CHECKING: + from types import ModuleType + + from torch.ao.nn import ( # noqa: TC004 + intrinsic as intrinsic, + qat as qat, + quantizable as quantizable, + quantized as quantized, + sparse as sparse, + ) + + +__all__ = [ + "intrinsic", + "qat", + "quantizable", + "quantized", + "sparse", +] + + +def __getattr__(name: str) -> "ModuleType": + if name in __all__: + import importlib + + return importlib.import_module("." + name, __name__) + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..044c86ab2f0d3e129276caefab5dbb01308de36f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..80ba84a84251db6229c38b5f2c48b233fe594fbb --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/__init__.py @@ -0,0 +1,41 @@ +import types + +from .modules import * # noqa: F403 +from .modules.fused import _FusedModule # noqa: F403 + + +# # Subpackages +# from . import qat # noqa: F403 +# from . import quantized # noqa: F403 + +__all__ = [ + "ConvBn1d", + "ConvBn2d", + "ConvBn3d", + "ConvBnReLU1d", + "ConvBnReLU2d", + "ConvBnReLU3d", + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", + "LinearReLU", + "BNReLU2d", + "BNReLU3d", + "LinearBn1d", + "LinearLeakyReLU", + "LinearTanh", + "ConvAdd2d", + "ConvAddReLU2d", +] + + +# We are exposing all subpackages to the end-user. +# Because of possible inter-dependency, we want to avoid +# the cyclic imports, thus implementing lazy version +# as per https://peps.python.org/pep-0562/ +def __getattr__(name: str) -> types.ModuleType: + if name in __all__: + import importlib + + return importlib.import_module("." + name, __name__) + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ded23d8f7aaa34495cb97cc9be52e8be8a86774 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/modules/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..132137b7357378fe29ef9a63310a554725aea86a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/modules/__init__.py @@ -0,0 +1,41 @@ +from .fused import ( # noqa: F401 + _FusedModule, + BNReLU2d, + BNReLU3d, + ConvAdd2d, + ConvAddReLU2d, + ConvBn1d, + ConvBn2d, + ConvBn3d, + ConvBnReLU1d, + ConvBnReLU2d, + ConvBnReLU3d, + ConvReLU1d, + ConvReLU2d, + ConvReLU3d, + LinearBn1d, + LinearLeakyReLU, + LinearReLU, + LinearTanh, +) + + +__all__ = [ + "ConvBn1d", + "ConvBn2d", + "ConvBn3d", + "ConvBnReLU1d", + "ConvBnReLU2d", + "ConvBnReLU3d", + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", + "LinearReLU", + "BNReLU2d", + "BNReLU3d", + "LinearBn1d", + "LinearLeakyReLU", + "LinearTanh", + "ConvAdd2d", + "ConvAddReLU2d", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/modules/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/modules/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7a0e63f2ad81b679421b883aad90f5158a94709 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/modules/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/modules/__pycache__/fused.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/modules/__pycache__/fused.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dbec59e4c78e7facd38fa796177559ab8a093f20 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/modules/__pycache__/fused.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/modules/fused.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/modules/fused.py new file mode 100644 index 0000000000000000000000000000000000000000..d189e3d92447da930ba487034b58c623e2e7a4ce --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/modules/fused.py @@ -0,0 +1,289 @@ +# mypy: allow-untyped-defs +import torch +from torch.nn import ( + BatchNorm1d, + BatchNorm2d, + BatchNorm3d, + Conv1d, + Conv2d, + Conv3d, + Linear, + ReLU, +) +from torch.nn.utils.parametrize import type_before_parametrizations + + +__all__ = [ + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", + "LinearReLU", + "ConvBn1d", + "ConvBn2d", + "ConvBnReLU1d", + "ConvBnReLU2d", + "ConvBn3d", + "ConvBnReLU3d", + "BNReLU2d", + "BNReLU3d", + "LinearBn1d", + "LinearLeakyReLU", + "LinearTanh", + "ConvAdd2d", + "ConvAddReLU2d", +] + + +# Used for identifying intrinsic modules used in quantization +class _FusedModule(torch.nn.Sequential): + pass + + +class ConvReLU1d(_FusedModule): + r"""This is a sequential container which calls the Conv1d and ReLU modules. + During quantization this will be replaced with the corresponding fused module.""" + + def __init__(self, conv, relu): + assert ( + type_before_parametrizations(conv) == Conv1d + and type_before_parametrizations(relu) == ReLU + ), ( + f"Incorrect types for input modules{type_before_parametrizations(conv)}" + f"{type_before_parametrizations(relu)}" + ) + super().__init__(conv, relu) + + +class ConvReLU2d(_FusedModule): + r"""This is a sequential container which calls the Conv2d and ReLU modules. + During quantization this will be replaced with the corresponding fused module.""" + + def __init__(self, conv, relu): + assert ( + type_before_parametrizations(conv) == Conv2d + and type_before_parametrizations(relu) == ReLU + ), ( + f"Incorrect types for input modules{type_before_parametrizations(conv)}" + f"{type_before_parametrizations(relu)}" + ) + super().__init__(conv, relu) + + +class ConvReLU3d(_FusedModule): + r"""This is a sequential container which calls the Conv3d and ReLU modules. + During quantization this will be replaced with the corresponding fused module.""" + + def __init__(self, conv, relu): + assert ( + type_before_parametrizations(conv) == Conv3d + and type_before_parametrizations(relu) == ReLU + ), ( + f"Incorrect types for input modules{type_before_parametrizations(conv)}" + f"{type_before_parametrizations(relu)}" + ) + super().__init__(conv, relu) + + +class LinearReLU(_FusedModule): + r"""This is a sequential container which calls the Linear and ReLU modules. + During quantization this will be replaced with the corresponding fused module.""" + + def __init__(self, linear, relu): + assert ( + type_before_parametrizations(linear) == Linear + and type_before_parametrizations(relu) == ReLU + ), ( + f"Incorrect types for input modules{type_before_parametrizations(linear)}" + f"{type_before_parametrizations(relu)}" + ) + super().__init__(linear, relu) + + +class ConvBn1d(_FusedModule): + r"""This is a sequential container which calls the Conv 1d and Batch Norm 1d modules. + During quantization this will be replaced with the corresponding fused module.""" + + def __init__(self, conv, bn): + assert ( + type_before_parametrizations(conv) == Conv1d + and type_before_parametrizations(bn) == BatchNorm1d + ), ( + f"Incorrect types for input modules{type_before_parametrizations(conv)}" + f"{type_before_parametrizations(bn)}" + ) + super().__init__(conv, bn) + + +class ConvBn2d(_FusedModule): + r"""This is a sequential container which calls the Conv 2d and Batch Norm 2d modules. + During quantization this will be replaced with the corresponding fused module.""" + + def __init__(self, conv, bn): + assert ( + type_before_parametrizations(conv) == Conv2d + and type_before_parametrizations(bn) == BatchNorm2d + ), ( + f"Incorrect types for input modules{type_before_parametrizations(conv)}" + f"{type_before_parametrizations(bn)}" + ) + super().__init__(conv, bn) + + +class ConvBnReLU1d(_FusedModule): + r"""This is a sequential container which calls the Conv 1d, Batch Norm 1d, and ReLU modules. + During quantization this will be replaced with the corresponding fused module.""" + + def __init__(self, conv, bn, relu): + assert ( + type_before_parametrizations(conv) == Conv1d + and type_before_parametrizations(bn) == BatchNorm1d + and type_before_parametrizations(relu) == ReLU + ), ( + f"Incorrect types for input modules{type_before_parametrizations(conv)}" + f"{type_before_parametrizations(bn)}" + f"{type_before_parametrizations(relu)}" + ) + super().__init__(conv, bn, relu) + + +class ConvBnReLU2d(_FusedModule): + r"""This is a sequential container which calls the Conv 2d, Batch Norm 2d, and ReLU modules. + During quantization this will be replaced with the corresponding fused module.""" + + def __init__(self, conv, bn, relu): + assert ( + type_before_parametrizations(conv) == Conv2d + and type_before_parametrizations(bn) == BatchNorm2d + and type_before_parametrizations(relu) == ReLU + ), ( + f"Incorrect types for input modules{type_before_parametrizations(conv)}" + f"{type_before_parametrizations(bn)}" + f"{type_before_parametrizations(relu)}" + ) + super().__init__(conv, bn, relu) + + +class ConvBn3d(_FusedModule): + r"""This is a sequential container which calls the Conv 3d and Batch Norm 3d modules. + During quantization this will be replaced with the corresponding fused module.""" + + def __init__(self, conv, bn): + assert ( + type_before_parametrizations(conv) == Conv3d + and type_before_parametrizations(bn) == BatchNorm3d + ), ( + f"Incorrect types for input modules{type_before_parametrizations(conv)}" + f"{type_before_parametrizations(bn)}" + ) + super().__init__(conv, bn) + + +class ConvBnReLU3d(_FusedModule): + r"""This is a sequential container which calls the Conv 3d, Batch Norm 3d, and ReLU modules. + During quantization this will be replaced with the corresponding fused module.""" + + def __init__(self, conv, bn, relu): + assert ( + type_before_parametrizations(conv) == Conv3d + and type_before_parametrizations(bn) == BatchNorm3d + and type_before_parametrizations(relu) == ReLU + ), ( + f"Incorrect types for input modules{type_before_parametrizations(conv)}" + f"{type_before_parametrizations(bn)}" + f"{type_before_parametrizations(relu)}" + ) + super().__init__(conv, bn, relu) + + +class BNReLU2d(_FusedModule): + r"""This is a sequential container which calls the BatchNorm 2d and ReLU modules. + During quantization this will be replaced with the corresponding fused module.""" + + def __init__(self, batch_norm, relu): + assert ( + type_before_parametrizations(batch_norm) == BatchNorm2d + and type_before_parametrizations(relu) == ReLU + ), ( + f"Incorrect types for input modules{type_before_parametrizations(batch_norm)}" + f"{type_before_parametrizations(relu)}" + ) + super().__init__(batch_norm, relu) + + +class BNReLU3d(_FusedModule): + r"""This is a sequential container which calls the BatchNorm 3d and ReLU modules. + During quantization this will be replaced with the corresponding fused module.""" + + def __init__(self, batch_norm, relu): + assert ( + type_before_parametrizations(batch_norm) == BatchNorm3d + and type_before_parametrizations(relu) == ReLU + ), ( + f"Incorrect types for input modules{type_before_parametrizations(batch_norm)}" + f"{type_before_parametrizations(relu)}" + ) + super().__init__(batch_norm, relu) + + +class LinearBn1d(_FusedModule): + r"""This is a sequential container which calls the Linear and BatchNorm1d modules. + During quantization this will be replaced with the corresponding fused module.""" + + def __init__(self, linear, bn): + assert ( + type_before_parametrizations(linear) == Linear + and type_before_parametrizations(bn) == BatchNorm1d + ), ( + f"Incorrect types for input modules{type_before_parametrizations(linear)}" + f"{type_before_parametrizations(bn)}" + ) + super().__init__(linear, bn) + + +class LinearLeakyReLU(_FusedModule): + r"""This is a sequential container which calls the Linear and LeakyReLU modules. + During quantization this will be replaced with the corresponding fused module.""" + + def __init__(self, linear, leaky_relu): + assert type(linear) is Linear and type(leaky_relu) is torch.nn.LeakyReLU, ( + f"Incorrect types for input modules{type(linear)}{type(leaky_relu)}" + ) + super().__init__(linear, leaky_relu) + + +class LinearTanh(_FusedModule): + r"""This is a sequential container which calls the Linear and Tanh modules. + During quantization this will be replaced with the corresponding fused module.""" + + def __init__(self, linear, tanh): + assert type(linear) is Linear and type(tanh) is torch.nn.Tanh, ( + f"Incorrect types for input modules{type(linear)}{type(tanh)}" + ) + super().__init__(linear, tanh) + + +class ConvAdd2d(_FusedModule): + r"""This is a sequential container which calls the Conv2d modules with extra Add. + During quantization this will be replaced with the corresponding fused module.""" + + def __init__(self, conv, add): + super().__init__(conv) + self.add = add + + def forward(self, x1, x2): # type: ignore[override] + r"""Applies convolution to x1 and adds the result to x2.""" + return self.add(self[0](x1), x2) + + +class ConvAddReLU2d(_FusedModule): + r"""This is a sequential container which calls the Conv2d, add, Relu. + During quantization this will be replaced with the corresponding fused module.""" + + def __init__(self, conv, add, relu): + super().__init__(conv) + self.add = add + self.relu = relu + + def forward(self, x1, x2): # type: ignore[override] + r"""Applies convolution to x1, adds the result to x2, and applies ReLU.""" + return self.relu(self.add(self[0](x1), x2)) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3d79bdbfe83209f18b17cc8c7b245f322871d6c0 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/__init__.py @@ -0,0 +1 @@ +from .modules import * # noqa: F403 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ba84285a8ab32370d1c92194f5d657418d4fa41 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..18534bbc588e7480ac6529c6648c5976eadaea3a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/__init__.py @@ -0,0 +1,32 @@ +from .conv_fused import ( + ConvBn1d, + ConvBn2d, + ConvBn3d, + ConvBnReLU1d, + ConvBnReLU2d, + ConvBnReLU3d, + ConvReLU1d, + ConvReLU2d, + ConvReLU3d, + freeze_bn_stats, + update_bn_stats, +) +from .linear_fused import LinearBn1d +from .linear_relu import LinearReLU + + +__all__ = [ + "LinearReLU", + "LinearBn1d", + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", + "ConvBn1d", + "ConvBn2d", + "ConvBn3d", + "ConvBnReLU1d", + "ConvBnReLU2d", + "ConvBnReLU3d", + "update_bn_stats", + "freeze_bn_stats", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c75c7b839f37eeeb44ed4f8750a2f7bec427b78 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/conv_fused.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/conv_fused.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e2a5d9e0e578da985c0f17963664c1a951d56e7 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/conv_fused.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/linear_fused.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/linear_fused.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e7325c4f03c9afcb24850c480e94f3eaf0297cef Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/linear_fused.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/linear_relu.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/linear_relu.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f31b618d6ff858b09ae9ccfd882aa0efb8fa1e7 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/linear_relu.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/conv_fused.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/conv_fused.py new file mode 100644 index 0000000000000000000000000000000000000000..10f67764d8f05143e4bcc15ad1196f801015370a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/conv_fused.py @@ -0,0 +1,958 @@ +# mypy: allow-untyped-defs +import math +from typing import ClassVar + +import torch +import torch.ao.nn.intrinsic as nni +import torch.ao.nn.qat as nnqat +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import init +from torch.nn.modules.utils import _pair, _single, _triple +from torch.nn.parameter import Parameter +from torch.nn.utils import fuse_conv_bn_weights + + +__all__ = [ + "ConvBn1d", + "ConvBnReLU1d", + "ConvReLU1d", + "ConvBn2d", + "ConvBnReLU2d", + "ConvReLU2d", + "ConvBn3d", + "ConvBnReLU3d", + "ConvReLU3d", + "update_bn_stats", + "freeze_bn_stats", +] +_BN_CLASS_MAP = { + 1: nn.BatchNorm1d, + 2: nn.BatchNorm2d, + 3: nn.BatchNorm3d, +} + + +class _ConvBnNd(nn.modules.conv._ConvNd, nni._FusedModule): + _version = 2 + _FLOAT_MODULE: ClassVar[type[nn.modules.conv._ConvNd]] + + def __init__( + self, + # ConvNd args + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + bias, + padding_mode, + # BatchNormNd args + # num_features: out_channels + eps=1e-05, + momentum=0.1, + # affine: True + # track_running_stats: True + # Args for this module + freeze_bn=False, + qconfig=None, + dim=2, + ): + nn.modules.conv._ConvNd.__init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + False, + padding_mode, + ) + assert qconfig, "qconfig must be provided for QAT module" + self.qconfig = qconfig + self.freeze_bn = freeze_bn if self.training else True + self.bn = _BN_CLASS_MAP[dim](out_channels, eps, momentum, True, True) + self.weight_fake_quant = self.qconfig.weight() + if bias: + self.bias = Parameter(torch.empty(out_channels)) + else: + self.register_parameter("bias", None) + self.reset_bn_parameters() + + # this needs to be called after reset_bn_parameters, + # as they modify the same state + if self.training: + if freeze_bn: + self.freeze_bn_stats() + else: + self.update_bn_stats() + else: + self.freeze_bn_stats() + + self._enable_slow_path_for_better_numerical_stability = False + + def reset_running_stats(self): + self.bn.reset_running_stats() + + def reset_bn_parameters(self): + self.bn.reset_running_stats() + init.uniform_(self.bn.weight) + init.zeros_(self.bn.bias) + # note: below is actually for conv, not BN + if self.bias is not None: + fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) + init.uniform_(self.bias, -bound, bound) + + def update_bn_stats(self): + self.freeze_bn = False + self.bn.training = True + return self + + def freeze_bn_stats(self): + self.freeze_bn = True + self.bn.training = False + return self + + def _forward(self, input): + if self._enable_slow_path_for_better_numerical_stability: + return self._forward_slow(input) + return self._forward_approximate(input) + + def _forward_approximate(self, input): + """Approximated method to fuse conv and bn. It requires only one forward pass. + conv_orig = conv / scale_factor where scale_factor = bn.weight / running_std + """ + assert self.bn.running_var is not None + running_std = torch.sqrt(self.bn.running_var + self.bn.eps) + scale_factor = self.bn.weight / running_std + weight_shape = [1] * len(self.weight.shape) + weight_shape[0] = -1 + bias_shape = [1] * len(self.weight.shape) + bias_shape[1] = -1 + scaled_weight = self.weight_fake_quant( + self.weight * scale_factor.reshape(weight_shape) + ) + # using zero bias here since the bias for original conv + # will be added later + if self.bias is not None: + zero_bias = torch.zeros_like(self.bias, dtype=input.dtype) + else: + zero_bias = torch.zeros( + self.out_channels, device=scaled_weight.device, dtype=input.dtype + ) + conv = self._conv_forward(input, scaled_weight, zero_bias) + conv_orig = conv / scale_factor.reshape(bias_shape) + if self.bias is not None: + conv_orig = conv_orig + self.bias.reshape(bias_shape) + conv = self.bn(conv_orig) + return conv + + def _forward_slow(self, input): + """ + A more accurate but slow method to compute conv bn fusion, following https://arxiv.org/pdf/1806.08342.pdf + It requires two forward passes but handles the case bn.weight == 0 + + Conv: Y = WX + B_c + Conv without bias: Y0 = WX = Y - B_c, Y = Y0 + B_c + + Batch statistics: + mean_Y = Y.mean() + = Y0.mean() + B_c + var_Y = (Y - mean_Y)^2.mean() + = (Y0 - Y0.mean())^2.mean() + BN (r: bn.weight, beta: bn.bias): + Z = r * (Y - mean_Y) / sqrt(var_Y + eps) + beta + = r * (Y0 - Y0.mean()) / sqrt(var_Y + eps) + beta + + Fused Conv BN training (std_Y = sqrt(var_Y + eps)): + Z = (r * W / std_Y) * X + r * (B_c - mean_Y) / std_Y + beta + = (r * W / std_Y) * X - r * Y0.mean() / std_Y + beta + + Fused Conv BN inference (running_std = sqrt(running_var + eps)): + Z = (r * W / running_std) * X - r * (running_mean - B_c) / running_std + beta + + QAT with fused conv bn: + Z_train = fake_quant(r * W / running_std) * X * (running_std / std_Y) - r * Y0.mean() / std_Y + beta + = conv(X, fake_quant(r * W / running_std)) * (running_std / std_Y) - r * Y0.mean() / std_Y + beta + Z_inference = conv(X, fake_quant(r * W / running_std)) - r * (running_mean - B_c) / running_std + beta + """ + + assert self.bn.running_var is not None + assert self.bn.running_mean is not None + + # using zero bias here since the bias for original conv + # will be added later + zero_bias = torch.zeros( + self.out_channels, device=self.weight.device, dtype=input.dtype + ) + + weight_shape = [1] * len(self.weight.shape) + weight_shape[0] = -1 + bias_shape = [1] * len(self.weight.shape) + bias_shape[1] = -1 + + if self.bn.training: + # needed to compute batch mean/std + conv_out = self._conv_forward(input, self.weight, zero_bias) + # update bn statistics + with torch.no_grad(): + conv_out_bias = ( + conv_out + if self.bias is None + else conv_out + self.bias.reshape(bias_shape) + ) + self.bn(conv_out_bias) + + # fused conv + bn without bias using bn running statistics + running_std = torch.sqrt(self.bn.running_var + self.bn.eps) + scale_factor = self.bn.weight / running_std + scaled_weight = self.weight_fake_quant( + self.weight * scale_factor.reshape(weight_shape) + ) + # fused conv without bias for inference: (r * W / running_std) * X + conv_bn = self._conv_forward(input, scaled_weight, zero_bias) + + avg_dims = [0] + list(range(2, len(self.weight.shape))) + batch_mean = conv_out.mean(avg_dims) + batch_var = torch.square(conv_out - batch_mean.reshape(bias_shape)).mean( + avg_dims + ) + batch_std = torch.sqrt(batch_var + self.bn.eps) + + # scale to use batch std in training mode + # conv(X, r * W / std_Y) = conv(X, r * W / running_std) * (running_std / std_Y) + unscale_factor = running_std / batch_std + conv_bn *= unscale_factor.reshape(bias_shape) + + fused_mean = batch_mean + fused_std = batch_std + else: + # fused conv + bn without bias using bn running statistics + running_std = torch.sqrt(self.bn.running_var + self.bn.eps) + scale_factor = self.bn.weight / running_std + scaled_weight = self.weight_fake_quant( + self.weight * scale_factor.reshape(weight_shape) + ) + # fused conv without bias for inference: (r * W / running_std) * X + conv_bn = self._conv_forward(input, scaled_weight, zero_bias) + + fused_mean = self.bn.running_mean - ( + self.bias if self.bias is not None else 0 + ) + fused_std = running_std + + # fused bias = beta - r * mean / std + fused_bias = self.bn.bias - self.bn.weight * fused_mean / fused_std + conv_bn += fused_bias.reshape(bias_shape) + + # HACK to let conv bias participate in loss to avoid DDP error (parameters + # were not used in producing loss) + if self.bias is not None: + conv_bn += (self.bias - self.bias).reshape(bias_shape) + + return conv_bn + + def forward(self, input): + return self._forward(input) + + def train(self, mode=True): + """ + Batchnorm's training behavior is using the self.training flag. Prevent + changing it if BN is frozen. This makes sure that calling `model.train()` + on a model with a frozen BN will behave properly. + """ + self.training = mode + if not self.freeze_bn: + for module in self.children(): + module.train(mode) + return self + + # ===== Serialization version history ===== + # + # Version 1/None + # self + # |--- weight : Tensor + # |--- bias : Tensor + # |--- gamma : Tensor + # |--- beta : Tensor + # |--- running_mean : Tensor + # |--- running_var : Tensor + # |--- num_batches_tracked : Tensor + # + # Version 2 + # self + # |--- weight : Tensor + # |--- bias : Tensor + # |--- bn : Module + # |--- weight : Tensor (moved from v1.self.gamma) + # |--- bias : Tensor (moved from v1.self.beta) + # |--- running_mean : Tensor (moved from v1.self.running_mean) + # |--- running_var : Tensor (moved from v1.self.running_var) + # |--- num_batches_tracked : Tensor (moved from v1.self.num_batches_tracked) + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + version = local_metadata.get("version", None) + if version is None or version == 1: + # BN related parameters and buffers were moved into the BN module for v2 + v2_to_v1_names = { + "bn.weight": "gamma", + "bn.bias": "beta", + "bn.running_mean": "running_mean", + "bn.running_var": "running_var", + "bn.num_batches_tracked": "num_batches_tracked", + } + for v2_name, v1_name in v2_to_v1_names.items(): + if prefix + v1_name in state_dict: + state_dict[prefix + v2_name] = state_dict[prefix + v1_name] + state_dict.pop(prefix + v1_name) + elif prefix + v2_name in state_dict: + # there was a brief period where forward compatibility + # for this module was broken (between + # https://github.com/pytorch/pytorch/pull/38478 + # and https://github.com/pytorch/pytorch/pull/38820) + # and modules emitted the v2 state_dict format while + # specifying that version == 1. This patches the forward + # compatibility issue by allowing the v2 style entries to + # be used. + pass + elif strict: + missing_keys.append(prefix + v2_name) + + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + r"""Create a qat module from a float module or qparams_dict + + Args: `mod` a float module, either produced by torch.ao.quantization utilities + or directly from user + """ + # The ignore is because _FLOAT_MODULE is a TypeVar here where the bound + # has no __name__ (code is fine though) + assert type(mod) is cls._FLOAT_MODULE, ( + "qat." + + cls.__name__ + + ".from_float only works for " + + cls._FLOAT_MODULE.__name__ + ) + assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" + assert mod.qconfig, "Input float module must have a valid qconfig" + qconfig = mod.qconfig + conv, bn = mod[0], mod[1] # type: ignore[index] + qat_convbn = cls( + conv.in_channels, + conv.out_channels, + conv.kernel_size, + conv.stride, + conv.padding, + conv.dilation, + conv.groups, + conv.bias is not None, + conv.padding_mode, + bn.eps, + bn.momentum, + False, + qconfig, + ) + qat_convbn.weight = conv.weight + qat_convbn.bias = conv.bias + qat_convbn.bn.weight = bn.weight + qat_convbn.bn.bias = bn.bias + qat_convbn.bn.running_mean = bn.running_mean + qat_convbn.bn.running_var = bn.running_var + # mypy error: Cannot determine type of 'num_batches_tracked' + qat_convbn.bn.num_batches_tracked = bn.num_batches_tracked + return qat_convbn + + def to_float(self): + cls = type(self) + conv = cls._FLOAT_CONV_MODULE( # type: ignore[attr-defined] + self.in_channels, + self.out_channels, + self.kernel_size, + self.stride, + self.padding, + self.dilation, + self.groups, + self.bias is not None, + self.padding_mode, + ) + conv.weight = torch.nn.Parameter(self.weight.detach()) + if self.bias is not None: + conv.bias = torch.nn.Parameter(self.bias.detach()) + + if cls._FLOAT_BN_MODULE: # type: ignore[attr-defined] + # fuse bn into conv + assert self.bn.running_var is not None and self.bn.running_mean is not None + conv.weight, conv.bias = fuse_conv_bn_weights( + conv.weight, + conv.bias, + self.bn.running_mean, + self.bn.running_var, + self.bn.eps, + self.bn.weight, + self.bn.bias, + ) + + if cls._FLOAT_RELU_MODULE: # type: ignore[attr-defined] + modules = [] + modules.append(conv) + relu = cls._FLOAT_RELU_MODULE() # type: ignore[attr-defined] + modules.append(relu) + conv_relu = cls._FUSED_FLOAT_MODULE(*modules) # type: ignore[attr-defined] + conv_relu.train(self.training) + return conv_relu + else: + conv.train(self.training) + return conv + + +class ConvBn1d(_ConvBnNd, nn.Conv1d): + r""" + A ConvBn1d module is a module fused from Conv1d and BatchNorm1d, + attached with FakeQuantize modules for weight, + used in quantization aware training. + + We combined the interface of :class:`torch.nn.Conv1d` and + :class:`torch.nn.BatchNorm1d`. + + Similar to :class:`torch.nn.Conv1d`, with FakeQuantize modules initialized + to default. + + Attributes: + freeze_bn: + weight_fake_quant: fake quant module for weight + + """ + + _FLOAT_BN_MODULE: ClassVar[type[nn.BatchNorm1d]] = nn.BatchNorm1d + _FLOAT_RELU_MODULE: ClassVar[type[nn.Module] | None] = None + _FLOAT_MODULE: ClassVar[type[nn.Module]] = nni.ConvBn1d # type: ignore[assignment] + _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv1d]] = nn.Conv1d + + def __init__( + self, + # Conv1d args + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=None, + padding_mode="zeros", + # BatchNorm1d args + # num_features: out_channels + eps=1e-05, + momentum=0.1, + # affine: True + # track_running_stats: True + # Args for this module + freeze_bn=False, + qconfig=None, + ): + kernel_size = _single(kernel_size) + stride = _single(stride) + padding = _single(padding) + dilation = _single(dilation) + _ConvBnNd.__init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + False, + _single(0), + groups, + bias, + padding_mode, + eps, + momentum, + freeze_bn, + qconfig, + dim=1, + ) + + +class ConvBnReLU1d(ConvBn1d): + r""" + A ConvBnReLU1d module is a module fused from Conv1d, BatchNorm1d and ReLU, + attached with FakeQuantize modules for weight, + used in quantization aware training. + + We combined the interface of :class:`torch.nn.Conv1d` and + :class:`torch.nn.BatchNorm1d` and :class:`torch.nn.ReLU`. + + Similar to `torch.nn.Conv1d`, with FakeQuantize modules initialized to + default. + + Attributes: + weight_fake_quant: fake quant module for weight + + """ + + # base class defines _FLOAT_MODULE as "ConvBn1d" + _FLOAT_MODULE: ClassVar[type[nn.Module]] = nni.ConvBnReLU1d + _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv1d]] = nn.Conv1d + _FLOAT_BN_MODULE: ClassVar[type[nn.BatchNorm1d]] = nn.BatchNorm1d + _FLOAT_RELU_MODULE: ClassVar[type[nn.Module] | None] = nn.ReLU + # module class after fusing bn into conv + _FUSED_FLOAT_MODULE: ClassVar[type[nn.Module] | None] = nni.ConvReLU1d + + def forward(self, input): + r"""Performs forward pass through fused Conv1d, BatchNorm1d, and ReLU.""" + return F.relu(self._forward(input)) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + r"""Creates a QAT module from a floating point module.""" + return super().from_float(mod, use_precomputed_fake_quant) + + +class ConvReLU1d(nnqat.Conv1d, nni._FusedModule): + r"""A ConvReLU1d module is a fused module of Conv1d and ReLU, attached with + FakeQuantize modules for weight for + quantization aware training. + + We combined the interface of :class:`~torch.nn.Conv1d` and + :class:`~torch.nn.BatchNorm1d`. + + Attributes: + weight_fake_quant: fake quant module for weight + + """ + + _FLOAT_MODULE: ClassVar[type[nni.ConvReLU1d]] = nni.ConvReLU1d # type: ignore[assignment] + _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv1d]] = nn.Conv1d + _FLOAT_BN_MODULE: ClassVar[type[nn.Module] | None] = None + _FLOAT_RELU_MODULE: ClassVar[type[nn.Module] | None] = nn.ReLU + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode="zeros", + qconfig=None, + ): + super().__init__( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + # pyrefly: ignore [bad-argument-type] + padding_mode=padding_mode, + qconfig=qconfig, + ) + assert qconfig, "qconfig must be provided for QAT module" + self.qconfig = qconfig + self.weight_fake_quant = self.qconfig.weight() + + def forward(self, input): + r"""Performs forward pass through fused Conv1d and ReLU.""" + return F.relu( + self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias) + ) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] + r"""Creates a QAT module from a floating point module.""" + return super().from_float( + mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + + +class ConvBn2d(_ConvBnNd, nn.Conv2d): + r""" + A ConvBn2d module is a module fused from Conv2d and BatchNorm2d, + attached with FakeQuantize modules for weight, + used in quantization aware training. + + We combined the interface of :class:`torch.nn.Conv2d` and + :class:`torch.nn.BatchNorm2d`. + + Similar to :class:`torch.nn.Conv2d`, with FakeQuantize modules initialized + to default. + + Attributes: + freeze_bn: + weight_fake_quant: fake quant module for weight + + """ + + _FLOAT_MODULE: ClassVar[type[nni.ConvBn2d]] = nni.ConvBn2d # type: ignore[assignment] + _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv2d]] = nn.Conv2d + _FLOAT_BN_MODULE: ClassVar[type[nn.Module] | None] = nn.BatchNorm2d + _FLOAT_RELU_MODULE: ClassVar[type[nn.Module] | None] = None + + def __init__( + self, + # ConvNd args + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=None, + padding_mode="zeros", + # BatchNorm2d args + # num_features: out_channels + eps=1e-05, + momentum=0.1, + # affine: True + # track_running_stats: True + # Args for this module + freeze_bn=False, + qconfig=None, + ): + kernel_size = _pair(kernel_size) + stride = _pair(stride) + padding = _pair(padding) + dilation = _pair(dilation) + _ConvBnNd.__init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + False, + _pair(0), + groups, + bias, + padding_mode, + eps, + momentum, + freeze_bn, + qconfig, + dim=2, + ) + + +class ConvBnReLU2d(ConvBn2d): + r""" + A ConvBnReLU2d module is a module fused from Conv2d, BatchNorm2d and ReLU, + attached with FakeQuantize modules for weight, + used in quantization aware training. + + We combined the interface of :class:`torch.nn.Conv2d` and + :class:`torch.nn.BatchNorm2d` and :class:`torch.nn.ReLU`. + + Similar to `torch.nn.Conv2d`, with FakeQuantize modules initialized to + default. + + Attributes: + weight_fake_quant: fake quant module for weight + + """ + + # base class defines _FLOAT_MODULE as "ConvBn2d" + _FLOAT_MODULE: ClassVar[type[nni.ConvBnReLU2d]] = nni.ConvBnReLU2d # type: ignore[assignment] + _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv2d]] = nn.Conv2d + _FLOAT_BN_MODULE: ClassVar[type[nn.BatchNorm2d]] = nn.BatchNorm2d + _FLOAT_RELU_MODULE: ClassVar[type[nn.Module] | None] = nn.ReLU + # module class after fusing bn into conv + _FUSED_FLOAT_MODULE: ClassVar[type[nni.ConvReLU2d] | None] = nni.ConvReLU2d + + def forward(self, input): + r"""Performs forward pass through fused Conv2d, BatchNorm2d, and ReLU.""" + return F.relu(self._forward(input)) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + r"""Creates a QAT module from a floating point module.""" + return super().from_float(mod, use_precomputed_fake_quant) + + +class ConvReLU2d(nnqat.Conv2d, nni._FusedModule): + r"""A ConvReLU2d module is a fused module of Conv2d and ReLU, attached with + FakeQuantize modules for weight for + quantization aware training. + + We combined the interface of :class:`~torch.nn.Conv2d` and + :class:`~torch.nn.BatchNorm2d`. + + Attributes: + weight_fake_quant: fake quant module for weight + + """ + + _FLOAT_MODULE: ClassVar[type[nn.Module]] = nni.ConvReLU2d # type: ignore[assignment] + _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv2d]] = nn.Conv2d + _FLOAT_BN_MODULE: ClassVar[type[nn.Module] | None] = None + _FLOAT_RELU_MODULE: ClassVar[type[nn.Module] | None] = nn.ReLU + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode="zeros", + qconfig=None, + ): + super().__init__( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + # pyrefly: ignore [bad-argument-type] + padding_mode=padding_mode, + qconfig=qconfig, + ) + assert qconfig, "qconfig must be provided for QAT module" + self.qconfig = qconfig + self.weight_fake_quant = self.qconfig.weight() + + def forward(self, input): + r"""Performs forward pass through fused Conv2d and ReLU.""" + return F.relu( + self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias) + ) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] + r"""Creates a QAT module from a floating point module.""" + return super().from_float( + mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + + +class ConvBn3d(_ConvBnNd, nn.Conv3d): + r""" + A ConvBn3d module is a module fused from Conv3d and BatchNorm3d, + attached with FakeQuantize modules for weight, + used in quantization aware training. + + We combined the interface of :class:`torch.nn.Conv3d` and + :class:`torch.nn.BatchNorm3d`. + + Similar to :class:`torch.nn.Conv3d`, with FakeQuantize modules initialized + to default. + + Attributes: + freeze_bn: + weight_fake_quant: fake quant module for weight + + """ + + _FLOAT_MODULE: ClassVar[type[nni.ConvBn3d]] = nni.ConvBn3d # type: ignore[assignment] + _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv3d]] = nn.Conv3d + _FLOAT_BN_MODULE: ClassVar[type[nn.Module] | None] = nn.BatchNorm3d + _FLOAT_RELU_MODULE: ClassVar[type[nn.Module] | None] = None + + def __init__( + self, + # ConvNd args + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=None, + padding_mode="zeros", + # BatchNorm3d args + # num_features: out_channels + eps=1e-05, + momentum=0.1, + # affine: True + # track_running_stats: True + # Args for this module + freeze_bn=False, + qconfig=None, + ): + kernel_size = _triple(kernel_size) + stride = _triple(stride) + padding = _triple(padding) + dilation = _triple(dilation) + _ConvBnNd.__init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + False, + _triple(0), + groups, + bias, + padding_mode, + eps, + momentum, + freeze_bn, + qconfig, + dim=3, + ) + + +class ConvBnReLU3d(ConvBn3d): + r""" + A ConvBnReLU3d module is a module fused from Conv3d, BatchNorm3d and ReLU, + attached with FakeQuantize modules for weight, + used in quantization aware training. + + We combined the interface of :class:`torch.nn.Conv3d` and + :class:`torch.nn.BatchNorm3d` and :class:`torch.nn.ReLU`. + + Similar to `torch.nn.Conv3d`, with FakeQuantize modules initialized to + default. + + Attributes: + weight_fake_quant: fake quant module for weight + + """ + + _FLOAT_MODULE: ClassVar[type[nni.ConvBnReLU3d]] = nni.ConvBnReLU3d # type: ignore[assignment] + _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv3d]] = nn.Conv3d + _FLOAT_BN_MODULE: ClassVar[type[nn.BatchNorm3d]] = nn.BatchNorm3d + _FLOAT_RELU_MODULE: ClassVar[type[nn.ReLU] | None] = nn.ReLU + # module class after fusing bn into conv + _FUSED_FLOAT_MODULE: ClassVar[type[nni.ConvReLU3d] | None] = nni.ConvReLU3d + + def forward(self, input): + r"""Performs forward pass through fused Conv3d, BatchNorm3d, and ReLU.""" + return F.relu(ConvBn3d._forward(self, input)) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + r"""Creates a QAT module from a floating point module.""" + return super().from_float( + mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + + +class ConvReLU3d(nnqat.Conv3d, nni._FusedModule): + r"""A ConvReLU3d module is a fused module of Conv3d and ReLU, attached with + FakeQuantize modules for weight for + quantization aware training. + + We combined the interface of :class:`~torch.nn.Conv3d` and + :class:`~torch.nn.BatchNorm3d`. + + Attributes: + weight_fake_quant: fake quant module for weight + + """ + + _FLOAT_MODULE: ClassVar[type[nni.ConvReLU3d]] = nni.ConvReLU3d # type: ignore[assignment] + _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv3d]] = nn.Conv3d + _FLOAT_BN_MODULE: ClassVar[type[nn.Module] | None] = None + _FLOAT_RELU_MODULE: ClassVar[type[nn.Module] | None] = nn.ReLU + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode="zeros", + qconfig=None, + ): + super().__init__( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + # pyrefly: ignore [bad-argument-type] + padding_mode=padding_mode, + qconfig=qconfig, + ) + assert qconfig, "qconfig must be provided for QAT module" + self.qconfig = qconfig + self.weight_fake_quant = self.qconfig.weight() + + def forward(self, input): + r"""Performs forward pass through fused Conv3d and ReLU.""" + return F.relu( + self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias) + ) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] + r"""Creates a QAT module from a floating point module.""" + return super().from_float( + mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + + +def update_bn_stats(mod): + if type(mod) in { + ConvBnReLU1d, + ConvBnReLU2d, + ConvBnReLU3d, + ConvBn1d, + ConvBn2d, + ConvBn3d, + }: + mod.update_bn_stats() + + +def freeze_bn_stats(mod): + if type(mod) in { + ConvBnReLU1d, + ConvBnReLU2d, + ConvBnReLU3d, + ConvBn1d, + ConvBn2d, + ConvBn3d, + }: + mod.freeze_bn_stats() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/linear_fused.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/linear_fused.py new file mode 100644 index 0000000000000000000000000000000000000000..8458cef76ee3a37bce33d924d2d60d2ca971a614 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/linear_fused.py @@ -0,0 +1,191 @@ +# mypy: allow-untyped-defs +import torch +import torch.ao.nn.intrinsic as nni +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import init +from torch.nn.parameter import Parameter +from torch.nn.utils.fusion import fuse_linear_bn_weights + + +__all__ = [ + "LinearBn1d", +] + + +class LinearBn1d(nn.modules.linear.Linear, nni._FusedModule): + r""" + A LinearBn1d module is a module fused from Linear and BatchNorm1d, attached + with FakeQuantize modules for weight, used in quantization aware training. + + We combined the interface of :class:`torch.nn.Linear` and + :class:torch.nn.BatchNorm1d`. + + Similar to :class:`torch.nn.Linear`, with FakeQuantize modules initialized + to default. + + Attributes: + freeze_bn: + weight_fake_quant: fake quant module for weight + + """ + + def __init__( + self, + # Linear args + in_features, + out_features, + bias=True, + # BatchNorm1d args + # num_features: out_features + eps=1e-05, + momentum=0.1, + # affine: True + # track_running_stats: True + # Args for this module + freeze_bn=False, + qconfig=None, + ): + nn.modules.linear.Linear.__init__(self, in_features, out_features, bias) + assert qconfig, "qconfig must be provided for QAT module" + self.qconfig = qconfig + self.freeze_bn = freeze_bn if self.training else True + self.bn = nn.BatchNorm1d(out_features, eps, momentum, True, True) + self.weight_fake_quant = self.qconfig.weight() + if bias: + self.bias = Parameter(torch.empty(out_features)) + else: + self.register_parameter("bias", None) + self.reset_bn_parameters() + + # this needs to be called after reset_bn_parameters, + # as they modify the same state + if self.training: + if freeze_bn: + self.freeze_bn_stats() + else: + self.update_bn_stats() + else: + self.freeze_bn_stats() + + def reset_running_stats(self): + self.bn.reset_running_stats() + + def reset_bn_parameters(self): + self.bn.reset_running_stats() + init.uniform_(self.bn.weight) + init.zeros_(self.bn.bias) + + def update_bn_stats(self): + self.freeze_bn = False + self.bn.training = True + return self + + def freeze_bn_stats(self): + self.freeze_bn = True + self.bn.training = False + return self + + def forward(self, input): + assert self.bn.running_var is not None + + # Scale the linear weights by BN's running statistics to reduce + # weight jitter, see https://arxiv.org/pdf/1806.08342.pdf, page 18 + # for motivation. + # + # Instead of + # + # x1 = F.linear(x0, fq(w), b) + # x2 = self.bn(x1) + # + # We have + # + # # scale the weight by previous batch's running statistics + # scale_factor = bn.w / bn.running_std_from_prev_batch + # # do the linear transformation without bias + # x1_scaled = F.linear(x0, fq(w * scale_factor), 0) + # # reverse the scaling and add original bias + # x1_orig = x1_scaled / scale_factor + b + # x2 = self.bn(x1_orig) + + running_std = torch.sqrt(self.bn.running_var + self.bn.eps) + scale_factor = self.bn.weight / running_std + weight_shape = [1] * len(self.weight.shape) + weight_shape[0] = -1 + bias_shape = [1] * len(self.weight.shape) + bias_shape[1] = -1 + scaled_weight = self.weight_fake_quant( + self.weight * scale_factor.reshape(weight_shape) + ) + if self.bias is not None: + zero_bias = torch.zeros_like(self.bias) + else: + zero_bias = torch.zeros(self.out_features, device=scaled_weight.device) + linear_out = F.linear(input, scaled_weight, zero_bias) + linear_out_orig = linear_out / scale_factor.reshape(bias_shape) + if self.bias is not None: + linear_out_orig = linear_out_orig + self.bias.reshape(bias_shape) + bn_out = self.bn(linear_out_orig) + return bn_out + + def train(self, mode=True): + """ + Batchnorm's training behavior is using the self.training flag. Prevent + changing it if BN is frozen. This makes sure that calling `model.train()` + on a model with a frozen BN will behave properly. + """ + self.training = mode + if not self.freeze_bn: + for module in self.children(): + module.train(mode) + return self + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + r"""Create a qat module from a float module or qparams_dict + + Args: + mod: A float module, either produced by torch.ao.quantization + utilities or directly from the user. + """ + assert type(mod) is nni.LinearBn1d, ( + "qat." + + cls.__name__ + + ".from_float only works for " + + nni.LinearBn1d.__name__ + ) + assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" + assert mod.qconfig, "Input float module must have a valid config" + qconfig = mod.qconfig + linear, bn = mod[0], mod[1] + qat_linearbn = cls( + linear.in_features, + linear.out_features, + linear.bias is not None, + bn.eps, + bn.momentum, + False, + qconfig, + ) + qat_linearbn.weight = linear.weight # type: ignore[assignment] + qat_linearbn.bias = linear.bias # type: ignore[assignment] + qat_linearbn.bn.weight = bn.weight # type: ignore[assignment] + qat_linearbn.bn.bias = bn.bias # type: ignore[assignment] + qat_linearbn.bn.running_mean = bn.running_mean # type: ignore[assignment] + qat_linearbn.bn.running_var = bn.running_var # type: ignore[assignment] + qat_linearbn.bn.num_batches_tracked = bn.num_batches_tracked # type: ignore[assignment] + return qat_linearbn + + def to_float(self): + linear = torch.nn.Linear(self.in_features, self.out_features) + assert self.bn.running_var is not None and self.bn.running_mean is not None + linear.weight, linear.bias = fuse_linear_bn_weights( + self.weight, + self.bias, + self.bn.running_mean, + self.bn.running_var, + self.bn.eps, + self.bn.weight, + self.bn.bias, + ) + return linear diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/linear_relu.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/linear_relu.py new file mode 100644 index 0000000000000000000000000000000000000000..183286ebb8dad25e49cd2fcd7c2dba2436003823 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/linear_relu.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch +import torch.ao.nn.intrinsic as nni +import torch.ao.nn.qat as nnqat +import torch.nn.functional as F +from torch.ao.nn.intrinsic.modules.fused import _FusedModule + + +if TYPE_CHECKING: + from torch.ao.quantization.qconfig import QConfigAny + + +__all__ = ["LinearReLU"] + + +class LinearReLU(nnqat.Linear, _FusedModule): + r""" + A LinearReLU module fused from Linear and ReLU modules, attached with + FakeQuantize modules for weight, used in + quantization aware training. + + We adopt the same interface as :class:`torch.nn.Linear`. + + Similar to `torch.ao.nn.intrinsic.LinearReLU`, with FakeQuantize modules initialized to + default. + + Attributes: + weight: fake quant module for weight + + Examples:: + + >>> # xdoctest: +SKIP + >>> m = nn.qat.LinearReLU(20, 30) + >>> input = torch.randn(128, 20) + >>> output = m(input) + >>> print(output.size()) + torch.Size([128, 30]) + """ + + # pyrefly: ignore [bad-override] + _FLOAT_MODULE = nni.LinearReLU + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + qconfig: QConfigAny = None, + ) -> None: + super().__init__(in_features, out_features, bias, qconfig) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return F.relu(F.linear(input, self.weight_fake_quant(self.weight), self.bias)) + + @classmethod + def from_float( + cls, + mod: torch.nn.Module, + use_precomputed_fake_quant: bool = False, + ) -> LinearReLU: + return super().from_float(mod, use_precomputed_fake_quant) # type: ignore[no-untyped-call,no-any-return] + + def to_float(self) -> nni.LinearReLU: + linear = torch.nn.Linear( + self.in_features, self.out_features, self.bias is not None + ) + linear.weight = torch.nn.Parameter(self.weight.detach()) + if self.bias is not None: + linear.bias = torch.nn.Parameter(self.bias.detach()) + relu = torch.nn.ReLU() + return torch.ao.nn.intrinsic.LinearReLU(linear, relu) # type: ignore[no-untyped-call] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6af3b4aeee893966323cc4e73a27ff41814fc251 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/__init__.py @@ -0,0 +1,15 @@ +from .modules import * # noqa: F403 + + +__all__ = [ + "BNReLU2d", + "BNReLU3d", + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", + "LinearReLU", + "LinearLeakyReLU", + "LinearTanh", + "ConvAdd2d", + "ConvAddReLU2d", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4266a3e8ad55402de9353db37403a385fcb9558c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3d79bdbfe83209f18b17cc8c7b245f322871d6c0 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/__init__.py @@ -0,0 +1 @@ +from .modules import * # noqa: F403 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c0fcdd2a34fec1045ebc398e09a00a587899b51 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d7a6c3c57c7828861b574e76b134aee2c23f0aad --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__init__.py @@ -0,0 +1,6 @@ +from .linear_relu import LinearReLU + + +__all__ = [ + "LinearReLU", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ecbd3a3f315e94abe6d049dada994b7f87e67d2 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__pycache__/linear_relu.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__pycache__/linear_relu.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bde10368f5ad43d018ff440964a3f0077f633d33 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__pycache__/linear_relu.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/linear_relu.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/linear_relu.py new file mode 100644 index 0000000000000000000000000000000000000000..620d24ae43e466ecd7883acf7df627641ebfdb24 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/linear_relu.py @@ -0,0 +1,72 @@ +from typing import Any +from typing_extensions import Self + +import torch +import torch.ao.nn.intrinsic as nni +import torch.ao.nn.quantized.dynamic as nnqd + + +__all__ = ["LinearReLU"] + + +class LinearReLU(nnqd.Linear): + r""" + A LinearReLU module fused from Linear and ReLU modules that can be used + for dynamic quantization. + Supports both, FP16 and INT8 quantization. + + We adopt the same interface as :class:`torch.ao.nn.quantized.dynamic.Linear`. + + Attributes: + Same as torch.ao.nn.quantized.dynamic.Linear + + Examples:: + + >>> # xdoctest: +SKIP + >>> m = nn.intrinsic.quantized.dynamic.LinearReLU(20, 30) + >>> input = torch.randn(128, 20) + >>> output = m(input) + >>> print(output.size()) + torch.Size([128, 30]) + """ + + # pyrefly: ignore [bad-override] + _FLOAT_MODULE = nni.LinearReLU + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = torch.qint8, + ) -> None: + super().__init__(in_features, out_features, bias, dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self._packed_params.dtype == torch.qint8: + # TODO check if we should set reduce_rage = True by default here + Y = torch.ops.quantized.linear_relu_dynamic( + x, self._packed_params._packed_params, reduce_range=True + ) + elif self._packed_params.dtype == torch.float16: + Y = torch.ops.quantized.linear_relu_dynamic_fp16( + x, self._packed_params._packed_params + ) + else: + raise RuntimeError("Unsupported dtype on dynamic quantized linear relu!") + return Y.to(x.dtype) + + def _get_name(self) -> str: + return "DynamicQuantizedLinearReLU" + + @classmethod + def from_float( + cls, mod: torch.nn.Module, use_precomputed_fake_quant: bool = False + ) -> Self: + return super().from_float( + mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + + @classmethod + def from_reference(cls, ref_qlinear_relu: Any) -> Self: # type: ignore[override] + return super().from_reference(ref_qlinear_relu[0]) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d7fa4dcec2597e18c002489405894ea7251d5156 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/__init__.py @@ -0,0 +1,18 @@ +from .bn_relu import BNReLU2d, BNReLU3d +from .conv_add import ConvAdd2d, ConvAddReLU2d +from .conv_relu import ConvReLU1d, ConvReLU2d, ConvReLU3d +from .linear_relu import LinearLeakyReLU, LinearReLU, LinearTanh + + +__all__ = [ + "LinearReLU", + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", + "BNReLU2d", + "BNReLU3d", + "LinearLeakyReLU", + "LinearTanh", + "ConvAdd2d", + "ConvAddReLU2d", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c5dfff5971cb88f743e5d7d363f1432696cf3f3 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/bn_relu.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/bn_relu.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d1c4610809d81abcd9e86a61c244170ec868685 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/bn_relu.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/conv_add.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/conv_add.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62aacbe7a8515cbfb13e4e839ee2f4f1c81bc30f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/conv_add.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/conv_relu.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/conv_relu.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1f9cbd32b350e20054f762b77b2027518d06b98d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/conv_relu.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/linear_relu.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/linear_relu.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d75829513c29542a6f8df06a639465b15a2380ab Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/linear_relu.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/bn_relu.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/bn_relu.py new file mode 100644 index 0000000000000000000000000000000000000000..f05618c0949e1164f05cbd1edbfb8eb6440063e9 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/bn_relu.py @@ -0,0 +1,113 @@ +# mypy: allow-untyped-defs + +import torch +import torch.ao.nn.intrinsic +import torch.ao.nn.intrinsic.qat +import torch.ao.nn.quantized as nnq + + +__all__ = ["BNReLU2d", "BNReLU3d"] + + +class BNReLU2d(nnq.BatchNorm2d): + r""" + A BNReLU2d module is a fused module of BatchNorm2d and ReLU + + We adopt the same interface as :class:`torch.ao.nn.quantized.BatchNorm2d`. + + Attributes: + Same as torch.ao.nn.quantized.BatchNorm2d + + """ + + _FLOAT_MODULE = torch.ao.nn.intrinsic.BNReLU2d + + def __init__(self, num_features, eps=1e-5, momentum=0.1, device=None, dtype=None): + super().__init__( + num_features, eps=eps, momentum=momentum, device=device, dtype=dtype + ) + + def forward(self, input): + r"""Applies fused BatchNorm2d and ReLU.""" + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 4: + raise ValueError("Input shape must be `(N, C, H, W)`!") + return torch.ops.quantized.batch_norm2d_relu( + input, + self.weight, + self.bias, + self.running_mean, + self.running_var, + self.eps, + self.scale, + self.zero_point, + ) + + def _get_name(self): + return "QuantizedBNReLU2d" + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] + r"""Creates a quantized module from a float module.""" + # TODO: Add qat support for BNReLU2d + return super().from_float( + mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + + @classmethod + def from_reference(cls, bn_relu, output_scale, output_zero_point): + r"""Creates a quantized module from a reference module.""" + return super().from_reference(bn_relu[0], output_scale, output_zero_point) + + +class BNReLU3d(nnq.BatchNorm3d): + r""" + A BNReLU3d module is a fused module of BatchNorm3d and ReLU + + We adopt the same interface as :class:`torch.ao.nn.quantized.BatchNorm3d`. + + Attributes: + Same as torch.ao.nn.quantized.BatchNorm3d + + """ + + _FLOAT_MODULE = torch.ao.nn.intrinsic.BNReLU3d + + def __init__(self, num_features, eps=1e-5, momentum=0.1, device=None, dtype=None): + super().__init__( + num_features, eps=eps, momentum=momentum, device=device, dtype=dtype + ) + + def forward(self, input): + r"""Applies fused BatchNorm3d and ReLU.""" + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 5: + raise ValueError("Input shape must be `(N, C, D, H, W)`!") + return torch.ops.quantized.batch_norm3d_relu( + input, + self.weight, + self.bias, + self.running_mean, + self.running_var, + self.eps, + self.scale, + self.zero_point, + ) + + def _get_name(self): + return "QuantizedBNReLU3d" + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] + r"""Creates a quantized module from a float module.""" + # TODO: Add qat support for BNReLU3d + return super().from_float( + mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + + @classmethod + def from_reference(cls, bn_relu, output_scale, output_zero_point): + r"""Creates a quantized module from a reference module.""" + return super().from_reference(bn_relu[0], output_scale, output_zero_point) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/conv_add.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/conv_add.py new file mode 100644 index 0000000000000000000000000000000000000000..82d5673e7173c56b5b56d2bd48a0b154bbfdfe9e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/conv_add.py @@ -0,0 +1,153 @@ +# mypy: allow-untyped-defs +import torch +import torch.ao.nn.intrinsic +import torch.ao.nn.intrinsic.qat +import torch.ao.nn.quantized as nnq +import torch.nn.functional as F + + +_reverse_repeat_padding = nnq.modules.conv._reverse_repeat_padding + + +class ConvAdd2d(nnq.Conv2d): + r""" + A ConvAdd2d module is a fused module of Conv2d and Add + + We adopt the same interface as :class:`torch.ao.nn.quantized.Conv2d`. + + Attributes: + Same as torch.ao.nn.quantized.Conv2d + + """ + + _FLOAT_MODULE = torch.ao.nn.intrinsic.ConvAdd2d # type: ignore[assignment] + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode="zeros", + device=None, + dtype=None, + ): + super().__init__( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=padding_mode, + device=device, + dtype=dtype, + ) + + def forward(self, input, extra_input): # type: ignore[override] + r"""Applies fused quantized Conv2d and addition.""" + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 4: + raise ValueError("Input shape must be `(N, C, H, W)`!") + if self.padding_mode != "zeros": + _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding) + input = F.pad( + input, _reversed_padding_repeated_twice, mode=self.padding_mode + ) + return torch.ops.quantized.conv2d_add( + input, extra_input, self._packed_params, self.scale, self.zero_point + ) + + def _get_name(self): + return "QuantizedConvAdd2d" + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] + r"""Creates a quantized module from a float module.""" + return super().from_float( + mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + + @classmethod + def from_reference(cls, ref_qconv, output_scale, output_zero_point): + r"""Creates a quantized module from a reference module.""" + return super().from_reference(ref_qconv[0], output_scale, output_zero_point) + + +class ConvAddReLU2d(nnq.Conv2d): + r""" + A ConvAddReLU2d module is a fused module of Conv2d, Add and Relu + + We adopt the same interface as :class:`torch.ao.nn.quantized.Conv2d`. + + Attributes: + Same as torch.ao.nn.quantized.Conv2d + + """ + + _FLOAT_MODULE = torch.ao.nn.intrinsic.ConvAddReLU2d # type: ignore[assignment] + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode="zeros", + device=None, + dtype=None, + ): + super().__init__( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=padding_mode, + device=device, + dtype=dtype, + ) + + def forward(self, input, extra_input): # type: ignore[override] + r"""Applies fused quantized Conv2d, addition, and ReLU.""" + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 4: + raise ValueError("Input shape must be `(N, C, H, W)`!") + if self.padding_mode != "zeros": + _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding) + input = F.pad( + input, _reversed_padding_repeated_twice, mode=self.padding_mode + ) + return torch.ops.quantized.conv2d_add_relu( + input, extra_input, self._packed_params, self.scale, self.zero_point + ) + + def _get_name(self): + return "QuantizedConvAddReLU2d" + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] + r"""Creates a quantized module from a float module.""" + return super().from_float( + mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + + @classmethod + def from_reference(cls, ref_qconv, output_scale, output_zero_point): + r"""Creates a quantized module from a reference module.""" + return super().from_reference(ref_qconv[0], output_scale, output_zero_point) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/conv_relu.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/conv_relu.py new file mode 100644 index 0000000000000000000000000000000000000000..c31df28905cd7c9c17147c965f5bd2199af2920a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/conv_relu.py @@ -0,0 +1,276 @@ +# mypy: allow-untyped-defs + +import torch +import torch.ao.nn.intrinsic +import torch.ao.nn.intrinsic.qat +import torch.ao.nn.quantized as nnq +import torch.nn.functional as F +from torch.nn.utils import fuse_conv_bn_weights + + +__all__ = [ + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", +] + +_reverse_repeat_padding = nnq.modules.conv._reverse_repeat_padding + + +# TODO: factor out the common parts to ConvNd +class ConvReLU1d(nnq.Conv1d): + r""" + A ConvReLU1d module is a fused module of Conv1d and ReLU + + We adopt the same interface as :class:`torch.ao.nn.quantized.Conv1d`. + + Attributes: + Same as torch.ao.nn.quantized.Conv1d + + """ + + _FLOAT_MODULE = torch.ao.nn.intrinsic.ConvReLU1d # type: ignore[assignment] + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode="zeros", + device=None, + dtype=None, + ): + super().__init__( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + # pyrefly: ignore [bad-argument-type] + padding_mode=padding_mode, + device=device, + dtype=dtype, + ) + + def forward(self, input): + r"""Applies fused quantized Conv1d and ReLU.""" + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 3: + raise ValueError("Input shape must be `(N, C, L)`!") + if self.padding_mode != "zeros": + # Padding in Conv1d is stored as (p, p), need to get (p,) + _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding[:1]) + input = F.pad( + input, _reversed_padding_repeated_twice, mode=self.padding_mode + ) + return torch.ops.quantized.conv1d_relu( + input, self._packed_params, self.scale, self.zero_point + ) + + def _get_name(self): + return "QuantizedConvReLU1d" + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] + r"""Creates a quantized module from a float module.""" + if type(mod) is torch.ao.nn.intrinsic.qat.ConvBnReLU1d: + assert mod.bn.running_var is not None and mod.bn.running_mean is not None + mod.weight, mod.bias = fuse_conv_bn_weights( + mod.weight, + mod.bias, + mod.bn.running_mean, + mod.bn.running_var, + mod.bn.eps, + mod.bn.weight, + mod.bn.bias, + ) + return super().from_float(mod, use_precomputed_fake_quant) + + @classmethod + def from_reference(cls, ref_qconv, output_scale, output_zero_point): + r"""Creates a quantized module from a reference module.""" + assert type(ref_qconv) is not torch.ao.nn.intrinsic.ConvBnReLU1d, ( + "BatchNorm1d should be fused into Conv1d before converting to reference module" + ) + return super().from_reference(ref_qconv[0], output_scale, output_zero_point) + + +class ConvReLU2d(nnq.Conv2d): + r""" + A ConvReLU2d module is a fused module of Conv2d and ReLU + + We adopt the same interface as :class:`torch.ao.nn.quantized.Conv2d`. + + Attributes: + Same as torch.ao.nn.quantized.Conv2d + + """ + + _FLOAT_MODULE = torch.ao.nn.intrinsic.ConvReLU2d # type: ignore[assignment] + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode="zeros", + device=None, + dtype=None, + ): + super().__init__( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=padding_mode, + device=device, + dtype=dtype, + ) + + def forward(self, input): + r"""Applies fused quantized Conv2d and ReLU.""" + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 4: + raise ValueError("Input shape must be `(N, C, H, W)`!") + if self.padding_mode != "zeros": + _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding) + input = F.pad( + input, _reversed_padding_repeated_twice, mode=self.padding_mode + ) + return torch.ops.quantized.conv2d_relu( + input, self._packed_params, self.scale, self.zero_point + ) + + def _get_name(self): + return "QuantizedConvReLU2d" + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] + r"""Creates a quantized module from a float module.""" + if type(mod) is torch.ao.nn.intrinsic.qat.ConvBnReLU2d: + assert mod.bn.running_var is not None and mod.bn.running_mean is not None + mod.weight, mod.bias = fuse_conv_bn_weights( + mod.weight, + mod.bias, + mod.bn.running_mean, + mod.bn.running_var, + mod.bn.eps, + mod.bn.weight, + mod.bn.bias, + ) + return super().from_float( + mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + + @classmethod + def from_reference(cls, ref_qconv, output_scale, output_zero_point): + r"""Creates a quantized module from a reference module.""" + assert type(ref_qconv) is not torch.ao.nn.intrinsic.ConvBnReLU2d, ( + "BatchNorm2d should be fused into Conv2d before converting to reference module" + ) + return super().from_reference(ref_qconv[0], output_scale, output_zero_point) + + +class ConvReLU3d(nnq.Conv3d): + r""" + A ConvReLU3d module is a fused module of Conv3d and ReLU + + We adopt the same interface as :class:`torch.ao.nn.quantized.Conv3d`. + + Attributes: Same as torch.ao.nn.quantized.Conv3d + + """ + + _FLOAT_MODULE = torch.ao.nn.intrinsic.ConvReLU3d # type: ignore[assignment] + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode="zeros", + device=None, + dtype=None, + ): + assert padding_mode != "reflect", "Conv3d does not support reflection padding" + super().__init__( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=padding_mode, + device=device, + dtype=dtype, + ) + + def forward(self, input): + r"""Applies fused quantized Conv3d and ReLU.""" + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 5: + raise ValueError("Input shape must be `(N, C, D, H, W)`!") + if self.padding_mode != "zeros": + _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding) + input = F.pad( + input, _reversed_padding_repeated_twice, mode=self.padding_mode + ) + return torch.ops.quantized.conv3d_relu( + input, self._packed_params, self.scale, self.zero_point + ) + + def _get_name(self): + return "QuantizedConvReLU3d" + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] + r"""Creates a quantized module from a float module.""" + if type(mod) is torch.ao.nn.intrinsic.qat.ConvBnReLU3d: + assert mod.bn.running_var is not None and mod.bn.running_mean is not None + mod.weight, mod.bias = fuse_conv_bn_weights( + mod.weight, + mod.bias, + mod.bn.running_mean, + mod.bn.running_var, + mod.bn.eps, + mod.bn.weight, + mod.bn.bias, + ) + return super().from_float( + mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + + @classmethod + def from_reference(cls, ref_qconv, output_scale, output_zero_point): + r"""Creates a quantized module from a reference module.""" + assert type(ref_qconv) is not torch.ao.nn.intrinsic.ConvBnReLU3d, ( + "BatchNorm3d should be fused into Conv3d before converting to reference module" + ) + return super().from_reference(ref_qconv[0], output_scale, output_zero_point) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/linear_relu.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/linear_relu.py new file mode 100644 index 0000000000000000000000000000000000000000..8ec84101ee0da62e3923362f444368b2a429d8b3 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/linear_relu.py @@ -0,0 +1,190 @@ +# mypy: allow-untyped-defs +import torch +import torch.ao.nn.intrinsic as nni +import torch.ao.nn.quantized as nnq +from torch.ao.nn.quantized.modules.utils import _quantize_weight + + +__all__ = [ + "LinearReLU", + "LinearLeakyReLU", + "LinearTanh", +] + + +class LinearReLU(nnq.Linear): + r""" + A LinearReLU module fused from Linear and ReLU modules + + We adopt the same interface as :class:`torch.ao.nn.quantized.Linear`. + + Attributes: + Same as torch.ao.nn.quantized.Linear + + Examples:: + + >>> # xdoctest: +SKIP + >>> m = nn.intrinsic.LinearReLU(20, 30) + >>> input = torch.randn(128, 20) + >>> output = m(input) + >>> print(output.size()) + torch.Size([128, 30]) + """ + + _FLOAT_MODULE = nni.LinearReLU # type: ignore[assignment] + + def __init__(self, in_features, out_features, bias=True, dtype=torch.qint8): + super().__init__(in_features, out_features, bias, dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.ops.quantized.linear_relu( + x, self._packed_params._packed_params, self.scale, self.zero_point + ) + + def _get_name(self): + return "QuantizedLinearReLU" + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float(mod, use_precomputed_fake_quant) + + @classmethod + def from_reference(cls, ref_linear_relu, output_scale, output_zero_point): + return super().from_reference( + ref_linear_relu[0], output_scale, output_zero_point + ) + + +class LinearLeakyReLU(nnq.Linear): + r""" + For onednn backend only + A LinearLeakyReLU module fused from Linear and LeakyReLU modules + We adopt the same interface as :class:`torch.ao.nn.quantized.Linear`. + Attributes: + Same as torch.ao.nn.quantized.Linear + + negative_slope + Examples:: + >>> # xdoctest: +SKIP + >>> m = nn.intrinsic.LinearLeakyReLU(20, 30, 0.01) + >>> input = torch.randn(128, 20) + >>> output = m(input) + >>> print(output.size()) + torch.Size([128, 30]) + """ + + _FLOAT_MODULE = nni.LinearLeakyReLU # type: ignore[assignment] + + def __init__( + self, in_features, out_features, negative_slope, bias=True, dtype=torch.qint8 + ): + super().__init__(in_features, out_features, bias, dtype) + self.negative_slope = negative_slope + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.ops.quantized.linear_leaky_relu( + x, + self._packed_params._packed_params, + self.scale, + self.zero_point, + self.negative_slope, + ) + + def _get_name(self): + return "QuantizedLinearLeakyReLU" + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + assert type(mod) is nni.LinearLeakyReLU, ( + "Input float module should be LinearLeakyReLU" + ) + assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" + activation_post_process = mod.activation_post_process + leaky_relu = mod[1] + mod = mod[0] + weight_post_process = mod.qconfig.weight() # type: ignore[union-attr, operator] + weight_post_process(mod.weight) + dtype = weight_post_process.dtype + act_scale, act_zp = activation_post_process.calculate_qparams() # type: ignore[union-attr,operator] + assert dtype == torch.qint8, "Weight observer must have dtype torch.qint8" + qweight = _quantize_weight(mod.weight.float(), weight_post_process) + qlinear_leaky_relu = cls( + mod.in_features, mod.out_features, leaky_relu.negative_slope, dtype=dtype + ) + qlinear_leaky_relu.set_weight_bias(qweight, mod.bias) # type: ignore[arg-type] + qlinear_leaky_relu.scale = float(act_scale) + qlinear_leaky_relu.zero_point = int(act_zp) + return qlinear_leaky_relu + + @classmethod + def from_reference(cls, ref_mod, output_scale, output_zero_point): + linear = ref_mod[0] + leaky_relu = ref_mod[1] + qlinear_leaky_relu = cls( + linear.in_features, linear.out_features, leaky_relu.negative_slope + ) + qweight = linear.get_quantized_weight() + qlinear_leaky_relu.set_weight_bias(qweight, linear.bias) + qlinear_leaky_relu.scale = float(output_scale) + qlinear_leaky_relu.zero_point = int(output_zero_point) + return qlinear_leaky_relu + + +class LinearTanh(nnq.Linear): + r""" + A LinearTanh module fused from Linear and Tanh modules + + We adopt the same interface as :class:`torch.ao.nn.quantized.Linear`. + + Attributes: + Same as torch.ao.nn.quantized.Linear + + Examples:: + + >>> # xdoctest: +SKIP + >>> m = nn.intrinsic.LinearTanh(20, 30) + >>> input = torch.randn(128, 20) + >>> output = m(input) + >>> print(output.size()) + torch.Size([128, 30]) + """ + + _FLOAT_MODULE = nni.LinearTanh # type: ignore[assignment] + + def __init__(self, in_features, out_features, bias=True, dtype=torch.qint8): + super().__init__(in_features, out_features, bias, dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.ops.quantized.linear_tanh( + x, self._packed_params._packed_params, self.scale, self.zero_point + ) + + def _get_name(self): + return "QuantizedLinearTanh" + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + assert type(mod) is nni.LinearTanh, "Input float module should be LinearTanh" + assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" + activation_post_process = mod.activation_post_process + mod = mod[0] + weight_post_process = mod.qconfig.weight() # type: ignore[union-attr,operator] + weight_post_process(mod.weight) + dtype = weight_post_process.dtype + act_scale, act_zp = activation_post_process.calculate_qparams() # type: ignore[union-attr,operator] + assert dtype == torch.qint8, "Weight observer must have dtype torch.qint8" + qweight = _quantize_weight(mod.weight.float(), weight_post_process) + qlinear_tanh = cls(mod.in_features, mod.out_features, dtype=dtype) + qlinear_tanh.set_weight_bias(qweight, mod.bias) # type: ignore[arg-type] + qlinear_tanh.scale = float(act_scale) + qlinear_tanh.zero_point = int(act_zp) + return qlinear_tanh + + @classmethod + def from_reference(cls, ref_mod, output_scale, output_zero_point): + linear = ref_mod[0] + qlinear_tanh = cls(linear.in_features, linear.out_features) + qweight = linear.get_quantized_weight() + qlinear_tanh.set_weight_bias(qweight, linear.bias) + qlinear_tanh.scale = float(output_scale) + qlinear_tanh.zero_point = int(output_zero_point) + return qlinear_tanh diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/qat/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/qat/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3d79bdbfe83209f18b17cc8c7b245f322871d6c0 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/qat/__init__.py @@ -0,0 +1 @@ +from .modules import * # noqa: F403 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantizable/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantizable/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3d79bdbfe83209f18b17cc8c7b245f322871d6c0 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantizable/__init__.py @@ -0,0 +1 @@ +from .modules import * # noqa: F403 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantizable/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantizable/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..214848c14e3411b536e89a47dd7a328d91751695 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantizable/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantizable/modules/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantizable/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..221107660158171ada5d1823cc193666c9e152e7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantizable/modules/__init__.py @@ -0,0 +1,9 @@ +from .activation import MultiheadAttention +from .rnn import LSTM, LSTMCell + + +__all__ = [ + "LSTM", + "LSTMCell", + "MultiheadAttention", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantizable/modules/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantizable/modules/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5491ac17fd58aad8a552edaaf0814bf07e051c52 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantizable/modules/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantizable/modules/__pycache__/activation.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantizable/modules/__pycache__/activation.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..01e0c7c7de9d387d0098f6c48f2b496be3dde8d3 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantizable/modules/__pycache__/activation.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantizable/modules/__pycache__/rnn.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantizable/modules/__pycache__/rnn.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d426733acbd7dc81aa939dd0b45a38e7f13646d3 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantizable/modules/__pycache__/rnn.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantizable/modules/activation.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantizable/modules/activation.py new file mode 100644 index 0000000000000000000000000000000000000000..d808d50c366c68b8aa0d61a50b9f6db2d72c9ff2 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantizable/modules/activation.py @@ -0,0 +1,579 @@ +# mypy: allow-untyped-defs +import warnings + +import torch +import torch.jit # this is needed to avoid a circular import +import torch.nn.functional as F +from torch import nn, Tensor + + +__all__ = ["MultiheadAttention"] + + +class MultiheadAttention(nn.MultiheadAttention): + _FLOAT_MODULE = nn.MultiheadAttention + + r"""Quantizable implementation of the MultiheadAttention. + + Note:: + Please, refer to :class:`~torch.nn.MultiheadAttention` for more + information + + Allows the model to jointly attend to information from different + representation subspaces. + See reference: Attention Is All You Need + + The original MHA module is not quantizable. + This reimplements it by explicitly instantiating the linear layers. + + .. math:: + \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O + \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) + + Args: + embed_dim: total dimension of the model. + num_heads: parallel attention heads. + dropout: a Dropout layer on attn_output_weights. Default: 0.0. + bias: add bias as module parameter. Default: True. + add_bias_kv: add bias to the key and value sequences at dim=0. + add_zero_attn: add a new batch of zeros to the key and + value sequences at dim=1. + kdim: total number of features in key. Default: None. + vdim: total number of features in value. Default: None. + batch_first: If ``True``, then the input and output tensors are provided + as (batch, seq, feature). Default: ``False`` (seq, batch, feature). + + Note that if :attr:`kdim` and :attr:`vdim` are None, they will be set + to :attr:`embed_dim` such that query, key, and value have the same + number of features. + + Examples:: + + >>> import torch.ao.nn.quantizable as nnqa + >>> multihead_attn = nnqa.MultiheadAttention(embed_dim, num_heads) + >>> attn_output, attn_output_weights = multihead_attn(query, key, value) + + Note:: + Please, follow the quantization flow to convert the quantizable MHA. + """ + __constants__ = ["batch_first"] + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + bias: bool = True, + add_bias_kv: bool = False, + add_zero_attn: bool = False, + kdim: int | None = None, + vdim: int | None = None, + batch_first: bool = False, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__( + embed_dim, + num_heads, + dropout, + bias, + add_bias_kv, + add_zero_attn, + kdim, + vdim, + batch_first, + **factory_kwargs, + ) + self.linear_Q = nn.Linear( + self.embed_dim, self.embed_dim, bias=bias, **factory_kwargs + ) + self.linear_K = nn.Linear( + self.kdim, self.embed_dim, bias=bias, **factory_kwargs + ) + self.linear_V = nn.Linear( + self.vdim, self.embed_dim, bias=bias, **factory_kwargs + ) + # for the type: ignore, see https://github.com/pytorch/pytorch/issues/58969 + self.out_proj = nn.Linear( + self.embed_dim, self.embed_dim, bias=bias, **factory_kwargs + ) # type: ignore[assignment] + + # Functionals + self.q_scaling_product = torch.ao.nn.quantized.FloatFunctional() + # note: importing torch.ao.nn.quantized at top creates a circular import + + # Quant/Dequant + self.quant_attn_output = torch.ao.quantization.QuantStub() + self.quant_attn_output_weights = torch.ao.quantization.QuantStub() + self.dequant_q = torch.ao.quantization.DeQuantStub() + self.dequant_k = torch.ao.quantization.DeQuantStub() + self.dequant_v = torch.ao.quantization.DeQuantStub() + + def _get_name(self): + return "QuantizableMultiheadAttention" + + @classmethod + def from_float(cls, other): + assert type(other) is cls._FLOAT_MODULE + assert hasattr(other, "qconfig"), "The float module must have 'qconfig'" + # Setting the dropout to 0.0! + observed = cls( + other.embed_dim, + other.num_heads, + other.dropout, + (other.in_proj_bias is not None), + (other.bias_k is not None), + other.add_zero_attn, + other.kdim, + other.vdim, + other.batch_first, + ) + observed.bias_k = other.bias_k + observed.bias_v = other.bias_v + observed.qconfig = other.qconfig + + # Set the linear weights + # for the type: ignores, see https://github.com/pytorch/pytorch/issues/58969 + observed.out_proj.weight = other.out_proj.weight + observed.out_proj.bias = other.out_proj.bias + if other._qkv_same_embed_dim: + # Use separate params + bias = other.in_proj_bias + _start = 0 + _end = _start + other.embed_dim + weight = other.in_proj_weight[_start:_end, :] + if bias is not None: + bias = torch.nn.Parameter(bias[_start:_end], bias.requires_grad) + observed.linear_Q.weight = torch.nn.Parameter(weight, weight.requires_grad) + observed.linear_Q.bias = bias + + bias = other.in_proj_bias + _start = _end + _end = _start + other.embed_dim + weight = other.in_proj_weight[_start:_end, :] + if bias is not None: + bias = torch.nn.Parameter(bias[_start:_end], bias.requires_grad) + observed.linear_K.weight = torch.nn.Parameter(weight, weight.requires_grad) + observed.linear_K.bias = bias + + bias = other.in_proj_bias + _start = _end + weight = other.in_proj_weight[_start:, :] + if bias is not None: + bias = torch.nn.Parameter(bias[_start:], bias.requires_grad) + observed.linear_V.weight = torch.nn.Parameter(weight, weight.requires_grad) + observed.linear_V.bias = bias + else: + observed.linear_Q.weight = nn.Parameter(other.q_proj_weight) + observed.linear_K.weight = nn.Parameter(other.k_proj_weight) + observed.linear_V.weight = nn.Parameter(other.v_proj_weight) + if other.in_proj_bias is None: + # pyrefly: ignore [bad-assignment] + observed.linear_Q.bias = None + # pyrefly: ignore [bad-assignment] + observed.linear_K.bias = None + # pyrefly: ignore [bad-assignment] + observed.linear_V.bias = None + else: + observed.linear_Q.bias = nn.Parameter( + other.in_proj_bias[0 : other.embed_dim] + ) + observed.linear_K.bias = nn.Parameter( + other.in_proj_bias[other.embed_dim : (other.embed_dim * 2)] + ) + observed.linear_V.bias = nn.Parameter( + other.in_proj_bias[(other.embed_dim * 2) :] + ) + observed.eval() + # Explicit prepare + observed = torch.ao.quantization.prepare(observed, inplace=True) + return observed + + @torch.jit.unused + def dequantize(self): + r"""Utility to convert the quantized MHA back to float. + + The motivation for this is that it is not trivial to convert the weights + from the format that is used in the quantized version back to the + float. + """ + fp = self._FLOAT_MODULE( + self.embed_dim, + self.num_heads, + self.dropout, + (self.linear_Q._weight_bias()[1] is not None), # type: ignore[operator] + (self.bias_k is not None), + self.add_zero_attn, + self.kdim, + self.vdim, + self.batch_first, + ) + assert fp._qkv_same_embed_dim == self._qkv_same_embed_dim + if self.bias_k is not None: + fp.bias_k = nn.Parameter(self.bias_k.dequantize()) + if self.bias_v is not None: + fp.bias_v = nn.Parameter(self.bias_v.dequantize()) + + # Set the linear weights + # Note: Because the linear layers are quantized, mypy does not know how + # to deal with them -- might need to ignore the typing checks. + # for the type: ignore[has-type], see https://github.com/pytorch/pytorch/issues/58969 + w, b = self.out_proj._weight_bias() # type: ignore[operator, has-type] + fp.out_proj.weight = nn.Parameter(w.dequantize()) + if b is not None: + fp.out_proj.bias = nn.Parameter(b) + + wQ, bQ = self.linear_Q._weight_bias() # type: ignore[operator] + wQ = wQ.dequantize() + wK, bK = self.linear_K._weight_bias() # type: ignore[operator] + wK = wK.dequantize() + wV, bV = self.linear_V._weight_bias() # type: ignore[operator] + wV = wV.dequantize() + if fp._qkv_same_embed_dim: + # Use separate params + _start = 0 + _end = _start + fp.embed_dim + fp.in_proj_weight[_start:_end, :] = wQ + if fp.in_proj_bias is not None: + # pyrefly: ignore [bad-argument-type] + assert all(bQ == 0) + fp.in_proj_bias[_start:_end] = bQ + + _start = _end + _end = _start + fp.embed_dim + fp.in_proj_weight[_start:_end, :] = wK + if fp.in_proj_bias is not None: + # pyrefly: ignore [bad-argument-type] + assert all(bK == 0) + fp.in_proj_bias[_start:_end] = bK + + _start = _end + fp.in_proj_weight[_start:, :] = wV + if fp.in_proj_bias is not None: + # pyrefly: ignore [bad-argument-type] + assert all(bV == 0) + fp.in_proj_bias[_start:] = bV + else: + fp.q_proj_weight = nn.Parameter(wQ) + fp.k_proj_weight = nn.Parameter(wK) + fp.v_proj_weight = nn.Parameter(wV) + if fp.in_proj_bias is None: + # pyrefly: ignore [bad-assignment] + self.linear_Q.bias = None + # pyrefly: ignore [bad-assignment] + self.linear_K.bias = None + # pyrefly: ignore [bad-assignment] + self.linear_V.bias = None + else: + fp.in_proj_bias[0 : fp.embed_dim] = bQ + fp.in_proj_bias[fp.embed_dim : (fp.embed_dim * 2)] = bK + fp.in_proj_bias[(fp.embed_dim * 2) :] = bV + + return fp + + @classmethod + def from_observed(cls, other): + # The whole flow is float -> observed -> quantized + # This class does float -> observed only + # See nn.quantized.MultiheadAttention + raise NotImplementedError( + "It looks like you are trying to prepare an " + "MHA module. Please, see " + "the examples on quantizable MHAs." + ) + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + key_padding_mask: Tensor | None = None, + need_weights: bool = True, + attn_mask: Tensor | None = None, + average_attn_weights: bool = True, + is_causal: bool = False, + ) -> tuple[Tensor, Tensor | None]: + r""" + Note:: + Please, refer to :func:`~torch.nn.MultiheadAttention.forward` for more + information + + Args: + query, key, value: map a query and a set of key-value pairs to an output. + See "Attention Is All You Need" for more details. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. When given a binary mask and a value is True, + the corresponding value on the attention layer will be ignored. + need_weights: output attn_output_weights. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + + Shape: + - Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. :math:`(N, L, E)` if ``batch_first`` is ``True``. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. :math:`(N, S, E)` if ``batch_first`` is ``True``. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. :math:`(N, S, E)` if ``batch_first`` is ``True``. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked + positions. If a BoolTensor is provided, positions with ``True`` + is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + - is_causal: If specified, applies a causal mask as attention mask. Mutually exclusive with providing attn_mask. + Default: ``False``. + - average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across + heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an + effect when ``need_weights=True.``. Default: True (i.e. average weights across heads) + + - Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. :math:`(N, L, E)` if ``batch_first`` is ``True``. + - attn_output_weights: If ``average_attn_weights=True``, returns attention weights averaged + across heads of shape :math:`(N, L, S)`, where N is the batch size, L is the target sequence length, + S is the source sequence length. If ``average_attn_weights=False``, returns attention weights per + head of shape :math:`(N, num_heads, L, S)`. + """ + return self._forward_impl( + query, + key, + value, + key_padding_mask, + need_weights, + attn_mask, + average_attn_weights, + is_causal, + ) + + def _forward_impl( + self, + query: Tensor, + key: Tensor, + value: Tensor, + key_padding_mask: Tensor | None = None, + need_weights: bool = True, + attn_mask: Tensor | None = None, + average_attn_weights: bool = True, + is_causal: bool = False, + ) -> tuple[Tensor, Tensor | None]: + # This version will not deal with the static key/value pairs. + # Keeping it here for future changes. + # + # TODO: This method has some duplicate lines with the + # `torch.nn.functional.multi_head_attention`. Will need to refactor. + static_k = None + static_v = None + + if attn_mask is not None and is_causal: + raise AssertionError("Only allow causal mask or attn_mask") + + if is_causal: + raise AssertionError("causal mask not supported by AO MHA module") + + if self.batch_first: + query, key, value = (x.transpose(0, 1) for x in (query, key, value)) + + tgt_len, bsz, embed_dim_to_check = query.size() + assert self.embed_dim == embed_dim_to_check + # allow MHA to have different sizes for the feature dimension + assert key.size(0) == value.size(0) and key.size(1) == value.size(1) + + head_dim = self.embed_dim // self.num_heads + assert head_dim * self.num_heads == self.embed_dim, ( + "embed_dim must be divisible by num_heads" + ) + scaling = float(head_dim) ** -0.5 + + q = self.linear_Q(query) + k = self.linear_K(key) + v = self.linear_V(value) + + q = self.q_scaling_product.mul_scalar(q, scaling) + + if attn_mask is not None: + if attn_mask.dtype == torch.uint8: + warnings.warn( + "Byte tensor for `attn_mask` in `nn.MultiheadAttention` is deprecated. " + "Use bool tensor instead.", + stacklevel=3, + ) + attn_mask = attn_mask.to(torch.bool) + assert attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, ( + f"Only float and bool types are supported for attn_mask, not {attn_mask.dtype}" + ) + + if attn_mask.dim() == 2: + attn_mask = attn_mask.unsqueeze(0) + if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: + raise RuntimeError("The size of the 2D attn_mask is not correct.") + elif attn_mask.dim() == 3: + if list(attn_mask.size()) != [ + bsz * self.num_heads, + query.size(0), + key.size(0), + ]: + raise RuntimeError("The size of the 3D attn_mask is not correct.") + else: + raise RuntimeError( + f"attn_mask's dimension {attn_mask.dim()} is not supported" + ) + # attn_mask's dim is 3 now. + + # convert ByteTensor key_padding_mask to bool + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + warnings.warn( + "Byte tensor for `key_padding_mask` in `nn.MultiheadAttention` is deprecated. " + "Use bool tensor instead.", + stacklevel=3, + ) + key_padding_mask = key_padding_mask.to(torch.bool) + if self.bias_k is not None and self.bias_v is not None: + if static_k is None and static_v is None: + # Explicitly assert that bias_k and bias_v are not None + # in a way that TorchScript can understand. + bias_k = self.bias_k + assert bias_k is not None + bias_v = self.bias_v + assert bias_v is not None + + k = torch.cat([k, bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = F.pad(attn_mask, (0, 1)) + if key_padding_mask is not None: + key_padding_mask = F.pad(key_padding_mask, (0, 1)) + else: + assert static_k is None, "bias cannot be added to static key." + assert static_v is None, "bias cannot be added to static value." + else: + assert self.bias_k is None + assert self.bias_v is None + + q = q.contiguous().view(tgt_len, bsz * self.num_heads, head_dim).transpose(0, 1) + if k is not None: + k = k.contiguous().view(-1, bsz * self.num_heads, head_dim).transpose(0, 1) + if v is not None: + v = v.contiguous().view(-1, bsz * self.num_heads, head_dim).transpose(0, 1) + + if static_k is not None: + assert static_k.size(0) == bsz * self.num_heads + assert static_k.size(2) == head_dim + k = static_k + + if static_v is not None: + assert static_v.size(0) == bsz * self.num_heads + assert static_v.size(2) == head_dim + v = static_v + + # pyrefly: ignore [missing-attribute] + src_len = k.size(1) + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + if self.add_zero_attn: + src_len += 1 + # pyrefly: ignore [missing-attribute] + k_zeros = torch.zeros((k.size(0), 1) + k.size()[2:]) + # pyrefly: ignore [missing-attribute] + if k.is_quantized: + k_zeros = torch.quantize_per_tensor( + k_zeros, + # pyrefly: ignore [missing-attribute] + k.q_scale(), + # pyrefly: ignore [missing-attribute] + k.q_zero_point(), + # pyrefly: ignore [missing-attribute] + k.dtype, + ) + # pyrefly: ignore [no-matching-overload] + k = torch.cat([k, k_zeros], dim=1) + # pyrefly: ignore [missing-attribute] + v_zeros = torch.zeros((v.size(0), 1) + k.size()[2:]) + # pyrefly: ignore [missing-attribute] + if v.is_quantized: + v_zeros = torch.quantize_per_tensor( + v_zeros, + # pyrefly: ignore [missing-attribute] + v.q_scale(), + # pyrefly: ignore [missing-attribute] + v.q_zero_point(), + # pyrefly: ignore [missing-attribute] + v.dtype, + ) + # pyrefly: ignore [no-matching-overload] + v = torch.cat([v, v_zeros], dim=1) + + if attn_mask is not None: + attn_mask = F.pad(attn_mask, (0, 1)) + if key_padding_mask is not None: + key_padding_mask = F.pad(key_padding_mask, (0, 1)) + + # Leaving the quantized zone here + q = self.dequant_q(q) + k = self.dequant_k(k) + v = self.dequant_v(v) + attn_output_weights = torch.bmm(q, k.transpose(1, 2)) + assert list(attn_output_weights.size()) == [ + bsz * self.num_heads, + tgt_len, + src_len, + ] + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_output_weights.masked_fill_(attn_mask, float("-inf")) + else: + attn_output_weights += attn_mask + + if key_padding_mask is not None: + attn_output_weights = attn_output_weights.view( + bsz, self.num_heads, tgt_len, src_len + ) + attn_output_weights = attn_output_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + float("-inf"), + ) + attn_output_weights = attn_output_weights.view( + bsz * self.num_heads, tgt_len, src_len + ) + + attn_output_weights = F.softmax(attn_output_weights, dim=-1) + attn_output_weights = F.dropout( + attn_output_weights, p=self.dropout, training=self.training + ) + + attn_output = torch.bmm(attn_output_weights, v) + assert list(attn_output.size()) == [bsz * self.num_heads, tgt_len, head_dim] + if self.batch_first: + attn_output = attn_output.view(bsz, tgt_len, self.embed_dim) + else: + attn_output = ( + attn_output.transpose(0, 1) + .contiguous() + .view(tgt_len, bsz, self.embed_dim) + ) + + # Reentering the quantized zone + attn_output = self.quant_attn_output(attn_output) + # for the type: ignore[has-type], see https://github.com/pytorch/pytorch/issues/58969 + attn_output = self.out_proj(attn_output) # type: ignore[has-type] + attn_output_weights = self.quant_attn_output_weights(attn_output_weights) + + if need_weights: + # average attention weights over heads + attn_output_weights = attn_output_weights.view( + bsz, self.num_heads, tgt_len, src_len + ) + if average_attn_weights: + attn_output_weights = attn_output_weights.mean(dim=1) + return attn_output, attn_output_weights + else: + return attn_output, None diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantizable/modules/rnn.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantizable/modules/rnn.py new file mode 100644 index 0000000000000000000000000000000000000000..74e4bd902d1565360f72a5c4098b6e6d1590a146 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantizable/modules/rnn.py @@ -0,0 +1,604 @@ +""" +We will recreate all the RNN modules as we require the modules to be decomposed +into its building blocks to be able to observe. +""" + +# mypy: allow-untyped-defs + +import numbers +import warnings + +import torch +from torch import Tensor + + +__all__ = ["LSTMCell", "LSTM"] + + +class LSTMCell(torch.nn.Module): + r"""A quantizable long short-term memory (LSTM) cell. + + For the description and the argument types, please, refer to :class:`~torch.nn.LSTMCell` + + `split_gates`: specify True to compute the input/forget/cell/output gates separately + to avoid an intermediate tensor which is subsequently chunk'd. This optimization can + be beneficial for on-device inference latency. This flag is cascaded down from the + parent classes. + + Examples:: + + >>> import torch.ao.nn.quantizable as nnqa + >>> rnn = nnqa.LSTMCell(10, 20) + >>> input = torch.randn(6, 10) + >>> hx = torch.randn(3, 20) + >>> cx = torch.randn(3, 20) + >>> output = [] + >>> for i in range(6): + ... hx, cx = rnn(input[i], (hx, cx)) + ... output.append(hx) + """ + + _FLOAT_MODULE = torch.nn.LSTMCell + __constants__ = ["split_gates"] # for jit.script + + def __init__( + self, + input_dim: int, + hidden_dim: int, + bias: bool = True, + device=None, + dtype=None, + *, + split_gates=False, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.input_size = input_dim + self.hidden_size = hidden_dim + self.bias = bias + self.split_gates = split_gates + + if not split_gates: + self.igates: torch.nn.Module = torch.nn.Linear( + input_dim, 4 * hidden_dim, bias=bias, **factory_kwargs + ) + self.hgates: torch.nn.Module = torch.nn.Linear( + hidden_dim, 4 * hidden_dim, bias=bias, **factory_kwargs + ) + self.gates: torch.nn.Module = torch.ao.nn.quantized.FloatFunctional() + else: + # keep separate Linear layers for each gate + self.igates = torch.nn.ModuleDict() + self.hgates = torch.nn.ModuleDict() + self.gates = torch.nn.ModuleDict() + for g in ["input", "forget", "cell", "output"]: + # pyre-fixme[29]: `Union[torch._tensor.Tensor, torch.nn.modules.module.Module]` + self.igates[g] = torch.nn.Linear( + input_dim, hidden_dim, bias=bias, **factory_kwargs + ) + # pyre-fixme[29]: `Union[torch._tensor.Tensor, torch.nn.modules.module.Module]` + self.hgates[g] = torch.nn.Linear( + hidden_dim, hidden_dim, bias=bias, **factory_kwargs + ) + # pyre-fixme[29]: `Union[torch._tensor.Tensor, torch.nn.modules.module.Module]` + self.gates[g] = torch.ao.nn.quantized.FloatFunctional() + + self.input_gate = torch.nn.Sigmoid() + self.forget_gate = torch.nn.Sigmoid() + self.cell_gate = torch.nn.Tanh() + self.output_gate = torch.nn.Sigmoid() + + self.fgate_cx = torch.ao.nn.quantized.FloatFunctional() + self.igate_cgate = torch.ao.nn.quantized.FloatFunctional() + self.fgate_cx_igate_cgate = torch.ao.nn.quantized.FloatFunctional() + + self.ogate_cy = torch.ao.nn.quantized.FloatFunctional() + + self.initial_hidden_state_qparams: tuple[float, int] = (1.0, 0) + self.initial_cell_state_qparams: tuple[float, int] = (1.0, 0) + self.hidden_state_dtype: torch.dtype = torch.quint8 + self.cell_state_dtype: torch.dtype = torch.quint8 + + def forward( + self, x: Tensor, hidden: tuple[Tensor, Tensor] | None = None + ) -> tuple[Tensor, Tensor]: + if hidden is None or hidden[0] is None or hidden[1] is None: + hidden = self.initialize_hidden(x.shape[0], x.is_quantized) + hx, cx = hidden + + if not self.split_gates: + igates = self.igates(x) + hgates = self.hgates(hx) + gates = self.gates.add(igates, hgates) # type: ignore[operator] + + input_gate, forget_gate, cell_gate, out_gate = gates.chunk(4, 1) + + input_gate = self.input_gate(input_gate) + forget_gate = self.forget_gate(forget_gate) + cell_gate = self.cell_gate(cell_gate) + out_gate = self.output_gate(out_gate) + else: + # apply each input + hidden projection and add together + gate = {} + for (key, gates), igates, hgates in zip( + self.gates.items(), # type: ignore[operator] + self.igates.values(), # type: ignore[operator] + self.hgates.values(), # type: ignore[operator] + ): + gate[key] = gates.add(igates(x), hgates(hx)) + + input_gate = self.input_gate(gate["input"]) + forget_gate = self.forget_gate(gate["forget"]) + cell_gate = self.cell_gate(gate["cell"]) + out_gate = self.output_gate(gate["output"]) + + fgate_cx = self.fgate_cx.mul(forget_gate, cx) + igate_cgate = self.igate_cgate.mul(input_gate, cell_gate) + fgate_cx_igate_cgate = self.fgate_cx_igate_cgate.add(fgate_cx, igate_cgate) + cy = fgate_cx_igate_cgate + + # TODO: make this tanh a member of the module so its qparams can be configured + tanh_cy = torch.tanh(cy) + hy = self.ogate_cy.mul(out_gate, tanh_cy) + return hy, cy + + def initialize_hidden( + self, batch_size: int, is_quantized: bool = False + ) -> tuple[Tensor, Tensor]: + h, c = ( + torch.zeros((batch_size, self.hidden_size)), + torch.zeros((batch_size, self.hidden_size)), + ) + if is_quantized: + (h_scale, h_zp) = self.initial_hidden_state_qparams + (c_scale, c_zp) = self.initial_cell_state_qparams + h = torch.quantize_per_tensor( + h, scale=h_scale, zero_point=h_zp, dtype=self.hidden_state_dtype + ) + c = torch.quantize_per_tensor( + c, scale=c_scale, zero_point=c_zp, dtype=self.cell_state_dtype + ) + return h, c + + def _get_name(self): + return "QuantizableLSTMCell" + + @classmethod + def from_params(cls, wi, wh, bi=None, bh=None, split_gates=False): + """Uses the weights and biases to create a new LSTM cell. + + Args: + wi, wh: Weights for the input and hidden layers + bi, bh: Biases for the input and hidden layers + """ + assert (bi is None) == (bh is None) # Either both None or both have values + input_size = wi.shape[1] + hidden_size = wh.shape[1] + cell = cls( + input_dim=input_size, + hidden_dim=hidden_size, + bias=(bi is not None), + split_gates=split_gates, + ) + + if not split_gates: + cell.igates.weight = torch.nn.Parameter(wi) + if bi is not None: + cell.igates.bias = torch.nn.Parameter(bi) + cell.hgates.weight = torch.nn.Parameter(wh) + if bh is not None: + cell.hgates.bias = torch.nn.Parameter(bh) + else: + # split weight/bias + for w, b, gates in zip([wi, wh], [bi, bh], [cell.igates, cell.hgates]): + for w_chunk, gate in zip(w.chunk(4, dim=0), gates.values()): # type: ignore[operator] + gate.weight = torch.nn.Parameter(w_chunk) + + if b is not None: + for b_chunk, gate in zip(b.chunk(4, dim=0), gates.values()): # type: ignore[operator] + gate.bias = torch.nn.Parameter(b_chunk) + + return cell + + @classmethod + def from_float(cls, other, use_precomputed_fake_quant=False, split_gates=False): + assert type(other) is cls._FLOAT_MODULE + assert hasattr(other, "qconfig"), "The float module must have 'qconfig'" + observed = cls.from_params( + other.weight_ih, + other.weight_hh, + other.bias_ih, + other.bias_hh, + split_gates=split_gates, + ) + observed.qconfig = other.qconfig + observed.igates.qconfig = other.qconfig + observed.hgates.qconfig = other.qconfig + if split_gates: + # also apply qconfig directly to Linear modules + for g in observed.igates.values(): + g.qconfig = other.qconfig + for g in observed.hgates.values(): + g.qconfig = other.qconfig + return observed + + +class _LSTMSingleLayer(torch.nn.Module): + r"""A single one-directional LSTM layer. + + The difference between a layer and a cell is that the layer can process a + sequence, while the cell only expects an instantaneous value. + """ + + def __init__( + self, + input_dim: int, + hidden_dim: int, + bias: bool = True, + device=None, + dtype=None, + *, + split_gates=False, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.cell = LSTMCell( + input_dim, hidden_dim, bias=bias, split_gates=split_gates, **factory_kwargs + ) + + def forward(self, x: Tensor, hidden: tuple[Tensor, Tensor] | None = None): + result = [] + seq_len = x.shape[0] + for i in range(seq_len): + hidden = self.cell(x[i], hidden) + result.append(hidden[0]) # type: ignore[index] + result_tensor = torch.stack(result, 0) + return result_tensor, hidden + + @classmethod + def from_params(cls, *args, **kwargs): + cell = LSTMCell.from_params(*args, **kwargs) + layer = cls( + cell.input_size, cell.hidden_size, cell.bias, split_gates=cell.split_gates + ) + layer.cell = cell + return layer + + +class _LSTMLayer(torch.nn.Module): + r"""A single bi-directional LSTM layer.""" + + def __init__( + self, + input_dim: int, + hidden_dim: int, + bias: bool = True, + batch_first: bool = False, + bidirectional: bool = False, + device=None, + dtype=None, + *, + split_gates=False, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.batch_first = batch_first + self.bidirectional = bidirectional + self.layer_fw = _LSTMSingleLayer( + input_dim, hidden_dim, bias=bias, split_gates=split_gates, **factory_kwargs + ) + if self.bidirectional: + self.layer_bw = _LSTMSingleLayer( + input_dim, + hidden_dim, + bias=bias, + split_gates=split_gates, + **factory_kwargs, + ) + + def forward(self, x: Tensor, hidden: tuple[Tensor, Tensor] | None = None): + if self.batch_first: + x = x.transpose(0, 1) + if hidden is None: + hx_fw, cx_fw = (None, None) + else: + hx_fw, cx_fw = hidden + hidden_bw: tuple[Tensor, Tensor] | None = None + if self.bidirectional: + if hx_fw is None: + hx_bw = None + else: + hx_bw = hx_fw[1] + hx_fw = hx_fw[0] + if cx_fw is None: + cx_bw = None + else: + cx_bw = cx_fw[1] + cx_fw = cx_fw[0] + if hx_bw is not None and cx_bw is not None: + hidden_bw = hx_bw, cx_bw + if hx_fw is None and cx_fw is None: + hidden_fw = None + else: + hidden_fw = ( + torch.jit._unwrap_optional(hx_fw), + torch.jit._unwrap_optional(cx_fw), + ) + result_fw, hidden_fw = self.layer_fw(x, hidden_fw) + + if hasattr(self, "layer_bw") and self.bidirectional: + x_reversed = x.flip(0) + result_bw, hidden_bw = self.layer_bw(x_reversed, hidden_bw) + result_bw = result_bw.flip(0) + + result = torch.cat([result_fw, result_bw], result_fw.dim() - 1) + if hidden_fw is None and hidden_bw is None: + h = None + c = None + elif hidden_fw is None: + (h, c) = torch.jit._unwrap_optional(hidden_bw) + elif hidden_bw is None: + (h, c) = torch.jit._unwrap_optional(hidden_fw) + else: + h = torch.stack([hidden_fw[0], hidden_bw[0]], 0) # type: ignore[list-item] + c = torch.stack([hidden_fw[1], hidden_bw[1]], 0) # type: ignore[list-item] + else: + result = result_fw + h, c = torch.jit._unwrap_optional(hidden_fw) # type: ignore[assignment] + + if self.batch_first: + result.transpose_(0, 1) + + return result, (h, c) + + @classmethod + def from_float(cls, other, layer_idx=0, qconfig=None, **kwargs): + r""" + There is no FP equivalent of this class. This function is here just to + mimic the behavior of the `prepare` within the `torch.ao.quantization` + flow. + """ + assert hasattr(other, "qconfig") or (qconfig is not None) + + input_size = kwargs.get("input_size", other.input_size) + hidden_size = kwargs.get("hidden_size", other.hidden_size) + bias = kwargs.get("bias", other.bias) + batch_first = kwargs.get("batch_first", other.batch_first) + bidirectional = kwargs.get("bidirectional", other.bidirectional) + split_gates = kwargs.get("split_gates", False) + + layer = cls( + input_size, + hidden_size, + bias, + batch_first, + bidirectional, + split_gates=split_gates, + ) + # pyrefly: ignore [bad-argument-type] + layer.qconfig = getattr(other, "qconfig", qconfig) + wi = getattr(other, f"weight_ih_l{layer_idx}") + wh = getattr(other, f"weight_hh_l{layer_idx}") + bi = getattr(other, f"bias_ih_l{layer_idx}", None) + bh = getattr(other, f"bias_hh_l{layer_idx}", None) + + layer.layer_fw = _LSTMSingleLayer.from_params( + wi, wh, bi, bh, split_gates=split_gates + ) + + if other.bidirectional: + wi = getattr(other, f"weight_ih_l{layer_idx}_reverse") + wh = getattr(other, f"weight_hh_l{layer_idx}_reverse") + bi = getattr(other, f"bias_ih_l{layer_idx}_reverse", None) + bh = getattr(other, f"bias_hh_l{layer_idx}_reverse", None) + layer.layer_bw = _LSTMSingleLayer.from_params( + wi, wh, bi, bh, split_gates=split_gates + ) + return layer + + +class LSTM(torch.nn.Module): + r"""A quantizable long short-term memory (LSTM). + + For the description and the argument types, please, refer to :class:`~torch.nn.LSTM` + + Attributes: + layers : instances of the `_LSTMLayer` + + .. note:: + To access the weights and biases, you need to access them per layer. + See examples below. + + Examples:: + + >>> import torch.ao.nn.quantizable as nnqa + >>> rnn = nnqa.LSTM(10, 20, 2) + >>> input = torch.randn(5, 3, 10) + >>> h0 = torch.randn(2, 3, 20) + >>> c0 = torch.randn(2, 3, 20) + >>> output, (hn, cn) = rnn(input, (h0, c0)) + >>> # To get the weights: + >>> # xdoctest: +SKIP + >>> print(rnn.layers[0].weight_ih) + tensor([[...]]) + >>> print(rnn.layers[0].weight_hh) + AssertionError: There is no reverse path in the non-bidirectional layer + """ + + _FLOAT_MODULE = torch.nn.LSTM + + def __init__( + self, + input_size: int, + hidden_size: int, + num_layers: int = 1, + bias: bool = True, + batch_first: bool = False, + dropout: float = 0.0, + bidirectional: bool = False, + device=None, + dtype=None, + *, + split_gates: bool = False, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.num_layers = num_layers + self.bias = bias + self.batch_first = batch_first + self.dropout = float(dropout) + self.bidirectional = bidirectional + self.training = False # Default to eval mode. If we want to train, we will explicitly set to training. + + if ( + not isinstance(dropout, numbers.Number) + # pyrefly: ignore [unsupported-operation] + or not 0 <= dropout <= 1 + or isinstance(dropout, bool) + ): + raise ValueError( + "dropout should be a number in range [0, 1] " + "representing the probability of an element being " + "zeroed" + ) + # pyrefly: ignore [unsupported-operation] + if dropout > 0: + warnings.warn( + "dropout option for quantizable LSTM is ignored. " + "If you are training, please, use nn.LSTM version " + "followed by `prepare` step.", + stacklevel=2, + ) + if num_layers == 1: + warnings.warn( + "dropout option adds dropout after all but last " + "recurrent layer, so non-zero dropout expects " + f"num_layers greater than 1, but got dropout={dropout} " + f"and num_layers={num_layers}", + stacklevel=2, + ) + + layers = [ + _LSTMLayer( + self.input_size, + self.hidden_size, + self.bias, + batch_first=False, + bidirectional=self.bidirectional, + split_gates=split_gates, + **factory_kwargs, + ) + ] + layers.extend( + _LSTMLayer( + self.hidden_size, + self.hidden_size, + self.bias, + batch_first=False, + bidirectional=self.bidirectional, + split_gates=split_gates, + **factory_kwargs, + ) + for _ in range(1, num_layers) + ) + self.layers = torch.nn.ModuleList(layers) + + def forward(self, x: Tensor, hidden: tuple[Tensor, Tensor] | None = None): + if self.batch_first: + x = x.transpose(0, 1) + + max_batch_size = x.size(1) + num_directions = 2 if self.bidirectional else 1 + if hidden is None: + zeros = torch.zeros( + num_directions, + max_batch_size, + self.hidden_size, + dtype=torch.float, + device=x.device, + ) + zeros.squeeze_(0) + if x.is_quantized: + zeros = torch.quantize_per_tensor( + zeros, scale=1.0, zero_point=0, dtype=x.dtype + ) + hxcx = [(zeros, zeros) for _ in range(self.num_layers)] + else: + hidden_non_opt = torch.jit._unwrap_optional(hidden) + if isinstance(hidden_non_opt[0], Tensor): + hx = hidden_non_opt[0].reshape( + self.num_layers, num_directions, max_batch_size, self.hidden_size + ) + cx = hidden_non_opt[1].reshape( + self.num_layers, num_directions, max_batch_size, self.hidden_size + ) + hxcx = [ + (hx[idx].squeeze(0), cx[idx].squeeze(0)) + for idx in range(self.num_layers) + ] + else: + hxcx = hidden_non_opt + + hx_list = [] + cx_list = [] + for idx, layer in enumerate(self.layers): + x, (h, c) = layer(x, hxcx[idx]) + hx_list.append(torch.jit._unwrap_optional(h)) + cx_list.append(torch.jit._unwrap_optional(c)) + hx_tensor = torch.stack(hx_list) + cx_tensor = torch.stack(cx_list) + + # We are creating another dimension for bidirectional case + # need to collapse it + hx_tensor = hx_tensor.reshape(-1, hx_tensor.shape[-2], hx_tensor.shape[-1]) + cx_tensor = cx_tensor.reshape(-1, cx_tensor.shape[-2], cx_tensor.shape[-1]) + + if self.batch_first: + x = x.transpose(0, 1) + + return x, (hx_tensor, cx_tensor) + + def _get_name(self): + return "QuantizableLSTM" + + @classmethod + def from_float(cls, other, qconfig=None, split_gates=False): + assert isinstance(other, cls._FLOAT_MODULE) + assert hasattr(other, "qconfig") or qconfig + observed = cls( + other.input_size, + other.hidden_size, + other.num_layers, + other.bias, + other.batch_first, + other.dropout, + other.bidirectional, + split_gates=split_gates, + ) + # pyrefly: ignore [bad-argument-type] + observed.qconfig = getattr(other, "qconfig", qconfig) + for idx in range(other.num_layers): + observed.layers[idx] = _LSTMLayer.from_float( + other, idx, qconfig, batch_first=False, split_gates=split_gates + ) + + # Prepare the model + if other.training: + observed.train() + observed = torch.ao.quantization.prepare_qat(observed, inplace=True) + else: + observed.eval() + observed = torch.ao.quantization.prepare(observed, inplace=True) + return observed + + @classmethod + def from_observed(cls, other): + # The whole flow is float -> observed -> quantized + # This class does float -> observed only + raise NotImplementedError( + "It looks like you are trying to convert a " + "non-quantizable LSTM module. Please, see " + "the examples on quantizable LSTMs." + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..77e97d8595282f3d69963ee129fa473249e3ae29 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/__init__.py @@ -0,0 +1,39 @@ +from . import functional +from .modules import * # noqa: F403 +from .modules import MaxPool2d + + +__all__ = [ + "BatchNorm2d", + "BatchNorm3d", + "Conv1d", + "Conv2d", + "Conv3d", + "ConvTranspose1d", + "ConvTranspose2d", + "ConvTranspose3d", + "DeQuantize", + "ELU", + "Embedding", + "EmbeddingBag", + "GroupNorm", + "Hardswish", + "InstanceNorm1d", + "InstanceNorm2d", + "InstanceNorm3d", + "LayerNorm", + "LeakyReLU", + "Linear", + "LSTM", + "MultiheadAttention", + "Quantize", + "ReLU6", + "Sigmoid", + "Softmax", + "Dropout", + "PReLU", + # Wrapper modules + "FloatFunctional", + "FXFloatFunctional", + "QFunctional", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7210ee744833644080b38f610d182e28b773c204 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/__pycache__/functional.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/__pycache__/functional.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d827a52be10268d4fe0959bd2f7809b7e48173a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/__pycache__/functional.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/dynamic/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/dynamic/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3d79bdbfe83209f18b17cc8c7b245f322871d6c0 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/dynamic/__init__.py @@ -0,0 +1 @@ +from .modules import * # noqa: F403 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/dynamic/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/dynamic/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..806bfd8c83ca88b1bf5765d7201afdd91edce57d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/dynamic/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/dynamic/modules/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/dynamic/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..969fd6f121f5ddb72ed2e8e158e3ee7e990cfd0c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/dynamic/modules/__init__.py @@ -0,0 +1,26 @@ +from .conv import ( + Conv1d, + Conv2d, + Conv3d, + ConvTranspose1d, + ConvTranspose2d, + ConvTranspose3d, +) +from .linear import Linear +from .rnn import GRU, GRUCell, LSTM, LSTMCell, RNNCell + + +__all__ = [ + "Linear", + "LSTM", + "GRU", + "LSTMCell", + "RNNCell", + "GRUCell", + "Conv1d", + "Conv2d", + "Conv3d", + "ConvTranspose1d", + "ConvTranspose2d", + "ConvTranspose3d", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..89693f18aa8dafeba6ffb76c78ce06906883372a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/conv.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/conv.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3010eb0b1f8f22a1ebdfbc8cd977c2157b479642 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/conv.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/linear.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/linear.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6aa50cf9a4918e9f31245063fa154db4751130e5 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/linear.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/rnn.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/rnn.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d67ed3494cdcf08f180146707bedd025b22b1d7c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/rnn.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/dynamic/modules/conv.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/dynamic/modules/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..68c3f6acd093477a44057ade1fb48107709eda89 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/dynamic/modules/conv.py @@ -0,0 +1,530 @@ +# mypy: allow-untyped-defs +r"""Dynamically quantized convolution modules.""" + +import warnings +from typing import ClassVar, Literal + +import torch +import torch.ao.nn.quantized as nnq +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch._ops import ops +from torch.ao.nn.quantized.modules.conv import _reverse_repeat_padding +from torch.nn.common_types import _size_1_t +from torch.nn.modules.utils import _pair, _single, _triple + + +__all__ = [ + "Conv1d", + "Conv2d", + "Conv3d", + "ConvTranspose1d", + "ConvTranspose2d", + "ConvTranspose3d", +] + + +class Conv1d(nnq.Conv1d): + r"""A dynamically quantized conv module with floating point tensors as inputs and outputs. + + For details on input arguments, parameters, and implementation see + :class:`~torch.nn.Conv1d` and :class:`~torch.ao.nn.quantized.dynamic.Conv1d` and + + Attributes: + weight (Tensor): packed tensor derived from the learnable weight + parameter. + scale (Tensor): scalar for the output scale + zero_point (Tensor): scalar for the output zero point + + See :class:`~torch.nn.Conv1d` for other attributes. + + Examples:: + + >>> # xdoctest: +SKIP + >>> m = nn.quantized.dynamic.Conv1d(16, 33, 3, stride=2) + >>> input = torch.randn(20, 16, 100) + >>> output = m(input) + + """ + + _FLOAT_MODULE: ClassVar[type[nn.Conv1d]] = nn.Conv1d + _NNIQAT_CONV_BN_MODULE: ClassVar[type[nn.Module] | None] = None + _NNI_CONV_RELU_MODULE: ClassVar[type[nn.Module] | None] = None + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_1_t, + stride: _size_1_t = 1, + padding: _size_1_t = 0, + dilation: _size_1_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: Literal["zeros", "reflect", "replicate", "circular"] = "zeros", + device=None, + dtype=None, + reduce_range=True, + ): + warnings.warn( + f"The current implementation of the {self._get_name()} module has poor numerical accuracy and its use is not recommended", # noqa: B950 + stacklevel=2, + ) + factory_kwargs = {"device": device, "dtype": dtype} + kernel_size = _single(kernel_size) + stride = _single(stride) + # pyrefly: ignore [bad-assignment] + padding = padding if isinstance(padding, str) else _single(padding) + dilation = _single(dilation) + + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + padding_mode, + **factory_kwargs, + ) + + def _get_name(self): + return "DynamicQuantizedConv1d" + + def forward(self, input: Tensor, reduce_range: bool = True) -> Tensor: + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 3: + raise ValueError("Input shape must be `(N, C, L)`!") + if self.padding_mode != "zeros": + # Padding in Conv1d is stored as (p, p), need to get (p,) + _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding[:1]) + input = F.pad( + input, _reversed_padding_repeated_twice, mode=self.padding_mode + ) + return ops.quantized.conv1d_dynamic(input, self._packed_params, reduce_range) + + +class Conv2d(nnq.Conv2d): + r"""A dynamically quantized conv module with floating point tensors as inputs and outputs. + + For details on input arguments, parameters, and implementation see + :class:`~torch.nn.Conv2d` and :class:`~torch.ao.nn.quantized.dynamic.Conv2d` and + + Attributes: + weight (Tensor): packed tensor derived from the learnable weight + parameter. + scale (Tensor): scalar for the output scale + zero_point (Tensor): scalar for the output zero point + + See :class:`~torch.nn.Conv2d` for other attributes. + + Examples:: + + >>> # xdoctest: +SKIP + >>> # With square kernels and equal stride + >>> m = nn.quantized.dynamic.Conv2d(16, 33, 3, stride=2) + >>> # non-square kernels and unequal stride and with padding + >>> m = nn.quantized.dynamic.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2)) + >>> # non-square kernels and unequal stride and with padding and dilation + >>> m = nn.quantized.dynamic.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1)) + >>> input = torch.randn(20, 16, 50, 100) + >>> output = m(input) + + """ + + _FLOAT_MODULE: ClassVar[type[nn.Conv2d]] = nn.Conv2d + _NNIQAT_CONV_BN_MODULE: ClassVar[type[nn.Module] | None] = None + _NNI_CONV_RELU_MODULE: ClassVar[type[nn.Module] | None] = None + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode="zeros", + device=None, + dtype=None, + ): + warnings.warn( + f"The current implementation of the {self._get_name()} module " + "has poor numerical accuracy and its use is not recommended", + stacklevel=2, + ) + factory_kwargs = {"device": device, "dtype": dtype} + kernel_size = _pair(kernel_size) + stride = _pair(stride) + padding = _pair(padding) + dilation = _pair(dilation) + + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + padding_mode, + **factory_kwargs, + ) + + def _get_name(self): + return "DynamicQuantizedConv2d" + + def forward(self, input: Tensor, reduce_range: bool = True) -> Tensor: + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 4: + raise ValueError("Input shape must be `(N, C, H, W)`!") + if self.padding_mode != "zeros": + _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding) + input = F.pad( + input, _reversed_padding_repeated_twice, mode=self.padding_mode + ) + return ops.quantized.conv2d_dynamic(input, self._packed_params, reduce_range) + + +class Conv3d(nnq.Conv3d): + r"""A dynamically quantized conv module with floating point tensors as inputs and outputs. + + For details on input arguments, parameters, and implementation see + :class:`~torch.nn.Conv3d` and :class:`~torch.ao.nn.quantized.dynamic.Conv3d` and + + Attributes: + weight (Tensor): packed tensor derived from the learnable weight + parameter. + scale (Tensor): scalar for the output scale + zero_point (Tensor): scalar for the output zero point + + See :class:`~torch.nn.Conv3d` for other attributes. + + Examples:: + + >>> # xdoctest: +SKIP + >>> # With square kernels and equal stride + >>> m = nn.quantized.dynamic.Conv3d(16, 33, 3, stride=2) + >>> # non-square kernels and unequal stride and with padding + >>> m = nn.quantized.dynamic.Conv3d(16, 33, (3, 5, 5), stride=(1, 2, 2), padding=(1, 2, 2)) + >>> # non-square kernels and unequal stride and with padding and dilation + >>> m = nn.quantized.dynamic.Conv3d(16, 33, (3, 5, 5), stride=(1, 2, 2), padding=(1, 2, 2), dilation=(1, 2, 2)) + >>> input = torch.randn(20, 16, 56, 56, 56) + >>> output = m(input) + + """ + + _FLOAT_MODULE: ClassVar[type[nn.Conv3d]] = nn.Conv3d + _NNIQAT_CONV_BN_MODULE: ClassVar[type[nn.Module] | None] = None + _NNI_CONV_RELU_MODULE: ClassVar[type[nn.Module] | None] = None + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode="zeros", + device=None, + dtype=None, + ): + warnings.warn( + f"The current implementation of the {self._get_name()} module has poor numerical accuracy and its use is not recommended", # noqa: B950 + stacklevel=2, + ) + assert padding_mode != "reflect", "Conv3d does not support reflection padding" + factory_kwargs = {"device": device, "dtype": dtype} + kernel_size = _triple(kernel_size) + stride = _triple(stride) + padding = _triple(padding) + dilation = _triple(dilation) + super()._init( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + False, + _triple(0), + groups, + bias, + padding_mode, + **factory_kwargs, + ) + + def _get_name(self): + return "DynamicQuantizedConv3d" + + def forward(self, input: Tensor, reduce_range: bool = True) -> Tensor: + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 5: + raise ValueError("Input shape must be `(N, C, D, H, W)`!") + if self.padding_mode != "zeros": + _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding) + input = F.pad( + input, _reversed_padding_repeated_twice, mode=self.padding_mode + ) + return ops.quantized.conv3d_dynamic(input, self._packed_params, reduce_range) + + +class ConvTranspose1d(nnq.ConvTranspose1d): + r"""A dynamically quantized transposed convolution module with floating point tensors as inputs and outputs. + + For details on input arguments, parameters, and implementation see + :class:`~torch.nn.ConvTranspose1d`. + + For special notes, please, see :class:`~torch.ao.nn.quantized.dynamic.Conv1d` + + Attributes: + weight (Tensor): packed tensor derived from the learnable weight + parameter. + scale (Tensor): scalar for the output scale + zero_point (Tensor): scalar for the output zero point + See :class:`~torch.nn.ConvTranspose1d` for other attributes. + + Examples:: + + >>> # xdoctest: +SKIP + >>> # With square kernels and equal stride + >>> m = nndq.ConvTranspose1d(16, 33, 3, stride=2) + >>> # non-square kernels and unequal stride and with padding + >>> m = nndq.ConvTranspose1d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2)) + >>> output = m(input) + >>> # exact output size can be also specified as an argument + >>> downsample = nndq.Conv1d(16, 16, 3, stride=2, padding=1) + >>> upsample = nndq.ConvTranspose1d(16, 16, 3, stride=2, padding=1) + >>> h = downsample(input) + >>> h.size() + torch.Size([1, 16, 6]) + >>> output = upsample(h, output_size=input.size()) + >>> output.size() + torch.Size([1, 16, 12]) + """ + + _FLOAT_MODULE: ClassVar[type[nn.ConvTranspose1d]] = nn.ConvTranspose1d + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + output_padding=0, + groups=1, + bias=True, + dilation=1, + padding_mode="zeros", + device=None, + dtype=None, + ): + warnings.warn( + f"The current implementation of the {self._get_name()} module has poor numerical accuracy and its use is not recommended", # noqa: B950 + stacklevel=2, + ) + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + output_padding, + groups, + bias, + dilation, + padding_mode, + **factory_kwargs, + ) + + def _get_name(self): + return "DynamicQuantizedConvTranspose1d" + + def forward(self, input: Tensor, reduce_range: bool = True) -> Tensor: + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 3: + raise ValueError("Input shape must be `(N, C, L)`!") + return torch.ops.quantized.conv_transpose1d_dynamic( + input, self._packed_params, reduce_range + ) + + +class ConvTranspose2d(nnq.ConvTranspose2d): + r"""A dynamically quantized transposed convolution module with floating point tensors as inputs and outputs. + + For details on input arguments, parameters, and implementation see + :class:`~torch.nn.ConvTranspose2d`. + + For special notes, please, see :class:`~torch.ao.nn.quantized.dynamic.Conv2d` + + Attributes: + weight (Tensor): packed tensor derived from the learnable weight + parameter. + scale (Tensor): scalar for the output scale + zero_point (Tensor): scalar for the output zero point + See :class:`~torch.nn.ConvTranspose2d` for other attributes. + + Examples:: + + >>> # xdoctest: +SKIP + >>> # With square kernels and equal stride + >>> m = nnq.ConvTranspose2d(16, 33, 3, stride=2) + >>> # non-square kernels and unequal stride and with padding + >>> m = nnq.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2)) + >>> output = m(input) + >>> # exact output size can be also specified as an argument + >>> downsample = nnq.Conv2d(16, 16, 3, stride=2, padding=1) + >>> upsample = nnq.ConvTranspose2d(16, 16, 3, stride=2, padding=1) + >>> h = downsample(input) + >>> h.size() + torch.Size([1, 16, 6, 6]) + >>> output = upsample(h, output_size=input.size()) + >>> output.size() + torch.Size([1, 16, 12, 12]) + """ + + _FLOAT_MODULE: ClassVar[type[nn.ConvTranspose2d]] = nn.ConvTranspose2d + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + output_padding=0, + groups=1, + bias=True, + dilation=1, + padding_mode="zeros", + device=None, + dtype=None, + ): + warnings.warn( + f"The current implementation of the {self._get_name()} module has poor numerical accuracy and its use is not recommended", # noqa: B950 + stacklevel=2, + ) + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + output_padding, + groups, + bias, + dilation, + padding_mode, + **factory_kwargs, + ) + + def _get_name(self): + return "DynamicQuantizedConvTranspose2d" + + def forward(self, input: Tensor, reduce_range: bool = True) -> Tensor: + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 4: + raise ValueError("Input shape must be `(N, C, H, W)`!") + return ops.quantized.conv_transpose2d_dynamic( + input, self._packed_params, reduce_range + ) + + +class ConvTranspose3d(nnq.ConvTranspose3d): + r"""A dynamically quantized transposed convolution module with floating point tensors as inputs and outputs. + + For details on input arguments, parameters, and implementation see + :class:`~torch.nn.ConvTranspose3d`. + + For special notes, please, see :class:`~torch.ao.nn.quantized.dynamic.Conv3d` + + Attributes: + weight (Tensor): packed tensor derived from the learnable weight + parameter. + scale (Tensor): scalar for the output scale + zero_point (Tensor): scalar for the output zero point + See :class:`~torch.nn.ConvTranspose3d` for other attributes. + + Examples:: + + >>> # xdoctest: +SKIP + >>> # With cubic kernels and equal stride + >>> m = nnq.ConvTranspose3d(16, 33, 3, stride=2) + >>> # non-cubic kernels and unequal stride and with padding + >>> m = nnq.ConvTranspose3d(16, 33, (3, 3, 5), stride=(2, 1, 1), padding=(4, 2, 2)) + >>> output = m(input) + >>> # exact output size can be also specified as an argument + >>> downsample = nnq.Conv3d(16, 16, 3, stride=2, padding=1) + >>> upsample = nnq.ConvTranspose3d(16, 16, 3, stride=2, padding=1) + >>> h = downsample(input) + >>> h.size() + torch.Size([1, 16, 6, 6, 6]) + >>> output = upsample(h, output_size=input.size()) + >>> output.size() + torch.Size([1, 16, 12, 12, 12]) + """ + + _FLOAT_MODULE: ClassVar[type[nn.ConvTranspose3d]] = nn.ConvTranspose3d + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + output_padding=0, + groups=1, + bias=True, + dilation=1, + padding_mode="zeros", + device=None, + dtype=None, + ): + warnings.warn( + f"The current implementation of the {self._get_name()} module has poor numerical accuracy and its use is not recommended", # noqa: B950 + stacklevel=2, + ) + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + output_padding, + groups, + bias, + dilation, + padding_mode, + **factory_kwargs, + ) + + def _get_name(self): + return "DynamicQuantizedConvTranspose3d" + + def forward(self, input: Tensor, reduce_range: bool = True) -> Tensor: + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 5: + raise ValueError("Input shape must be `(N, C, T, H, W)`!") + return ops.quantized.conv_transpose3d_dynamic( + input, self._packed_params, reduce_range + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/dynamic/modules/linear.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/dynamic/modules/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..523ff78c31cf141e680e0a3374bcb5f1252cf7d7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/dynamic/modules/linear.py @@ -0,0 +1,168 @@ +# mypy: allow-untyped-defs +import torch +import torch.ao.nn.intrinsic as nni +import torch.ao.nn.quantized as nnq +from torch.ao.nn.quantized.modules.utils import _quantize_weight + + +__all__ = [ + "Linear", +] + + +class Linear(nnq.Linear): + r""" + A dynamic quantized linear module with floating point tensor as inputs and outputs. + We adopt the same interface as `torch.nn.Linear`, please see + https://pytorch.org/docs/stable/nn.html#torch.nn.Linear for documentation. + + Similar to :class:`torch.nn.Linear`, attributes will be randomly + initialized at module creation time and will be overwritten later + + Attributes: + weight (Tensor): the non-learnable quantized weights of the module which are of + shape :math:`(\text{out\_features}, \text{in\_features})`. + bias (Tensor): the non-learnable floating point bias of the module of shape + :math:`(\text{out\_features})`. If :attr:`bias` is ``True``, + the values are initialized to zero. + + Examples:: + + >>> # xdoctest: +SKIP + >>> m = nn.quantized.dynamic.Linear(20, 30) + >>> input = torch.randn(128, 20) + >>> output = m(input) + >>> print(output.size()) + torch.Size([128, 30]) + """ + + # version used in this class is different from the parent class nnq.Linear + _version = 4 + + def __init__(self, in_features, out_features, bias_=True, dtype=torch.qint8): + super().__init__(in_features, out_features, bias_, dtype=dtype) + # We don't muck around with buffers or attributes or anything here + # to keep the module simple. *everything* is simply a Python attribute. + # Serialization logic is explicitly handled in the below serialization and + # deserialization modules + self.version = 4 + + def forward(self, x): + # Note that we can handle self.bias == None case. + if self._packed_params.dtype == torch.qint8: + if self.version is None or self.version < 4: + Y = torch.ops.quantized.linear_dynamic( + x, self._packed_params._packed_params + ) + else: + Y = torch.ops.quantized.linear_dynamic( + x, self._packed_params._packed_params, reduce_range=True + ) + elif self._packed_params.dtype == torch.float16: + Y = torch.ops.quantized.linear_dynamic_fp16( + x, self._packed_params._packed_params + ) + else: + raise RuntimeError("Unsupported dtype on dynamic quantized linear!") + return Y.to(x.dtype) + + def _get_name(self): + return "DynamicQuantizedLinear" + + def extra_repr(self): + extra_repr_str = f"in_features={self.in_features}, out_features={self.out_features}, dtype={self._packed_params.dtype}" + if self._packed_params.dtype == torch.qint8: + extra_repr_str += f", qscheme={self.weight().qscheme()}" + return extra_repr_str + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + version = local_metadata.get("version", None) + self.version = version + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + False, + missing_keys, + unexpected_keys, + error_msgs, + ) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + r"""Create a dynamic quantized module from a float module or qparams_dict + + Args: + mod (Module): a float module, either produced by torch.ao.quantization + utilities or provided by the user + """ + float_modules = [ + torch.nn.Linear, + torch.nn.modules.linear.NonDynamicallyQuantizableLinear, + torch.ao.nn.intrinsic.modules.fused.LinearReLU, + torch.ao.nn.qat.dynamic.Linear, + ] + + assert type(mod) in float_modules, ( + "nn.quantized.dynamic.Linear.from_float only works for one of" + + str([float_mod.__name__ for float_mod in float_modules]) + ) + assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" + if type(mod) is nni.LinearReLU: + mod = mod[0] + # pyrefly: ignore [missing-attribute] + if mod.qconfig is not None and mod.qconfig.weight is not None: + # pyrefly: ignore [not-callable] + weight_observer = mod.qconfig.weight() + else: + # We have the circular import issues if we import the qconfig in the beginning of this file: + # https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the + # import until we need it. + from torch.ao.quantization.qconfig import default_dynamic_qconfig + + weight_observer = default_dynamic_qconfig.weight() + dtype = weight_observer.dtype + assert dtype in [torch.qint8, torch.float16], ( + "The only supported dtypes for " + f"dynamic quantized linear are qint8 and float16 got: {dtype}" + ) + weight_observer(mod.weight) + if dtype == torch.qint8: + qweight = _quantize_weight(mod.weight.float(), weight_observer) + elif dtype == torch.float16: + qweight = mod.weight.float() + else: + raise RuntimeError( + "Unsupported dtype specified for dynamic quantized Linear!" + ) + qlinear = cls(mod.in_features, mod.out_features, dtype=dtype) + # pyrefly: ignore [bad-argument-type] + qlinear.set_weight_bias(qweight, mod.bias) + return qlinear + + @classmethod + def from_reference(cls, ref_qlinear): # type: ignore[override] + """Create a (fbgemm/qnnpack) dynamic quantized module from a reference quantized + module + Args: + ref_qlinear (Module): a reference quantized module, either produced by + torch.ao.quantization functions or provided by the user + """ + qlinear = cls( + ref_qlinear.in_features, + ref_qlinear.out_features, + dtype=ref_qlinear.weight_dtype, + ) + qweight = ref_qlinear.get_quantized_weight() + bias = ref_qlinear.bias + qlinear.set_weight_bias(qweight, bias) + return qlinear diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/dynamic/modules/rnn.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/dynamic/modules/rnn.py new file mode 100644 index 0000000000000000000000000000000000000000..1ebe4b6a15af499f38a0d70ca93870cf1d6c224f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/dynamic/modules/rnn.py @@ -0,0 +1,1366 @@ +# mypy: allow-untyped-defs +import numbers +import warnings +from typing_extensions import deprecated + +import torch +import torch.nn as nn +from torch import Tensor # noqa: F401 +from torch._jit_internal import Dict, List, Optional, Tuple, Union # noqa: F401 +from torch.ao.nn.quantized.modules.utils import _quantize_weight +from torch.nn.utils.rnn import PackedSequence + + +__all__ = [ + "pack_weight_bias", + "PackedParameter", + "RNNBase", + "LSTM", + "GRU", + "RNNCellBase", + "RNNCell", + "LSTMCell", + "GRUCell", + "apply_permutation", +] + + +def _apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor: + return tensor.index_select(dim, permutation) + + +@deprecated( + "`apply_permutation` is deprecated, please use `tensor.index_select(dim, permutation)` instead", + category=FutureWarning, +) +def apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor: + return _apply_permutation(tensor, permutation, dim) + + +def pack_weight_bias(qweight, bias, dtype): + if dtype == torch.qint8: + # for each layer, for each direction we need to quantize and pack + # weights and pack parameters in this order: + # + # w_ih, w_hh + packed_weight = torch.ops.quantized.linear_prepack(qweight, bias) + + return packed_weight + else: + # for each layer, for each direction we need to quantize and pack + # weights and pack parameters in this order: + # + # packed_ih, packed_hh, b_ih, b_hh + packed_weight = torch.ops.quantized.linear_prepack_fp16(qweight, bias) + + return packed_weight + + +class PackedParameter(torch.nn.Module): + def __init__(self, param): + super().__init__() + self.param = param + + def _save_to_state_dict(self, destination, prefix, keep_vars): + super()._save_to_state_dict(destination, prefix, keep_vars) + destination[prefix + "param"] = self.param + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + self.param = state_dict[prefix + "param"] + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + False, + missing_keys, + unexpected_keys, + error_msgs, + ) + + +class RNNBase(torch.nn.Module): + _FLOAT_MODULE = nn.RNNBase + + _version = 2 + + def __init__( + self, + mode, + input_size, + hidden_size, + num_layers=1, + bias=True, + batch_first=False, + dropout=0.0, + bidirectional=False, + dtype=torch.qint8, + ): + super().__init__() + + self.mode = mode + self.input_size = input_size + self.hidden_size = hidden_size + self.num_layers = num_layers + self.bias = bias + self.batch_first = batch_first + self.dropout = float(dropout) + self.bidirectional = bidirectional + self.dtype = dtype + self.version = 2 + self.training = False + num_directions = 2 if bidirectional else 1 + + # "type: ignore" is required since ints and Numbers are not fully comparable + # https://github.com/python/mypy/issues/8566 + if ( + not isinstance(dropout, numbers.Number) + or not 0 <= dropout <= 1 # type: ignore[operator] + or isinstance(dropout, bool) + ): + raise ValueError( + "dropout should be a number in range [0, 1] " + "representing the probability of an element being " + "zeroed" + ) + if dropout > 0 and num_layers == 1: # type: ignore[operator] + warnings.warn( + "dropout option adds dropout after all but last " + "recurrent layer, so non-zero dropout expects " + f"num_layers greater than 1, but got dropout={dropout} and " + f"num_layers={num_layers}", + stacklevel=2, + ) + + if mode == "LSTM": + gate_size = 4 * hidden_size + elif mode == "GRU": + gate_size = 3 * hidden_size + else: + raise ValueError("Unrecognized RNN mode: " + mode) + + _all_weight_values = [] + for layer in range(num_layers): + for _ in range(num_directions): + layer_input_size = ( + input_size if layer == 0 else hidden_size * num_directions + ) + + w_ih = torch.randn(gate_size, layer_input_size).to(torch.float) + w_hh = torch.randn(gate_size, hidden_size).to(torch.float) + b_ih = torch.randn(gate_size).to(torch.float) + b_hh = torch.randn(gate_size).to(torch.float) + if dtype == torch.qint8: + w_ih = torch.quantize_per_tensor( + w_ih, scale=0.1, zero_point=0, dtype=torch.qint8 + ) + w_hh = torch.quantize_per_tensor( + w_hh, scale=0.1, zero_point=0, dtype=torch.qint8 + ) + packed_ih = torch.ops.quantized.linear_prepack(w_ih, b_ih) + packed_hh = torch.ops.quantized.linear_prepack(w_hh, b_hh) + if self.version is None or self.version < 2: + cell_params = ( + torch.ops.quantized.make_quantized_cell_params_dynamic( + packed_ih, packed_hh, b_ih, b_hh + ) + ) + else: + cell_params = ( + torch.ops.quantized.make_quantized_cell_params_dynamic( + packed_ih, packed_hh, b_ih, b_hh, True + ) + ) + else: + packed_ih = torch.ops.quantized.linear_prepack_fp16(w_ih, b_ih) + packed_hh = torch.ops.quantized.linear_prepack_fp16(w_hh, b_hh) + cell_params = torch.ops.quantized.make_quantized_cell_params_fp16( + packed_ih, packed_hh + ) + + _all_weight_values.append(PackedParameter(cell_params)) + self._all_weight_values = torch.nn.ModuleList(_all_weight_values) + + def _get_name(self): + return "DynamicQuantizedRNN" + + def extra_repr(self): + s = "{input_size}, {hidden_size}" + if self.num_layers != 1: + s += ", num_layers={num_layers}" + if self.bias is not True: + s += ", bias={bias}" + if self.batch_first is not False: + s += ", batch_first={batch_first}" + if self.dropout != 0: + s += ", dropout={dropout}" + if self.bidirectional is not False: + s += ", bidirectional={bidirectional}" + return s.format(**self.__dict__) + + def __repr__(self): + # We don't want to show `ModuleList` children, hence custom + # `__repr__`. This is the same as nn.Module.__repr__, except the check + # for the `PackedParameter` and `nn.ModuleList`. + # You should still override `extra_repr` to add more info. + extra_lines = [] + extra_repr = self.extra_repr() + # empty string will be split into list [''] + if extra_repr: + extra_lines = extra_repr.split("\n") + child_lines = [] + for key, module in self._modules.items(): + if isinstance(module, (PackedParameter, nn.ModuleList)): + continue + mod_str = repr(module) + mod_str = nn.modules.module._addindent(mod_str, 2) + child_lines.append("(" + key + "): " + mod_str) + lines = extra_lines + child_lines + + main_str = self._get_name() + "(" + if lines: + # simple one-liner info, which most builtin Modules will use + if len(extra_lines) == 1 and not child_lines: + main_str += extra_lines[0] + else: + main_str += "\n " + "\n ".join(lines) + "\n" + + main_str += ")" + return main_str + + def check_input(self, input: Tensor, batch_sizes: Optional[Tensor]) -> None: + expected_input_dim = 2 if batch_sizes is not None else 3 + if input.dim() != expected_input_dim: + raise RuntimeError( + f"input must have {expected_input_dim} dimensions, got {input.dim()}" + ) + if self.input_size != input.size(-1): + raise RuntimeError( + f"input.size(-1) must be equal to input_size. Expected {self.input_size}, got {input.size(-1)}" + ) + + def get_expected_hidden_size( + self, input: Tensor, batch_sizes: Optional[Tensor] + ) -> tuple[int, int, int]: + if batch_sizes is not None: + mini_batch = int(batch_sizes[0]) + else: + mini_batch = input.size(0) if self.batch_first else input.size(1) + num_directions = 2 if self.bidirectional else 1 + expected_hidden_size = ( + self.num_layers * num_directions, + mini_batch, + self.hidden_size, + ) + return expected_hidden_size + + def check_hidden_size( + self, + hx: Tensor, + expected_hidden_size: tuple[int, int, int], + msg: str = "Expected hidden size {}, got {}", + ) -> None: + if hx.size() != expected_hidden_size: + raise RuntimeError(msg.format(expected_hidden_size, list(hx.size()))) + + def check_forward_args( + self, input: Tensor, hidden: Tensor, batch_sizes: Optional[Tensor] + ) -> None: + self.check_input(input, batch_sizes) + expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes) + self.check_hidden_size( + hidden, expected_hidden_size, msg="Expected hidden size {}, got {}" + ) + + def permute_hidden(self, hx: Tensor, permutation: Optional[Tensor]) -> Tensor: + if permutation is None: + return hx + return _apply_permutation(hx, permutation) + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + version = local_metadata.get("version", None) + self.version = version + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + False, + missing_keys, + unexpected_keys, + error_msgs, + ) + + def set_weight_bias(self, weight_bias_dict): + def weight_bias_name(ihhh, layer, suffix): + weight_name = f"weight_{ihhh}_l{layer}{suffix}" + bias_name = f"bias_{ihhh}_l{layer}{suffix}" + return weight_name, bias_name + + num_directions = 2 if self.bidirectional else 1 + # TODO: dedup with __init__ of RNNBase + _all_weight_values = [] + for layer in range(self.num_layers): + for direction in range(num_directions): + suffix = "_reverse" if direction == 1 else "" + w_ih_name, b_ih_name = weight_bias_name("ih", layer, suffix) + w_hh_name, b_hh_name = weight_bias_name("hh", layer, suffix) + w_ih = weight_bias_dict[w_ih_name] + b_ih = weight_bias_dict[b_ih_name] + w_hh = weight_bias_dict[w_hh_name] + b_hh = weight_bias_dict[b_hh_name] + if w_ih.dtype == torch.qint8: + packed_ih = torch.ops.quantized.linear_prepack(w_ih, b_ih) + packed_hh = torch.ops.quantized.linear_prepack(w_hh, b_hh) + if self.version is None or self.version < 2: + cell_params = ( + torch.ops.quantized.make_quantized_cell_params_dynamic( + packed_ih, packed_hh, b_ih, b_hh + ) + ) + else: + cell_params = ( + torch.ops.quantized.make_quantized_cell_params_dynamic( + packed_ih, packed_hh, b_ih, b_hh, True + ) + ) + else: + packed_ih = torch.ops.quantized.linear_prepack_fp16(w_ih, b_ih) + packed_hh = torch.ops.quantized.linear_prepack_fp16(w_hh, b_hh) + cell_params = torch.ops.quantized.make_quantized_cell_params_fp16( + packed_ih, packed_hh + ) + + _all_weight_values.append(PackedParameter(cell_params)) + self._all_weight_values = torch.nn.ModuleList(_all_weight_values) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + assert type(mod) in { + torch.nn.LSTM, + torch.nn.GRU, + }, "nn.quantized.dynamic.RNNBase.from_float only works for nn.LSTM and nn.GRU" + assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" + + if mod.qconfig is not None and mod.qconfig.weight is not None: + weight_observer_method = mod.qconfig.weight + else: + # We have the circular import issues if we import the qconfig in the beginning of this file: + # https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the + # import until we need it. + from torch.ao.quantization.qconfig import default_dynamic_qconfig + + weight_observer_method = default_dynamic_qconfig.weight + + dtype = weight_observer_method().dtype + supported_scalar_types = [torch.qint8, torch.float16] + if dtype not in supported_scalar_types: + raise RuntimeError( + f"Unsupported dtype for dynamic RNN quantization: {dtype}" + ) + # RNNBase can be either LSTM or GRU + qRNNBase: Union[LSTM, GRU] + if mod.mode == "LSTM": + qRNNBase = LSTM( + mod.input_size, + mod.hidden_size, + mod.num_layers, + mod.bias, + mod.batch_first, + mod.dropout, + mod.bidirectional, + dtype, + ) + elif mod.mode == "GRU": + qRNNBase = GRU( + mod.input_size, + mod.hidden_size, + mod.num_layers, + mod.bias, + mod.batch_first, + mod.dropout, + mod.bidirectional, + dtype, + ) + else: + raise NotImplementedError( + "Only LSTM/GRU is supported for QuantizedRNN for now" + ) + + num_directions = 2 if mod.bidirectional else 1 + + assert mod.bias + + _all_weight_values = [] + for layer in range(qRNNBase.num_layers): + for direction in range(num_directions): + suffix = "_reverse" if direction == 1 else "" + + def retrieve_weight_bias(ihhh): + weight_name = f"weight_{ihhh}_l{layer}{suffix}" + bias_name = f"bias_{ihhh}_l{layer}{suffix}" + weight = getattr(mod, weight_name) + bias = getattr(mod, bias_name) + return weight, bias + + weight_ih, bias_ih = retrieve_weight_bias("ih") + weight_hh, bias_hh = retrieve_weight_bias("hh") + + if dtype == torch.qint8: + + def quantize_and_pack(w, b): + weight_observer = weight_observer_method() + weight_observer(w) + qweight = _quantize_weight(w.float(), weight_observer) + packed_weight = torch.ops.quantized.linear_prepack(qweight, b) + return packed_weight + + packed_ih = quantize_and_pack(weight_ih, bias_ih) + packed_hh = quantize_and_pack(weight_hh, bias_hh) + if qRNNBase.version is None or qRNNBase.version < 2: + cell_params = ( + torch.ops.quantized.make_quantized_cell_params_dynamic( + packed_ih, packed_hh, bias_ih, bias_hh + ) + ) + else: + cell_params = ( + torch.ops.quantized.make_quantized_cell_params_dynamic( + packed_ih, packed_hh, bias_ih, bias_hh, True + ) + ) + + elif dtype == torch.float16: + packed_ih = torch.ops.quantized.linear_prepack_fp16( + weight_ih.float(), bias_ih + ) + packed_hh = torch.ops.quantized.linear_prepack_fp16( + weight_hh.float(), bias_hh + ) + + cell_params = torch.ops.quantized.make_quantized_cell_params_fp16( + packed_ih, packed_hh + ) + else: + raise RuntimeError( + "Unsupported dtype specified for dynamic quantized LSTM!" + ) + + _all_weight_values.append(PackedParameter(cell_params)) + qRNNBase._all_weight_values = torch.nn.ModuleList(_all_weight_values) + + return qRNNBase + + def _weight_bias(self): + # Returns a dict of weights and biases + weight_bias_dict: Dict[str, Dict] = {"weight": {}, "bias": {}} + count = 0 + num_directions = 2 if self.bidirectional else 1 + for layer in range(self.num_layers): + for direction in range(num_directions): + suffix = "_reverse" if direction == 1 else "" + key_name1 = f"weight_ih_l{layer}{suffix}" + key_name2 = f"weight_hh_l{layer}{suffix}" + # packed weights are part of torchbind class, CellParamsSerializationType + # Within the packed weight class, the weight and bias are accessible as Tensors + packed_weight_bias = self._all_weight_values[ # type: ignore[index] + count + ].param.__getstate__()[0][4] + weight_bias_dict["weight"][key_name1] = packed_weight_bias[ + 0 + ].__getstate__()[0][0] + weight_bias_dict["weight"][key_name2] = packed_weight_bias[ + 1 + ].__getstate__()[0][0] + key_name1 = f"bias_ih_l{layer}{suffix}" + key_name2 = f"bias_hh_l{layer}{suffix}" + weight_bias_dict["bias"][key_name1] = packed_weight_bias[ + 0 + ].__getstate__()[0][1] + weight_bias_dict["bias"][key_name2] = packed_weight_bias[ + 1 + ].__getstate__()[0][1] + count = count + 1 + return weight_bias_dict + + def get_weight(self): + return self._weight_bias()["weight"] + + def get_bias(self): + return self._weight_bias()["bias"] + + +class LSTM(RNNBase): + r""" + A dynamic quantized LSTM module with floating point tensor as inputs and outputs. + We adopt the same interface as `torch.nn.LSTM`, please see + https://pytorch.org/docs/stable/nn.html#torch.nn.LSTM for documentation. + + Examples:: + + >>> # xdoctest: +SKIP + >>> rnn = nn.LSTM(10, 20, 2) + >>> input = torch.randn(5, 3, 10) + >>> h0 = torch.randn(2, 3, 20) + >>> c0 = torch.randn(2, 3, 20) + >>> output, (hn, cn) = rnn(input, (h0, c0)) + """ + + # pyrefly: ignore [bad-override] + _FLOAT_MODULE = nn.LSTM + + __overloads__ = {"forward": ["forward_packed", "forward_tensor"]} + + def __init__(self, *args, **kwargs): + super().__init__("LSTM", *args, **kwargs) + + def _get_name(self): + return "DynamicQuantizedLSTM" + + def forward_impl( + self, + input: Tensor, + hx: Optional[tuple[Tensor, Tensor]], + batch_sizes: Optional[Tensor], + max_batch_size: int, + sorted_indices: Optional[Tensor], + ) -> tuple[Tensor, tuple[Tensor, Tensor]]: + if hx is None: + num_directions = 2 if self.bidirectional else 1 + zeros = torch.zeros( + self.num_layers * num_directions, + max_batch_size, + self.hidden_size, + dtype=input.dtype, + device=input.device, + ) + hx = (zeros, zeros) + else: + # Each batch of the hidden state should match the input sequence that + # the user believes he/she is passing in. + hx = self.permute_hidden(hx, sorted_indices) + + self.check_forward_args(input, hx, batch_sizes) + + _all_params = [m.param for m in self._all_weight_values] + if batch_sizes is None: + result = torch.quantized_lstm( + input, + hx, + _all_params, + self.bias, + self.num_layers, + float(self.dropout), + self.training, + self.bidirectional, + self.batch_first, + dtype=self.dtype, + use_dynamic=True, + ) + else: + result = torch.quantized_lstm( + input, + batch_sizes, + hx, + _all_params, + self.bias, + self.num_layers, + float(self.dropout), + self.training, + self.bidirectional, + dtype=self.dtype, + use_dynamic=True, + ) + output = result[0] + hidden = result[1:] + + return output, hidden + + @torch.jit.export + def forward_tensor( + self, input: Tensor, hx: Optional[tuple[Tensor, Tensor]] = None + ) -> tuple[Tensor, tuple[Tensor, Tensor]]: + batch_sizes = None + max_batch_size = input.size(0) if self.batch_first else input.size(1) + sorted_indices = None + unsorted_indices = None + + output, hidden = self.forward_impl( + input, hx, batch_sizes, max_batch_size, sorted_indices + ) + + return output, self.permute_hidden(hidden, unsorted_indices) + + @torch.jit.export + def forward_packed( + self, input: PackedSequence, hx: Optional[tuple[Tensor, Tensor]] = None + ) -> tuple[PackedSequence, tuple[Tensor, Tensor]]: + input_, batch_sizes, sorted_indices, unsorted_indices = input + max_batch_size = int(batch_sizes[0]) + + output_, hidden = self.forward_impl( + input_, hx, batch_sizes, max_batch_size, sorted_indices + ) + + output = PackedSequence(output_, batch_sizes, sorted_indices, unsorted_indices) + return output, self.permute_hidden(hidden, unsorted_indices) + + # "type: ignore" is required due to issue #43072 + def permute_hidden( # type: ignore[override] + self, + hx: tuple[Tensor, Tensor], + permutation: Optional[Tensor], + ) -> tuple[Tensor, Tensor]: + if permutation is None: + return hx + return _apply_permutation(hx[0], permutation), _apply_permutation( + hx[1], permutation + ) + + # "type: ignore" is required due to issue #43072 + def check_forward_args( # type: ignore[override] + self, + input: Tensor, + hidden: tuple[Tensor, Tensor], + batch_sizes: Optional[Tensor], + ) -> None: + self.check_input(input, batch_sizes) + expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes) + + self.check_hidden_size( + hidden[0], expected_hidden_size, "Expected hidden[0] size {}, got {}" + ) + self.check_hidden_size( + hidden[1], expected_hidden_size, "Expected hidden[1] size {}, got {}" + ) + + @torch.jit.ignore + def forward(self, input, hx=None): + if isinstance(input, PackedSequence): + return self.forward_packed(input, hx) + else: + return self.forward_tensor(input, hx) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float( + mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + + @classmethod + def from_reference(cls, ref_mod): + assert hasattr(ref_mod, "weight_ih_l0_dtype"), "We are assuming weight_ih_l0 " + "exists in LSTM, may need to relax the assumption to support the use case" + qmod = cls( + ref_mod.input_size, + ref_mod.hidden_size, + ref_mod.num_layers, + ref_mod.bias, + ref_mod.batch_first, + ref_mod.dropout, + ref_mod.bidirectional, + # assuming there is layer 0, which should be OK + ref_mod.weight_ih_l0_dtype, + ) + qmod.set_weight_bias(ref_mod.get_quantized_weight_bias_dict()) + return qmod + + +class GRU(RNNBase): + r"""Applies a multi-layer gated recurrent unit (GRU) RNN to an input sequence. + + + For each element in the input sequence, each layer computes the following + function: + + .. math:: + \begin{array}{ll} + r_t = \sigma(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\ + z_t = \sigma(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\ + n_t = \tanh(W_{in} x_t + b_{in} + r_t \odot (W_{hn} h_{(t-1)}+ b_{hn})) \\ + h_t = (1 - z_t) \odot n_t + z_t \odot h_{(t-1)} + \end{array} + + where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is the input + at time `t`, :math:`h_{(t-1)}` is the hidden state of the layer + at time `t-1` or the initial hidden state at time `0`, and :math:`r_t`, + :math:`z_t`, :math:`n_t` are the reset, update, and new gates, respectively. + :math:`\sigma` is the sigmoid function, and :math:`\odot` is the Hadamard product. + + In a multilayer GRU, the input :math:`x^{(l)}_t` of the :math:`l` -th layer + (:math:`l >= 2`) is the hidden state :math:`h^{(l-1)}_t` of the previous layer multiplied by + dropout :math:`\delta^{(l-1)}_t` where each :math:`\delta^{(l-1)}_t` is a Bernoulli random + variable which is :math:`0` with probability :attr:`dropout`. + + Args: + input_size: The number of expected features in the input `x` + hidden_size: The number of features in the hidden state `h` + num_layers: Number of recurrent layers. E.g., setting ``num_layers=2`` + would mean stacking two GRUs together to form a `stacked GRU`, + with the second GRU taking in outputs of the first GRU and + computing the final results. Default: 1 + bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`. + Default: ``True`` + batch_first: If ``True``, then the input and output tensors are provided + as (batch, seq, feature). Default: ``False`` + dropout: If non-zero, introduces a `Dropout` layer on the outputs of each + GRU layer except the last layer, with dropout probability equal to + :attr:`dropout`. Default: 0 + bidirectional: If ``True``, becomes a bidirectional GRU. Default: ``False`` + + Inputs: input, h_0 + - **input** of shape `(seq_len, batch, input_size)`: tensor containing the features + of the input sequence. The input can also be a packed variable length + sequence. See :func:`torch.nn.utils.rnn.pack_padded_sequence` + for details. + - **h_0** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor + containing the initial hidden state for each element in the batch. + Defaults to zero if not provided. If the RNN is bidirectional, + num_directions should be 2, else it should be 1. + + Outputs: output, h_n + - **output** of shape `(seq_len, batch, num_directions * hidden_size)`: tensor + containing the output features h_t from the last layer of the GRU, + for each `t`. If a :class:`torch.nn.utils.rnn.PackedSequence` has been + given as the input, the output will also be a packed sequence. + For the unpacked case, the directions can be separated + using ``output.view(seq_len, batch, num_directions, hidden_size)``, + with forward and backward being direction `0` and `1` respectively. + + Similarly, the directions can be separated in the packed case. + - **h_n** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor + containing the hidden state for `t = seq_len` + + Like *output*, the layers can be separated using + ``h_n.view(num_layers, num_directions, batch, hidden_size)``. + + Shape: + - Input1: :math:`(L, N, H_{in})` tensor containing input features where + :math:`H_{in}=\text{input\_size}` and `L` represents a sequence length. + - Input2: :math:`(S, N, H_{out})` tensor + containing the initial hidden state for each element in the batch. + :math:`H_{out}=\text{hidden\_size}` + Defaults to zero if not provided. where :math:`S=\text{num\_layers} * \text{num\_directions}` + If the RNN is bidirectional, num_directions should be 2, else it should be 1. + - Output1: :math:`(L, N, H_{all})` where :math:`H_{all}=\text{num\_directions} * \text{hidden\_size}` + - Output2: :math:`(S, N, H_{out})` tensor containing the next hidden state + for each element in the batch + + Attributes: + weight_ih_l[k] : the learnable input-hidden weights of the :math:`\text{k}^{th}` layer + (W_ir|W_iz|W_in), of shape `(3*hidden_size, input_size)` for `k = 0`. + Otherwise, the shape is `(3*hidden_size, num_directions * hidden_size)` + weight_hh_l[k] : the learnable hidden-hidden weights of the :math:`\text{k}^{th}` layer + (W_hr|W_hz|W_hn), of shape `(3*hidden_size, hidden_size)` + bias_ih_l[k] : the learnable input-hidden bias of the :math:`\text{k}^{th}` layer + (b_ir|b_iz|b_in), of shape `(3*hidden_size)` + bias_hh_l[k] : the learnable hidden-hidden bias of the :math:`\text{k}^{th}` layer + (b_hr|b_hz|b_hn), of shape `(3*hidden_size)` + + .. note:: + All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` + where :math:`k = \frac{1}{\text{hidden\_size}}` + + .. note:: + The calculation of new gate :math:`n_t` subtly differs from the original paper and other frameworks. + In the original implementation, the Hadamard product :math:`(\odot)` between :math:`r_t` and the + previous hidden state :math:`h_{(t-1)}` is done before the multiplication with the weight matrix + `W` and addition of bias: + + .. math:: + \begin{aligned} + n_t = \tanh(W_{in} x_t + b_{in} + W_{hn} ( r_t \odot h_{(t-1)} ) + b_{hn}) + \end{aligned} + + This is in contrast to PyTorch implementation, which is done after :math:`W_{hn} h_{(t-1)}` + + .. math:: + \begin{aligned} + n_t = \tanh(W_{in} x_t + b_{in} + r_t \odot (W_{hn} h_{(t-1)}+ b_{hn})) + \end{aligned} + + This implementation differs on purpose for efficiency. + + .. include:: ../cudnn_persistent_rnn.rst + + Examples:: + + >>> # xdoctest: +SKIP + >>> rnn = nn.GRU(10, 20, 2) + >>> input = torch.randn(5, 3, 10) + >>> h0 = torch.randn(2, 3, 20) + >>> output, hn = rnn(input, h0) + """ + + # pyrefly: ignore [bad-override] + _FLOAT_MODULE = nn.GRU + + __overloads__ = {"forward": ["forward_packed", "forward_tensor"]} + + def __init__(self, *args, **kwargs): + super().__init__("GRU", *args, **kwargs) + + def _get_name(self): + return "DynamicQuantizedGRU" + + def check_forward_args( + self, input: Tensor, hidden: Tensor, batch_sizes: Optional[Tensor] + ) -> None: + self.check_input(input, batch_sizes) + expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes) + + self.check_hidden_size( + hidden, expected_hidden_size, "Expected hidden size {}, got {}" + ) + + def forward_impl( + self, + input: Tensor, + hx: Optional[Tensor], + batch_sizes: Optional[Tensor], + max_batch_size: int, + sorted_indices: Optional[Tensor], + ) -> tuple[Tensor, Tensor]: + if hx is None: + num_directions = 2 if self.bidirectional else 1 + zeros = torch.zeros( + self.num_layers * num_directions, + max_batch_size, + self.hidden_size, + dtype=input.dtype, + device=input.device, + ) + hx = zeros + else: + # Each batch of the hidden state should match the input sequence that + # the user believes he/she is passing in. + hx = self.permute_hidden(hx, sorted_indices) + + self.check_forward_args(input, hx, batch_sizes) + + _all_params = [m.param for m in self._all_weight_values] + if batch_sizes is None: + result = torch.quantized_gru( + input, + hx, + _all_params, + self.bias, + self.num_layers, + self.dropout, + self.training, + self.bidirectional, + self.batch_first, + ) + else: + result = torch.quantized_gru( + input, + batch_sizes, + hx, + _all_params, + self.bias, + self.num_layers, + self.dropout, + self.training, + self.bidirectional, + ) + output = result[0] + hidden = result[1] + + return output, hidden + + @torch.jit.export + def forward_tensor( + self, input: Tensor, hx: Optional[Tensor] = None + ) -> tuple[Tensor, Tensor]: + batch_sizes = None + max_batch_size = input.size(0) if self.batch_first else input.size(1) + sorted_indices = None + unsorted_indices = None + + output, hidden = self.forward_impl( + input, hx, batch_sizes, max_batch_size, sorted_indices + ) + + return output, self.permute_hidden(hidden, unsorted_indices) + + @torch.jit.export + def forward_packed( + self, input: PackedSequence, hx: Optional[Tensor] = None + ) -> tuple[PackedSequence, Tensor]: + input_, batch_sizes, sorted_indices, unsorted_indices = input + max_batch_size = int(batch_sizes[0]) + output_, hidden = self.forward_impl( + input_, hx, batch_sizes, max_batch_size, sorted_indices + ) + + output = PackedSequence(output_, batch_sizes, sorted_indices, unsorted_indices) + return output, self.permute_hidden(hidden, unsorted_indices) + + def permute_hidden(self, hx: Tensor, permutation: Optional[Tensor]) -> Tensor: + if permutation is None: + return hx + return _apply_permutation(hx, permutation) + + @torch.jit.ignore + def forward(self, input, hx=None): + if isinstance(input, PackedSequence): + return self.forward_packed(input, hx) + else: + return self.forward_tensor(input, hx) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float( + mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + + @classmethod + def from_reference(cls, ref_mod): + assert hasattr(ref_mod, "weight_ih_l0_dtype"), "We are assuming weight_ih_l0 " + "exists in LSTM, may need to relax the assumption to support the use case" + qmod = cls( + ref_mod.input_size, + ref_mod.hidden_size, + ref_mod.num_layers, + ref_mod.bias, + ref_mod.batch_first, + ref_mod.dropout, + ref_mod.bidirectional, + # assuming there is layer 0, which should be OK + ref_mod.weight_ih_l0_dtype, + ) + qmod.set_weight_bias(ref_mod.get_quantized_weight_bias_dict()) + return qmod + + +class RNNCellBase(torch.nn.Module): + # _FLOAT_MODULE = nn.CellRNNBase + __constants__ = ["input_size", "hidden_size", "bias"] + + def __init__( + self, input_size, hidden_size, bias=True, num_chunks=4, dtype=torch.qint8 + ): + super().__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.bias = bias + self.weight_dtype = dtype + if bias: + self.bias_ih = torch.randn(num_chunks * hidden_size).to(dtype=torch.float) + self.bias_hh = torch.randn(num_chunks * hidden_size).to(dtype=torch.float) + else: + self.register_parameter("bias_ih", None) + self.register_parameter("bias_hh", None) + + weight_ih = torch.randn(num_chunks * hidden_size, input_size).to(torch.float) + weight_hh = torch.randn(num_chunks * hidden_size, hidden_size).to(torch.float) + if dtype == torch.qint8: + weight_ih = torch.quantize_per_tensor( + weight_ih, scale=1, zero_point=0, dtype=torch.qint8 + ) + weight_hh = torch.quantize_per_tensor( + weight_hh, scale=1, zero_point=0, dtype=torch.qint8 + ) + + if dtype == torch.qint8: + # for each layer, for each direction we need to quantize and pack + # weights and pack parameters in this order: + # + # w_ih, w_hh + packed_weight_ih = torch.ops.quantized.linear_prepack( + weight_ih, self.bias_ih + ) + packed_weight_hh = torch.ops.quantized.linear_prepack( + weight_hh, self.bias_hh + ) + else: + # for each layer, for each direction we need to quantize and pack + # weights and pack parameters in this order: + # + # packed_ih, packed_hh, b_ih, b_hh + packed_weight_ih = torch.ops.quantized.linear_prepack_fp16( + weight_ih, self.bias_ih + ) + packed_weight_hh = torch.ops.quantized.linear_prepack_fp16( + weight_hh, self.bias_hh + ) + + self._packed_weight_ih = packed_weight_ih + self._packed_weight_hh = packed_weight_hh + + def _get_name(self): + return "DynamicQuantizedRNNBase" + + def extra_repr(self): + s = "{input_size}, {hidden_size}" + if "bias" in self.__dict__ and self.bias is not True: + s += ", bias={bias}" + if "nonlinearity" in self.__dict__ and self.nonlinearity != "tanh": + s += ", nonlinearity={nonlinearity}" + return s.format(**self.__dict__) + + def check_forward_input(self, input): + if input.size(1) != self.input_size: + raise RuntimeError( + f"input has inconsistent input_size: got {input.size(1)}, expected {self.input_size}" + ) + + def check_forward_hidden( + self, input: Tensor, hx: Tensor, hidden_label: str = "" + ) -> None: + if input.size(0) != hx.size(0): + raise RuntimeError( + f"Input batch size {input.size(0)} doesn't match hidden{hidden_label} batch size {hx.size(0)}" + ) + + if hx.size(1) != self.hidden_size: + raise RuntimeError( + f"hidden{hidden_label} has inconsistent hidden_size: got {hx.size(1)}, expected {self.hidden_size}" + ) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + assert type(mod) in { + torch.nn.LSTMCell, + torch.nn.GRUCell, + torch.nn.RNNCell, + }, ( + "nn.quantized.dynamic.RNNCellBase.from_float \ + only works for nn.LSTMCell, nn.GRUCell and nn.RNNCell" + ) + assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" + + if mod.qconfig is not None and mod.qconfig.weight is not None: + weight_observer_method = mod.qconfig.weight + else: + # We have the circular import issues if we import the qconfig in the beginning of this file: + # https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the + # import until we need it. + from torch.ao.quantization.qconfig import default_dynamic_qconfig + + weight_observer_method = default_dynamic_qconfig.weight + + dtype = weight_observer_method().dtype + supported_scalar_types = [torch.qint8, torch.float16] + if dtype not in supported_scalar_types: + raise RuntimeError( + f"Unsupported dtype for dynamic RNN quantization: {dtype}" + ) + + qRNNCellBase: Union[LSTMCell, GRUCell, RNNCell] + + if type(mod) is torch.nn.LSTMCell: + qRNNCellBase = LSTMCell( + mod.input_size, mod.hidden_size, bias=mod.bias, dtype=dtype + ) + elif type(mod) is torch.nn.GRUCell: + qRNNCellBase = GRUCell( + mod.input_size, mod.hidden_size, bias=mod.bias, dtype=dtype + ) + elif type(mod) is torch.nn.RNNCell: + qRNNCellBase = RNNCell( + mod.input_size, + mod.hidden_size, + bias=mod.bias, + nonlinearity=mod.nonlinearity, + dtype=dtype, + ) + else: + raise NotImplementedError( + "Only LSTMCell, GRUCell and RNNCell \ + are supported for QuantizedRNN for now" + ) + + assert mod.bias + + def _observe_and_quantize_weight(weight): + if dtype == torch.qint8: + weight_observer = weight_observer_method() + weight_observer(weight) + qweight = _quantize_weight(weight.float(), weight_observer) + return qweight + else: + return weight.float() + + qRNNCellBase._packed_weight_ih = pack_weight_bias( + _observe_and_quantize_weight(mod.weight_ih), mod.bias_ih, dtype + ) + qRNNCellBase._packed_weight_hh = pack_weight_bias( + _observe_and_quantize_weight(mod.weight_hh), mod.bias_hh, dtype + ) + return qRNNCellBase + + @classmethod + def from_reference(cls, ref_mod): + assert hasattr(ref_mod, "weight_ih_dtype"), "We are assuming weight_ih " + "exists in reference module, may need to relax the assumption to support the use case" + if hasattr(ref_mod, "nonlinearity"): + qmod = cls( + ref_mod.input_size, + ref_mod.hidden_size, + ref_mod.bias, + ref_mod.nonlinearity, + dtype=ref_mod.weight_ih_dtype, + ) + else: + qmod = cls( + ref_mod.input_size, + ref_mod.hidden_size, + ref_mod.bias, + dtype=ref_mod.weight_ih_dtype, + ) + weight_bias_dict = { + "weight": { + "weight_ih": ref_mod.get_quantized_weight_ih(), + "weight_hh": ref_mod.get_quantized_weight_hh(), + }, + "bias": { + "bias_ih": ref_mod.bias_ih, + "bias_hh": ref_mod.bias_hh, + }, + } + qmod.set_weight_bias(weight_bias_dict) + return qmod + + def _weight_bias(self): + # Returns a dict of weights and biases + weight_bias_dict: Dict[str, Dict] = {"weight": {}, "bias": {}} + w1, b1 = self._packed_weight_ih.__getstate__()[0] + w2, b2 = self._packed_weight_hh.__getstate__()[0] + # TODO: these can be simplified to one level? e.g. using weight_ih as key + # directly + weight_bias_dict["weight"]["weight_ih"] = w1 + weight_bias_dict["weight"]["weight_hh"] = w2 + weight_bias_dict["bias"]["bias_ih"] = b1 + weight_bias_dict["bias"]["bias_hh"] = b2 + return weight_bias_dict + + def get_weight(self): + return self._weight_bias()["weight"] + + def get_bias(self): + return self._weight_bias()["bias"] + + def set_weight_bias(self, weight_bias_dict): + # TODO: these can be simplified to one level? e.g. using weight_ih as key + # directly + self._packed_weight_ih = pack_weight_bias( + weight_bias_dict["weight"]["weight_ih"], + weight_bias_dict["bias"]["bias_ih"], + self.weight_dtype, + ) + self._packed_weight_hh = pack_weight_bias( + weight_bias_dict["weight"]["weight_hh"], + weight_bias_dict["bias"]["bias_hh"], + self.weight_dtype, + ) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + super()._save_to_state_dict(destination, prefix, keep_vars) + destination[prefix + "_packed_weight_ih"] = self._packed_weight_ih + destination[prefix + "_packed_weight_hh"] = self._packed_weight_hh + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + self._packed_weight_ih = state_dict.pop(prefix + "_packed_weight_ih") + self._packed_weight_hh = state_dict.pop(prefix + "_packed_weight_hh") + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + False, + missing_keys, + unexpected_keys, + error_msgs, + ) + + +class RNNCell(RNNCellBase): + r"""An Elman RNN cell with tanh or ReLU non-linearity. + A dynamic quantized RNNCell module with floating point tensor as inputs and outputs. + Weights are quantized to 8 bits. We adopt the same interface as `torch.nn.RNNCell`, + please see https://pytorch.org/docs/stable/nn.html#torch.nn.RNNCell for documentation. + + Examples:: + + >>> # xdoctest: +SKIP + >>> rnn = nn.RNNCell(10, 20) + >>> input = torch.randn(6, 3, 10) + >>> hx = torch.randn(3, 20) + >>> output = [] + >>> for i in range(6): + ... hx = rnn(input[i], hx) + ... output.append(hx) + """ + + __constants__ = ["input_size", "hidden_size", "bias", "nonlinearity"] + + def __init__( + self, input_size, hidden_size, bias=True, nonlinearity="tanh", dtype=torch.qint8 + ): + super().__init__(input_size, hidden_size, bias, num_chunks=1, dtype=dtype) + self.nonlinearity = nonlinearity + + def _get_name(self): + return "DynamicQuantizedRNNCell" + + def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: + self.check_forward_input(input) + if hx is None: + hx = torch.zeros( + input.size(0), self.hidden_size, dtype=input.dtype, device=input.device + ) + self.check_forward_hidden(input, hx, "") + if self.nonlinearity == "tanh": + ret = torch.ops.quantized.quantized_rnn_tanh_cell_dynamic( + input, + hx, + self._packed_weight_ih, + self._packed_weight_hh, + self.bias_ih, + self.bias_hh, + ) + elif self.nonlinearity == "relu": + ret = torch.ops.quantized.quantized_rnn_relu_cell_dynamic( + input, + hx, + self._packed_weight_ih, + self._packed_weight_hh, + self.bias_ih, + self.bias_hh, + ) + else: + ret = input # TODO: remove when jit supports exception flow + raise RuntimeError(f"Unknown nonlinearity: {self.nonlinearity}") + return ret + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float( + mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + + +class LSTMCell(RNNCellBase): + r"""A long short-term memory (LSTM) cell. + + A dynamic quantized LSTMCell module with floating point tensor as inputs and outputs. + Weights are quantized to 8 bits. We adopt the same interface as `torch.nn.LSTMCell`, + please see https://pytorch.org/docs/stable/nn.html#torch.nn.LSTMCell for documentation. + + Examples:: + + >>> # xdoctest: +SKIP + >>> rnn = nn.LSTMCell(10, 20) + >>> input = torch.randn(6, 3, 10) + >>> hx = torch.randn(3, 20) + >>> cx = torch.randn(3, 20) + >>> output = [] + >>> for i in range(6): + ... hx, cx = rnn(input[i], (hx, cx)) + ... output.append(hx) + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, num_chunks=4, **kwargs) # type: ignore[misc] + + def _get_name(self): + return "DynamicQuantizedLSTMCell" + + def forward( + self, input: Tensor, hx: Optional[tuple[Tensor, Tensor]] = None + ) -> tuple[Tensor, Tensor]: + self.check_forward_input(input) + if hx is None: + zeros = torch.zeros( + input.size(0), self.hidden_size, dtype=input.dtype, device=input.device + ) + hx = (zeros, zeros) + self.check_forward_hidden(input, hx[0], "[0]") + self.check_forward_hidden(input, hx[1], "[1]") + return torch.ops.quantized.quantized_lstm_cell_dynamic( + input, + hx, + self._packed_weight_ih, + self._packed_weight_hh, + self.bias_ih, + self.bias_hh, + ) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float( + mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + + +class GRUCell(RNNCellBase): + r"""A gated recurrent unit (GRU) cell + + A dynamic quantized GRUCell module with floating point tensor as inputs and outputs. + Weights are quantized to 8 bits. We adopt the same interface as `torch.nn.GRUCell`, + please see https://pytorch.org/docs/stable/nn.html#torch.nn.GRUCell for documentation. + + Examples:: + + >>> # xdoctest: +SKIP + >>> rnn = nn.GRUCell(10, 20) + >>> input = torch.randn(6, 3, 10) + >>> hx = torch.randn(3, 20) + >>> output = [] + >>> for i in range(6): + ... hx = rnn(input[i], hx) + ... output.append(hx) + """ + + def __init__(self, input_size, hidden_size, bias=True, dtype=torch.qint8): + super().__init__(input_size, hidden_size, bias, num_chunks=3, dtype=dtype) + + def _get_name(self): + return "DynamicQuantizedGRUCell" + + def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: + self.check_forward_input(input) + if hx is None: + hx = torch.zeros( + input.size(0), self.hidden_size, dtype=input.dtype, device=input.device + ) + self.check_forward_hidden(input, hx, "") + return torch.ops.quantized.quantized_gru_cell_dynamic( + input, + hx, + self._packed_weight_ih, + self._packed_weight_hh, + self.bias_ih, + self.bias_hh, + ) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float( + mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/functional.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/functional.py new file mode 100644 index 0000000000000000000000000000000000000000..f84d41b58503ad1d86244c7aa358f09ad16acad2 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/functional.py @@ -0,0 +1,781 @@ +# mypy: allow-untyped-defs +r"""Functional interface (quantized).""" + +import warnings + +import torch +from torch import Tensor +from torch.jit.annotations import BroadcastingList2 +from torch.nn.modules.utils import _pair, _triple + +from .modules.utils import _pair_from_first + + +# Although some of the functions and docstrings are mirrored from the torch.nn, +# we want to have them here for future changes. + +__all__ = [ + "avg_pool2d", + "avg_pool3d", + "adaptive_avg_pool2d", + "adaptive_avg_pool3d", + "conv1d", + "conv2d", + "conv3d", + "interpolate", + "linear", + "max_pool1d", + "max_pool2d", + "celu", + "leaky_relu", + "hardtanh", + "hardswish", + "threshold", + "elu", + "hardsigmoid", + "clamp", + "upsample", + "upsample_bilinear", + "upsample_nearest", +] + + +def avg_pool2d( + input, + kernel_size, + stride=None, + padding=0, + ceil_mode=False, + count_include_pad=True, + divisor_override=None, +): + r""" + Applies 2D average-pooling operation in :math:`kH \times kW` regions by step size + :math:`sH \times sW` steps. The number of output features is equal to the number of + input planes. + + .. note:: The input quantization parameters propagate to the output. + + See :class:`~torch.ao.nn.quantized.AvgPool2d` for details and output shape. + + Args: + input: quantized input tensor :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)` + kernel_size: size of the pooling region. Can be a single number or a + tuple `(kH, kW)` + stride: stride of the pooling operation. Can be a single number or a + tuple `(sH, sW)`. Default: :attr:`kernel_size` + padding: implicit zero paddings on both sides of the input. Can be a + single number or a tuple `(padH, padW)`. Default: 0 + ceil_mode: when True, will use `ceil` instead of `floor` in the formula + to compute the output shape. Default: ``False`` + count_include_pad: when True, will include the zero-padding in the + averaging calculation. Default: ``True`` + divisor_override: if specified, it will be used as divisor, otherwise + size of the pooling region will be used. Default: None + """ + if not input.is_quantized: + raise ValueError("Input to 'quantized.avg_pool2d' must be quantized!") + return torch.nn.functional.avg_pool2d( + input, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + ) + + +def avg_pool3d( + input, + kernel_size, + stride=None, + padding=0, + ceil_mode=False, + count_include_pad=True, + divisor_override=None, +): + r""" + Applies 3D average-pooling operation in :math:`kD \ times kH \times kW` regions by step size + :math:`sD \times sH \times sW` steps. The number of output features is equal to the number of + input planes. + + .. note:: The input quantization parameters propagate to the output. + + Args: + input: quantized input tensor :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)` + kernel_size: size of the pooling region. Can be a single number or a + tuple `(kD, kH, kW)` + stride: stride of the pooling operation. Can be a single number or a + tuple `(sD, sH, sW)`. Default: :attr:`kernel_size` + padding: implicit zero paddings on both sides of the input. Can be a + single number or a tuple `(padD, padH, padW)`. Default: 0 + ceil_mode: when True, will use `ceil` instead of `floor` in the formula + to compute the output shape. Default: ``False`` + count_include_pad: when True, will include the zero-padding in the + averaging calculation. Default: ``True`` + divisor_override: if specified, it will be used as divisor, otherwise + size of the pooling region will be used. Default: None + """ + if not input.is_quantized: + raise ValueError("Input to 'quantized.avg_pool3d' must be quantized!") + return torch.nn.functional.avg_pool3d( + input, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + ) + + +def adaptive_avg_pool2d(input: Tensor, output_size: BroadcastingList2[int]) -> Tensor: + r""" + Applies a 2D adaptive average pooling over a quantized input signal composed + of several quantized input planes. + + .. note:: The input quantization parameters propagate to the output. + + See :class:`~torch.ao.nn.quantized.AdaptiveAvgPool2d` for details and output shape. + + Args: + output_size: the target output size (single integer or + double-integer tuple) + """ + if not input.is_quantized: + raise ValueError( + "Input to 'quantized.functional.adaptive_avg_pool2d' must be quantized!" + ) + return torch.nn.functional.adaptive_avg_pool2d(input, output_size) + + +def adaptive_avg_pool3d(input: Tensor, output_size: BroadcastingList2[int]) -> Tensor: + r""" + Applies a 3D adaptive average pooling over a quantized input signal composed + of several quantized input planes. + + .. note:: The input quantization parameters propagate to the output. + + See :class:`~torch.ao.nn.quantized.AdaptiveAvgPool3d` for details and output shape. + + Args: + output_size: the target output size (single integer or + double-integer tuple) + """ + if not input.is_quantized: + raise ValueError( + "Input to 'quantized.functional.adaptive_avg_pool3d' must be quantized!" + ) + return torch.nn.functional.adaptive_avg_pool3d(input, output_size) + + +def conv1d( + input, + weight, + bias, + stride=1, + padding=0, + dilation=1, + groups=1, + padding_mode="zeros", + scale=1.0, + zero_point=0, + dtype=torch.quint8, +): + r""" + Applies a 1D convolution over a quantized 1D input composed of several input + planes. + + See :class:`~torch.ao.nn.quantized.Conv1d` for details and output shape. + + Args: + input: quantized input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)` + weight: quantized filters of shape :math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , iW)` + bias: **non-quantized** bias tensor of shape :math:`(\text{out\_channels})`. The tensor type must be `torch.float`. + stride: the stride of the convolving kernel. Can be a single number or a + tuple `(sW,)`. Default: 1 + padding: implicit paddings on both sides of the input. Can be a + single number or a tuple `(padW,)`. Default: 0 + dilation: the spacing between kernel elements. Can be a single number or + a tuple `(dW,)`. Default: 1 + groups: split input into groups, :math:`\text{in\_channels}` should be divisible by the + number of groups. Default: 1 + padding_mode: the padding mode to use. Only "zeros" is supported for quantized convolution at the moment. Default: "zeros" + scale: quantization scale for the output. Default: 1.0 + zero_point: quantization zero_point for the output. Default: 0 + dtype: quantization data type to use. Default: ``torch.quint8`` + + Examples:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE) + >>> from torch.ao.nn.quantized import functional as qF + >>> filters = torch.randn(33, 16, 3, dtype=torch.float) + >>> inputs = torch.randn(20, 16, 50, dtype=torch.float) + >>> bias = torch.randn(33, dtype=torch.float) + >>> + >>> scale, zero_point = 1.0, 0 + >>> dtype_inputs = torch.quint8 + >>> dtype_filters = torch.qint8 + >>> + >>> q_filters = torch.quantize_per_tensor(filters, scale, zero_point, dtype_filters) + >>> q_inputs = torch.quantize_per_tensor(inputs, scale, zero_point, dtype_inputs) + >>> qF.conv1d(q_inputs, q_filters, bias, padding=1, scale=scale, zero_point=zero_point) + """ # noqa: E501 + if padding_mode != "zeros": + raise NotImplementedError("Only zero-padding is supported!") + if input.dtype != torch.quint8: + raise NotImplementedError( + "Only torch.quint8 is supported for activation tensor!" + ) + if weight.dtype != torch.qint8: + raise NotImplementedError("Only torch.qint8 is supported for weight tensor!") + if input.ndim != 3: + raise ValueError("Input shape must be `(N, C, L)`!") + stride = _pair_from_first(stride) + padding = _pair_from_first(padding) + dilation = _pair_from_first(dilation) + + packed_params = torch.ops.quantized.conv1d_prepack( + weight, bias, stride, padding, dilation, groups + ) + return torch.ops.quantized.conv1d(input, packed_params, scale, zero_point) + + +def conv2d( + input, + weight, + bias, + stride=1, + padding=0, + dilation=1, + groups=1, + padding_mode="zeros", + scale=1.0, + zero_point=0, + dtype=torch.quint8, +): + r""" + Applies a 2D convolution over a quantized 2D input composed of several input + planes. + + See :class:`~torch.ao.nn.quantized.Conv2d` for details and output shape. + + Args: + input: quantized input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)` + weight: quantized filters of shape :math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , kH , kW)` + bias: **non-quantized** bias tensor of shape :math:`(\text{out\_channels})`. The tensor type must be `torch.float`. + stride: the stride of the convolving kernel. Can be a single number or a + tuple `(sH, sW)`. Default: 1 + padding: implicit paddings on both sides of the input. Can be a + single number or a tuple `(padH, padW)`. Default: 0 + dilation: the spacing between kernel elements. Can be a single number or + a tuple `(dH, dW)`. Default: 1 + groups: split input into groups, :math:`\text{in\_channels}` should be divisible by the + number of groups. Default: 1 + padding_mode: the padding mode to use. Only "zeros" is supported for quantized convolution at the moment. Default: "zeros" + scale: quantization scale for the output. Default: 1.0 + zero_point: quantization zero_point for the output. Default: 0 + dtype: quantization data type to use. Default: ``torch.quint8`` + + Examples:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE) + >>> from torch.ao.nn.quantized import functional as qF + >>> filters = torch.randn(8, 4, 3, 3, dtype=torch.float) + >>> inputs = torch.randn(1, 4, 5, 5, dtype=torch.float) + >>> bias = torch.randn(8, dtype=torch.float) + >>> + >>> scale, zero_point = 1.0, 0 + >>> dtype_inputs = torch.quint8 + >>> dtype_filters = torch.qint8 + >>> + >>> q_filters = torch.quantize_per_tensor(filters, scale, zero_point, dtype_filters) + >>> q_inputs = torch.quantize_per_tensor(inputs, scale, zero_point, dtype_inputs) + >>> qF.conv2d(q_inputs, q_filters, bias, padding=1, scale=scale, zero_point=zero_point) + """ # noqa: E501 + if padding_mode != "zeros": + raise NotImplementedError("Only zero-padding is supported!") + if input.dtype != torch.quint8: + raise NotImplementedError( + "Only torch.quint8 is supported for activation tensor!" + ) + if weight.dtype != torch.qint8: + raise NotImplementedError("Only torch.qint8 is supported for weight tensor!") + if input.ndim != 4: + raise ValueError("Input shape must be `(N, C, H, W)`!") + stride = _pair(stride) + padding = _pair(padding) + dilation = _pair(dilation) + + packed_params = torch.ops.quantized.conv2d_prepack( + weight, bias, stride, padding, dilation, groups + ) + return torch.ops.quantized.conv2d(input, packed_params, scale, zero_point) + + +def conv3d( + input, + weight, + bias, + stride=1, + padding=0, + dilation=1, + groups=1, + padding_mode="zeros", + scale=1.0, + zero_point=0, + dtype=torch.quint8, +): + r""" + Applies a 3D convolution over a quantized 3D input composed of several input + planes. + + See :class:`~torch.ao.nn.quantized.Conv3d` for details and output shape. + + Args: + input: quantized input tensor of shape + :math:`(\text{minibatch} , \text{in\_channels} , iD , iH , iW)` + weight: quantized filters of shape + :math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , kD , kH , kW)` + bias: **non-quantized** bias tensor of shape + :math:`(\text{out\_channels})`. The tensor type must be `torch.float`. + stride: the stride of the convolving kernel. Can be a single number or a + tuple `(sD, sH, sW)`. Default: 1 + padding: implicit paddings on both sides of the input. Can be a + single number or a tuple `(padD, padH, padW)`. Default: 0 + dilation: the spacing between kernel elements. Can be a single number or + a tuple `(dD, dH, dW)`. Default: 1 + groups: split input into groups, :math:`\text{in\_channels}` should be + divisible by the number of groups. Default: 1 + padding_mode: the padding mode to use. Only "zeros" is supported for + quantized convolution at the moment. Default: "zeros" + scale: quantization scale for the output. Default: 1.0 + zero_point: quantization zero_point for the output. Default: 0 + dtype: quantization data type to use. Default: ``torch.quint8`` + + Examples:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE) + >>> from torch.ao.nn.quantized import functional as qF + >>> filters = torch.randn(8, 4, 3, 3, 3, dtype=torch.float) + >>> inputs = torch.randn(1, 4, 5, 5, 5, dtype=torch.float) + >>> bias = torch.randn(8, dtype=torch.float) + >>> + >>> scale, zero_point = 1.0, 0 + >>> dtype_inputs = torch.quint8 + >>> dtype_filters = torch.qint8 + >>> + >>> q_filters = torch.quantize_per_tensor(filters, scale, zero_point, dtype_filters) + >>> q_inputs = torch.quantize_per_tensor(inputs, scale, zero_point, dtype_inputs) + >>> qF.conv3d(q_inputs, q_filters, bias, padding=1, scale=scale, zero_point=zero_point) + """ # noqa: E501 + if padding_mode != "zeros": + raise NotImplementedError("Only zero-padding is supported!") + if input.dtype != torch.quint8: + raise NotImplementedError( + "Only torch.quint8 is supported for activation tensor!" + ) + if weight.dtype != torch.qint8: + raise NotImplementedError("Only torch.qint8 is supported for weight tensor!") + if input.ndim != 5: + raise ValueError("Input shape must be `(N, C, D, H, W)`!") + stride = _triple(stride) + padding = _triple(padding) + dilation = _triple(dilation) + + packed_params = torch.ops.quantized.conv3d_prepack( + weight, bias, stride, padding, dilation, groups + ) + return torch.ops.quantized.conv3d(input, packed_params, scale, zero_point) + + +def interpolate( + input, size=None, scale_factor=None, mode="nearest", align_corners=None +): + r"""Down/up samples the input to either the given :attr:`size` or the given + :attr:`scale_factor` + + See :func:`torch.nn.functional.interpolate` for implementation details. + + The input dimensions are interpreted in the form: + `mini-batch x channels x [optional depth] x [optional height] x width`. + + .. note:: The input quantization parameters propagate to the output. + + .. note:: Only 2D/3D input is supported for quantized inputs + + .. note:: Only the following modes are supported for the quantized inputs: + + - `bilinear` + - `nearest` + + Args: + input (Tensor): the input tensor + size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]): + output spatial size. + scale_factor (float or Tuple[float]): multiplier for spatial size. Has to match input size if it is a tuple. + mode (str): algorithm used for upsampling: + ``'nearest'`` | ``'bilinear'`` + align_corners (bool, optional): Geometrically, we consider the pixels of the + input and output as squares rather than points. + If set to ``True``, the input and output tensors are aligned by the + center points of their corner pixels, preserving the values at the corner pixels. + If set to ``False``, the input and output tensors are aligned by the corner + points of their corner pixels, and the interpolation uses edge value padding + for out-of-boundary values, making this operation *independent* of input size + when :attr:`scale_factor` is kept the same. This only has an effect when :attr:`mode` + is ``'bilinear'``. + Default: ``False`` + """ + if not input.is_quantized: + raise ValueError("Input to 'quantized.interpolate' must be quantized!") + return torch.nn.functional.interpolate( + input, size, scale_factor, mode, align_corners + ) + + +def linear( + input: Tensor, + weight: Tensor, + bias: Tensor | None = None, + scale: float | None = None, + zero_point: int | None = None, +) -> Tensor: + r""" + Applies a linear transformation to the incoming quantized data: + :math:`y = xA^T + b`. + See :class:`~torch.ao.nn.quantized.Linear` + + .. note:: + + Current implementation packs weights on every call, which has penalty on performance. + If you want to avoid the overhead, use :class:`~torch.ao.nn.quantized.Linear`. + + Args: + input (Tensor): Quantized input of type `torch.quint8` + weight (Tensor): Quantized weight of type `torch.qint8` + bias (Tensor): None or fp32 bias of type `torch.float` + scale (double): output scale. If None, derived from the input scale + zero_point (long): output zero point. If None, derived from the input zero_point + + Shape: + - Input: :math:`(N, *, in\_features)` where `*` means any number of + additional dimensions + - Weight: :math:`(out\_features, in\_features)` + - Bias: :math:`(out\_features)` + - Output: :math:`(N, *, out\_features)` + """ + if scale is None: + scale = input.q_scale() + if zero_point is None: + zero_point = input.q_zero_point() + _packed_params = torch.ops.quantized.linear_prepack(weight, bias) + return torch.ops.quantized.linear(input, _packed_params, scale, zero_point) + + +def max_pool1d( + input, + kernel_size, + stride=None, + padding=0, + dilation=1, + ceil_mode=False, + return_indices=False, +): + r"""Applies a 1D max pooling over a quantized input signal composed of + several quantized input planes. + + .. note:: The input quantization parameters are propagated to the output. + + See :class:`~torch.ao.nn.quantized.MaxPool1d` for details. + """ + if return_indices: + raise NotImplementedError("return_indices is not yet implemented!") + if stride is None: + stride = torch.jit.annotate(list[int], []) + return torch.nn.functional.max_pool1d( + input, + kernel_size, + stride, + padding, + dilation, + ceil_mode=ceil_mode, + return_indices=return_indices, + ) + + +def max_pool2d( + input, + kernel_size, + stride=None, + padding=0, + dilation=1, + ceil_mode=False, + return_indices=False, +): + r"""Applies a 2D max pooling over a quantized input signal composed of + several quantized input planes. + + .. note:: The input quantization parameters are propagated to the output. + + See :class:`~torch.ao.nn.quantized.MaxPool2d` for details. + """ + if return_indices: + raise NotImplementedError("return_indices is not yet implemented!") + if stride is None: + stride = torch.jit.annotate(list[int], []) + return torch.nn.functional.max_pool2d( + input, + kernel_size, + stride, + padding, + dilation, + ceil_mode=ceil_mode, + return_indices=return_indices, + ) + + +def celu(input: Tensor, scale: float, zero_point: int, alpha: float = 1.0) -> Tensor: + r"""celu(input, scale, zero_point, alpha=1.) -> Tensor + + Applies the quantized CELU function element-wise. + + .. math:: + \text{CELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x / \alpha) - 1)) + + Args: + input: quantized input + alpha: the :math:`\alpha` value for the CELU formulation. Default: 1.0 + """ + if not input.is_quantized: + raise ValueError("Input to 'quantized.celu' must be quantized!") + return torch.ops.quantized.celu(input, scale, zero_point, alpha) + + +def leaky_relu( + input: Tensor, + negative_slope: float = 0.01, + inplace: bool = False, + scale: float | None = None, + zero_point: int | None = None, +): + r""" + Quantized version of the. + leaky_relu(input, negative_slope=0.01, inplace=False, scale, zero_point) -> Tensor + + Applies element-wise, + :math:`\text{LeakyReLU}(x) = \max(0, x) + \text{negative\_slope} * \min(0, x)` + + Args: + input: Quantized input + negative_slope: The slope of the negative input + inplace: Inplace modification of the input tensor + scale, zero_point: Scale and zero point of the output tensor. + + See :class:`~torch.nn.LeakyReLU` for more details. + """ + if scale is not None and zero_point is not None: + assert not inplace, "Cannot rescale with `inplace`" + output = torch._empty_affine_quantized( + input.shape, scale=scale, zero_point=int(zero_point), dtype=input.dtype + ) + torch._C._nn.leaky_relu(input, negative_slope, out=output) + return output + if inplace: + result = torch._C._nn.leaky_relu_(input, negative_slope) + else: + result = torch._C._nn.leaky_relu(input, negative_slope) + return result + + +def hardtanh( + input: Tensor, min_val: float = -1.0, max_val: float = 1.0, inplace: bool = False +) -> Tensor: + r"""This is the quantized version of :func:`~torch.nn.functional.hardtanh`.""" + if not input.is_quantized: + raise ValueError("Input to 'quantized.hardtanh' must be quantized!") + if inplace: + return torch._C._nn.hardtanh_(input, min_val, max_val) + return torch._C._nn.hardtanh(input, min_val, max_val) + + +def hardswish(input: Tensor, scale: float, zero_point: int) -> Tensor: + r"""This is the quantized version of :func:`~torch.nn.functional.hardswish`. + + Args: + input: quantized input + scale: quantization scale of the output tensor + zero_point: quantization zero point of the output tensor + """ + if not input.is_quantized: + raise ValueError("Input to 'quantized.hardswish' must be quantized!") + return torch._ops.ops.quantized.hardswish(input, scale, zero_point) + + +def threshold(input: Tensor, threshold: float, value: float) -> Tensor: + r"""Applies the quantized version of the threshold function element-wise: + + .. math:: + x = \begin{cases} + x & \text{if~} x > \text{threshold} \\ + \text{value} & \text{otherwise} + \end{cases} + + See :class:`~torch.nn.Threshold` for more details. + """ + if not input.is_quantized: + raise ValueError("Input to 'quantized.threshold' must be quantized!") + if threshold is None: + raise ValueError("Input to 'threshold' must be specified!") + if value is None: + raise ValueError("Input to 'value' must be specified!") + return torch._ops.ops.quantized.threshold(input, threshold, value) + + +def elu(input: Tensor, scale: float, zero_point: int, alpha: float = 1.0) -> Tensor: + r"""This is the quantized version of :func:`~torch.nn.functional.elu`. + + Args: + input: quantized input + scale: quantization scale of the output tensor + zero_point: quantization zero point of the output tensor + alpha: the alpha constant + """ + if not input.is_quantized: + raise ValueError("Input to 'quantized.elu' must be quantized!") + return torch.ops.quantized.elu(input, scale, zero_point, alpha) + + +def hardsigmoid(input: Tensor, inplace: bool = False) -> Tensor: + r"""This is the quantized version of :func:`~torch.nn.functional.hardsigmoid`.""" + if not input.is_quantized: + raise ValueError("Input to 'quantized.hardsigmoid' must be quantized!") + if inplace: + return torch._C._nn.hardsigmoid_(input) # type: ignore[attr-defined] + return torch._C._nn.hardsigmoid(input) + + +def clamp(input: Tensor, min_: float, max_: float) -> Tensor: + r"""float(input, min\_, max\_) -> Tensor + + Applies the clamp function element-wise. + See :class:`~torch.ao.nn.quantized.clamp` for more details. + + Args: + input: quantized input + min_: minimum value for clamping + max_: maximum value for clamping + """ + if not input.is_quantized: + raise ValueError("Input to 'quantized.clamp' must be quantized!") + return torch.clamp(input, min_, max_) + + +def upsample(input, size=None, scale_factor=None, mode="nearest", align_corners=None): + r"""Upsamples the input to either the given :attr:`size` or the given + :attr:`scale_factor` + + .. warning:: + This function is deprecated in favor of + :func:`torch.ao.nn.quantized.functional.interpolate`. + This is equivalent with ``nn.quantized.functional.interpolate(...)``. + + See :func:`torch.nn.functional.interpolate` for implementation details. + + The input dimensions are interpreted in the form: + `mini-batch x channels x [optional depth] x [optional height] x width`. + + .. note:: The input quantization parameters propagate to the output. + + .. note:: Only 2D input is supported for quantized inputs + + .. note:: Only the following modes are supported for the quantized inputs: + + - `bilinear` + - `nearest` + + Args: + input (Tensor): quantized input tensor + size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]): + output spatial size. + scale_factor (float or Tuple[float]): multiplier for spatial size. Has to be an integer. + mode (str): algorithm used for upsampling: + ``'nearest'`` | ``'bilinear'`` + align_corners (bool, optional): Geometrically, we consider the pixels of the + input and output as squares rather than points. + If set to ``True``, the input and output tensors are aligned by the + center points of their corner pixels, preserving the values at the corner pixels. + If set to ``False``, the input and output tensors are aligned by the corner + points of their corner pixels, and the interpolation uses edge value padding + for out-of-boundary values, making this operation *independent* of input size + when :attr:`scale_factor` is kept the same. This only has an effect when :attr:`mode` + is ``'bilinear'``. + Default: ``False`` + + .. warning:: + With ``align_corners = True``, the linearly interpolating modes + (`bilinear`) don't proportionally align the + output and input pixels, and thus the output values can depend on the + input size. This was the default behavior for these modes up to version + 0.3.1. Since then, the default behavior is ``align_corners = False``. + See :class:`~torch.nn.Upsample` for concrete examples on how this + affects the outputs. + """ + warnings.warn( + "nn.quantized.functional.upsample is deprecated. Use nn.quantized.functional.interpolate instead.", + stacklevel=2, + ) + return interpolate(input, size, scale_factor, mode, align_corners) + + +def upsample_bilinear(input, size=None, scale_factor=None): + r"""Upsamples the input, using bilinear upsampling. + + .. warning:: + This function is deprecated in favor of + :func:`torch.ao.nn.quantized.functional.interpolate`. + This is equivalent with + ``nn.quantized.functional.interpolate(..., mode='bilinear', align_corners=True)``. + + .. note:: The input quantization parameters propagate to the output. + + .. note:: Only 2D inputs are supported + + Args: + input (Tensor): quantized input + size (int or Tuple[int, int]): output spatial size. + scale_factor (int or Tuple[int, int]): multiplier for spatial size + """ + # DeprecationWarning is ignored by default + warnings.warn( + "nn.quantized.functional.upsample_bilinear is deprecated. Use nn.quantized.functional.interpolate instead.", + stacklevel=2, + ) + return interpolate(input, size, scale_factor, mode="bilinear", align_corners=True) + + +def upsample_nearest(input, size=None, scale_factor=None): + r"""Upsamples the input, using nearest neighbours' pixel values. + + .. warning:: + This function is deprecated in favor of + :func:`torch.ao.nn.quantized.functional.interpolate`. + This is equivalent with ``nn.quantized.functional.interpolate(..., mode='nearest')``. + + .. note:: The input quantization parameters propagate to the output. + + .. note:: Only 2D inputs are supported + + Args: + input (Tensor): quantized input + size (int or Tuple[int, int] or Tuple[int, int, int]): output spatial + size. + scale_factor (int): multiplier for spatial size. Has to be an integer. + """ + # DeprecationWarning is ignored by default + warnings.warn( + "nn.quantized.functional.upsample_nearest is deprecated. Use nn.quantized.functional.interpolate instead.", + stacklevel=2, + ) + return interpolate(input, size, scale_factor, mode="nearest") diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a3bad8c49350f56e5e58235570799a8d0968296d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__init__.py @@ -0,0 +1,162 @@ +# mypy: allow-untyped-defs +import torch + +# The quantized modules use `torch.nn` and `torch.ao.nn.quantizable` +# packages. However, the `quantizable` package uses "lazy imports" +# to avoid circular dependency. +# Hence we need to include it here to make sure it is resolved before +# they are used in the modules. +import torch.ao.nn.quantizable +from torch.nn.modules.pooling import MaxPool2d + +from .activation import ( + ELU, + Hardswish, + LeakyReLU, + MultiheadAttention, + PReLU, + ReLU6, + Sigmoid, + Softmax, +) +from .batchnorm import BatchNorm2d, BatchNorm3d +from .conv import ( + Conv1d, + Conv2d, + Conv3d, + ConvTranspose1d, + ConvTranspose2d, + ConvTranspose3d, +) +from .dropout import Dropout +from .embedding_ops import Embedding, EmbeddingBag +from .functional_modules import FloatFunctional, FXFloatFunctional, QFunctional +from .linear import Linear +from .normalization import ( + GroupNorm, + InstanceNorm1d, + InstanceNorm2d, + InstanceNorm3d, + LayerNorm, +) +from .rnn import LSTM + + +__all__ = [ + "BatchNorm2d", + "BatchNorm3d", + "Conv1d", + "Conv2d", + "Conv3d", + "ConvTranspose1d", + "ConvTranspose2d", + "ConvTranspose3d", + "DeQuantize", + "ELU", + "Embedding", + "EmbeddingBag", + "GroupNorm", + "Hardswish", + "InstanceNorm1d", + "InstanceNorm2d", + "InstanceNorm3d", + "LayerNorm", + "LeakyReLU", + "Linear", + "LSTM", + "MultiheadAttention", + "Quantize", + "ReLU6", + "Sigmoid", + "Softmax", + "Dropout", + "PReLU", + # Wrapper modules + "FloatFunctional", + "FXFloatFunctional", + "QFunctional", +] + + +class Quantize(torch.nn.Module): + r"""Quantizes an incoming tensor + + Args: + `scale`: scale of the output Quantized Tensor + `zero_point`: zero_point of output Quantized Tensor + `dtype`: data type of output Quantized Tensor + `factory_kwargs`: Dictionary of kwargs used for configuring initialization + of internal buffers. Currently, `device` and `dtype` are supported. + Example: `factory_kwargs={'device': 'cuda', 'dtype': torch.float64}` + will initialize internal buffers as type `torch.float64` on the current CUDA device. + Note that `dtype` only applies to floating-point buffers. + + Examples:: + >>> t = torch.tensor([[1., -1.], [1., -1.]]) + >>> scale, zero_point, dtype = 1.0, 2, torch.qint8 + >>> qm = Quantize(scale, zero_point, dtype) + >>> # xdoctest: +SKIP + >>> qt = qm(t) + >>> print(qt) + tensor([[ 1., -1.], + [ 1., -1.]], size=(2, 2), dtype=torch.qint8, scale=1.0, zero_point=2) + """ + + scale: torch.Tensor + zero_point: torch.Tensor + + def __init__(self, scale, zero_point, dtype, factory_kwargs=None): + factory_kwargs = torch.nn.factory_kwargs(factory_kwargs) + super().__init__() + self.register_buffer("scale", torch.tensor([scale], **factory_kwargs)) + self.register_buffer( + "zero_point", + torch.tensor( + [zero_point], + dtype=torch.long, + **{k: v for k, v in factory_kwargs.items() if k != "dtype"}, + ), + ) + self.dtype = dtype + + def forward(self, X): + return torch.quantize_per_tensor( + X, float(self.scale), int(self.zero_point), self.dtype + ) + + @staticmethod + def from_float(mod, use_precomputed_fake_quant=False): + assert hasattr(mod, "activation_post_process") + scale, zero_point = mod.activation_post_process.calculate_qparams() + return Quantize( + scale.float().item(), + zero_point.long().item(), + mod.activation_post_process.dtype, + ) + + def extra_repr(self): + return f"scale={self.scale}, zero_point={self.zero_point}, dtype={self.dtype}" + + +class DeQuantize(torch.nn.Module): + r"""Dequantizes an incoming tensor + + Examples:: + >>> input = torch.tensor([[1., -1.], [1., -1.]]) + >>> scale, zero_point, dtype = 1.0, 2, torch.qint8 + >>> qm = Quantize(scale, zero_point, dtype) + >>> # xdoctest: +SKIP + >>> quantized_input = qm(input) + >>> dqm = DeQuantize() + >>> dequantized = dqm(quantized_input) + >>> print(dequantized) + tensor([[ 1., -1.], + [ 1., -1.]], dtype=torch.float32) + """ + + def forward(self, Xq): + return Xq.dequantize() + + @staticmethod + def from_float(mod, use_precomputed_fake_quant=False): + return DeQuantize() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c20fae3fc3c9857993166fd54d7a9c304a1c69d5 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/activation.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/activation.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d70bbd085cd80fcf64db0b0e605d8e455102e13c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/activation.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/batchnorm.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/batchnorm.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..32d45639957512e5487bb75a8f88ce2b57dd4bb0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/batchnorm.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/conv.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/conv.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b3b69d018535cb44e01ca9f4802c6057e90d4160 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/conv.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/dropout.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/dropout.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b3bcc1309ac1c40d7c61e7213b59b1f925ed9b3f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/dropout.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/embedding_ops.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/embedding_ops.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe864e3435bfef9674369e13d518f5a5aaef4523 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/embedding_ops.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/functional_modules.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/functional_modules.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96a4f7d92c79c48fc67f9d2ef2eabcea52cbc142 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/functional_modules.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/linear.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/linear.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..30ee713f00616de2127d83be621ea15f1cbf5ca0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/linear.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/normalization.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/normalization.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fadf3c27b7aa77b48da2cccb76a256983db26f6e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/normalization.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/rnn.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/rnn.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a6c9b9ab5e34780cf8ac70db1fc109c9cb5e6f49 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/rnn.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e0c188409f0e41042212b515ae96957422ad667f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/activation.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/activation.py new file mode 100644 index 0000000000000000000000000000000000000000..3ecf1d5c9a1e2c198d89f284e109dd9410994b60 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/activation.py @@ -0,0 +1,351 @@ +# mypy: allow-untyped-defs +from warnings import warn + +import torch + + +__all__ = [ + "ReLU6", + "Hardswish", + "ELU", + "LeakyReLU", + "Sigmoid", + "Softmax", + "MultiheadAttention", + "PReLU", +] + + +class ReLU6(torch.nn.ReLU): + r"""Applies the element-wise function: + + :math:`\text{ReLU6}(x) = \min(\max(x_0, x), q(6))`, where :math:`x_0` is the + zero_point, and :math:`q(6)` is the quantized representation of number 6. + + Args: + inplace: can optionally do the operation in-place. Default: ``False`` + + Shape: + - Input: :math:`(N, *)` where `*` means, any number of additional + dimensions + - Output: :math:`(N, *)`, same shape as the input + + .. image:: ../scripts/activation_images/ReLU6.png + + Examples:: + + >>> m = nn.quantized.ReLU6() + >>> input = torch.randn(2) + >>> # xdoctest: +SKIP + >>> input = torch.quantize_per_tensor(input, 1.0, 0, dtype=torch.qint32) + >>> output = m(input) + """ + + def __init__(self, inplace=False): + super().__init__(inplace) + self.inplace = inplace + + def forward(self, input): + return torch.ops.quantized.relu6(input, self.inplace) + + def _get_name(self): + return "QuantizedReLU6" + + @staticmethod + def from_float(mod, use_precomputed_fake_quant=False): + return ReLU6(mod.inplace) + + +class Hardswish(torch.nn.Hardswish): + r"""This is the quantized version of :class:`~torch.nn.Hardswish`. + + Args: + scale: quantization scale of the output tensor + zero_point: quantization zero point of the output tensor + """ + + def __init__(self, scale, zero_point, device=None, dtype=None): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + # pyrefly: ignore [bad-argument-type] + self.register_buffer("scale", torch.tensor(scale, **factory_kwargs)) + # pyrefly: ignore [bad-argument-type] + self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs)) + + def forward(self, input): + return torch.ops.quantized.hardswish(input, self.scale, self.zero_point) + + def _get_name(self): + return "QuantizedHardswish" + + @staticmethod + def from_float(mod, use_precomputed_fake_quant=False): + scale, zero_point = mod.activation_post_process.calculate_qparams() + return Hardswish(float(scale), int(zero_point)) + + @classmethod + def from_reference(cls, mod, scale, zero_point): + return cls(float(scale), int(zero_point)) + + +class ELU(torch.nn.ELU): + r"""This is the quantized equivalent of :class:`~torch.nn.ELU`. + + Args: + scale: quantization scale of the output tensor + zero_point: quantization zero point of the output tensor + alpha: the alpha constant + """ + + def __init__(self, scale, zero_point, alpha=1.0): + super().__init__(alpha) + self.scale = scale + self.zero_point = zero_point + + def forward(self, input): + return torch.ao.nn.quantized.functional.elu( + input, self.scale, self.zero_point, self.alpha + ) + + def _get_name(self): + return "QuantizedELU" + + @staticmethod + def from_float(mod, use_precomputed_fake_quant=False): + scale, zero_point = mod.activation_post_process.calculate_qparams() + return ELU(float(scale), int(zero_point), mod.alpha) + + @classmethod + def from_reference(cls, mod, scale, zero_point): + return cls(float(scale), int(zero_point), mod.alpha) + + +class LeakyReLU(torch.nn.LeakyReLU): + r"""This is the quantized equivalent of :class:`~torch.nn.LeakyReLU`. + + Args: + scale: quantization scale of the output tensor + zero_point: quantization zero point of the output tensor + negative_slope: Controls the angle of the negative slope. Default: 1e-2 + """ + + def __init__( + self, + scale: float, + zero_point: int, + negative_slope: float = 1e-2, + inplace: bool = False, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__(negative_slope, inplace) + # pyrefly: ignore [bad-argument-type] + self.register_buffer("scale", torch.tensor(scale, **factory_kwargs)) + # pyrefly: ignore [bad-argument-type] + self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs)) + + def forward(self, input): + return torch.ops.quantized.leaky_relu( + input, self.negative_slope, self.inplace, self.scale, self.zero_point + ) + + def _get_name(self): + return "QuantizedLeakyReLU" + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + scale, zero_point = mod.activation_post_process.calculate_qparams() + return cls(float(scale), int(zero_point), mod.negative_slope, mod.inplace) + + @classmethod + def from_reference(cls, mod, scale, zero_point): + return cls(float(scale), int(zero_point), mod.negative_slope, mod.inplace) + + +class Sigmoid(torch.nn.Sigmoid): + r"""This is the quantized equivalent of :class:`~torch.nn.Sigmoid`. + + Args: + scale: quantization scale of the output tensor + zero_point: quantization zero point of the output tensor + """ + + def __init__(self, output_scale: float, output_zero_point: int): + super().__init__() + self.output_scale = output_scale + self.output_zero_point = output_zero_point + + def forward(self, input): + return torch.ops.quantized.sigmoid( + input, self.output_scale, self.output_zero_point + ) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + ( + output_scale, + output_zero_point, + ) = mod.activation_post_process.calculate_qparams() + return cls(float(output_scale), int(output_zero_point)) + + +class Softmax(torch.nn.Softmax): + r"""This is the quantized version of :class:`~torch.nn.Softmax`. + + Args: + dim: A dimension along which Softmax will be computed (so every slice along dim will sum to 1). + scale: quantization scale of the output tensor + zero_point: quantization zero point of the output tensor + """ + + def __init__(self, dim=None, scale=1.0, zero_point=0): + super().__init__() + self.dim = dim + self.scale = scale + self.zero_point = zero_point + + def forward(self, input): + dim = self.dim + if dim is None: + stacklevel = 3 + # Note: adding the mypy ignore on _get_softmax_dim seems less bad + # than making `_get_softmax_dim` an official API. + dim = torch.nn.functional._get_softmax_dim( # type: ignore[attr-defined] + "softmax", input.dim(), stacklevel + ) + return torch.ops.quantized.softmax(input, dim, self.scale, self.zero_point) + + def _get_name(self): + return "QuantizedSoftmax" + + @staticmethod + def from_float(mod, use_precomputed_fake_quant=False): + scale, zero_point = mod.activation_post_process.calculate_qparams() + return Softmax(mod.dim, float(scale), int(zero_point)) + + @classmethod + def from_reference(cls, mod, scale, zero_point): + return cls(mod.dim, float(scale), int(zero_point)) + + +class MultiheadAttention(torch.ao.nn.quantizable.MultiheadAttention): + # pyrefly: ignore [bad-override] + _FLOAT_MODULE = torch.ao.nn.quantizable.MultiheadAttention + + def _get_name(self): + return "QuantizedMultiheadAttention" + + @classmethod + def from_float(cls, other): + # The whole flow is float -> observed -> quantized + # This class does observed -> quantized only + raise NotImplementedError( + "It looks like you are trying to convert a " + "non-observed MHA module. Please, see " + "the examples on quantizable MHAs." + ) + + @classmethod + def from_observed(cls, other): + converted = torch.ao.quantization.convert( + other, + mapping=None, + inplace=False, + remove_qconfig=True, + convert_custom_config_dict=None, + ) + converted.__class__ = cls + # Remove the parameters for the bias_k and bias_v to quantize them + # TODO: This is a potential source of accuracy drop. + # quantized cat takes the scale and zp of the first + # element, which might lose the precision in the bias_k + # and the bias_v (which are cat'ed with k/v being first). + if converted.bias_k is not None: + bias_k = converted._parameters.pop("bias_k") + sc, zp = torch._choose_qparams_per_tensor(bias_k, reduce_range=False) + bias_k = torch.quantize_per_tensor(bias_k, sc, zp, torch.quint8) + setattr(converted, "bias_k", bias_k) # noqa: B010 + + if converted.bias_v is not None: + bias_v = converted._parameters.pop("bias_v") + sc, zp = torch._choose_qparams_per_tensor( + bias_k, # type: ignore[possibly-undefined] + reduce_range=False, + ) + bias_v = torch.quantize_per_tensor(bias_v, sc, zp, torch.quint8) + setattr(converted, "bias_v", bias_v) # noqa: B010 + + del converted.in_proj_weight + del converted.in_proj_bias + + return converted + + +class PReLU(torch.nn.Module): + r"""This is the quantized equivalent of :class:`~torch.nn.PReLU`. + + Args: + scale: quantization scale of the output tensor + zero_point: quantization zero point of the output tensor + num_parameters: number of parameters: 1, or the number of channels at input. Default: 1 + """ + + def __init__( + self, output_scale: float, output_zero_point: int, num_parameters: int = 1 + ) -> None: + super().__init__() + self.num_parameters = num_parameters + self.scale = output_scale + self.zero_point = output_zero_point + w = torch.randn(num_parameters, dtype=torch.float) + qw = torch.quantize_per_tensor(w, scale=1.0, zero_point=0, dtype=torch.quint8) + self.set_weight(qw) + + def set_weight(self, w: torch.Tensor) -> None: + self.weight = w + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return torch.ops.quantized.prelu( + input, self.weight, self.scale, self.zero_point + ) + + def _get_name(self): + return "QuantizedPReLU" + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + scale, zero_point = mod.activation_post_process.calculate_qparams() + qprelu = cls(float(scale), int(zero_point), mod.num_parameters) + float_wt = mod.weight.float() + observer = mod.qconfig.weight() + observer(float_wt) + if observer.dtype != torch.quint8: + warn( + f"PReLU's weight observer should have dtype quint8 but got {observer.dtype}", + stacklevel=2, + ) + wt_scale, wt_zp = observer.calculate_qparams() + qweight = torch.quantize_per_tensor( + float_wt, float(wt_scale), int(wt_zp), torch.quint8 + ) + qprelu.set_weight(qweight) + return qprelu + + @classmethod + def from_reference(cls, mod, scale, zero_point): + qprelu = cls(float(scale), int(zero_point), mod.num_parameters) + float_wt = mod.weight.float() + observer = mod.qconfig.weight() + observer(float_wt) + if observer.dtype != torch.quint8: + warn( + f"PReLU's weight observer should have dtype quint8 but got {observer.dtype}", + stacklevel=2, + ) + wt_scale, wt_zp = observer.calculate_qparams() + qweight = torch.quantize_per_tensor( + float_wt, float(wt_scale), int(wt_zp), torch.quint8 + ) + qprelu.set_weight(qweight) + return qprelu diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/batchnorm.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/batchnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..f1e6779c08b1f6af61c2377335b984c7f75a29a6 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/batchnorm.py @@ -0,0 +1,130 @@ +# mypy: allow-untyped-defs +import torch +import torch.ao.nn.intrinsic as nni + + +__all__ = ["BatchNorm2d", "BatchNorm3d"] + + +class _BatchNorm(torch.nn.modules.batchnorm._BatchNorm): + def __init__( + self, num_features, eps=1e-5, momentum=0.1, device=None, dtype=None + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__(num_features, eps, momentum, True, True, **factory_kwargs) + # pyrefly: ignore [bad-argument-type] + self.register_buffer("scale", torch.tensor(1.0, **factory_kwargs)) + # pyrefly: ignore [bad-argument-type] + self.register_buffer("zero_point", torch.tensor(0, **factory_kwargs)) + + @staticmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + activation_post_process = mod.activation_post_process + if type(mod) is cls._NNI_BN_RELU_MODULE: + mod = mod[0] + scale, zero_point = activation_post_process.calculate_qparams() + new_mod = cls(mod.num_features, mod.eps) + new_mod.weight = mod.weight + new_mod.bias = mod.bias + new_mod.running_mean = mod.running_mean + new_mod.running_var = mod.running_var + new_mod.scale = scale + new_mod.zero_point = zero_point + return new_mod + + @classmethod + def from_reference(cls, bn, output_scale, output_zero_point): + qbn = cls( + bn.num_features, + bn.eps, + bn.momentum, + device=bn.weight.device, + dtype=bn.weight.dtype, + ) + qbn.weight = bn.weight + qbn.bias = bn.bias + qbn.running_mean = bn.running_mean + qbn.running_var = bn.running_var + qbn.scale = output_scale + qbn.zero_point = output_zero_point + return qbn + + +class BatchNorm2d(_BatchNorm): + r"""This is the quantized version of :class:`~torch.nn.BatchNorm2d`.""" + + _NNI_BN_RELU_MODULE = nni.BNReLU2d + + def __init__( + self, num_features, eps=1e-5, momentum=0.1, device=None, dtype=None + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__(num_features, eps, momentum, **factory_kwargs) + + def _get_name(self): + return "QuantizedBatchNorm2d" + + def _check_input_dim(self, input): + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 4: + raise ValueError("Input shape must be `(N, C, H, W)`!") + + def forward(self, input: torch.Tensor) -> torch.Tensor: + # disabling this since this is not symbolically traceable + # self._check_input_dim(input) + return torch.ops.quantized.batch_norm2d( + input, + self.weight, + self.bias, + self.running_mean, + self.running_var, + self.eps, + self.scale, + self.zero_point, + ) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] + return _BatchNorm.from_float( + cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + + +class BatchNorm3d(_BatchNorm): + r"""This is the quantized version of :class:`~torch.nn.BatchNorm3d`.""" + + _NNI_BN_RELU_MODULE = nni.BNReLU3d + + def __init__(self, num_features, eps=1e-5, momentum=0.1, device=None, dtype=None): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__(num_features, eps, momentum, **factory_kwargs) + + def _get_name(self): + return "QuantizedBatchNorm3d" + + def _check_input_dim(self, input): + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 5: + raise ValueError("Input shape must be `(N, C, H, W)`!") + + def forward(self, input: torch.Tensor) -> torch.Tensor: + # disabling this since this is not symbolically traceable + # self._check_input_dim(input) + return torch.ops.quantized.batch_norm3d( + input, + self.weight, + self.bias, + self.running_mean, + self.running_var, + self.eps, + self.scale, + self.zero_point, + ) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] + return _BatchNorm.from_float( + cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/conv.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..a292d616a86c31d22550faa7d38d256350e4e91a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/conv.py @@ -0,0 +1,1244 @@ +# mypy: allow-untyped-defs +r"""Quantized convolution modules.""" + +from typing import ClassVar, Literal, Optional + +import torch +import torch.ao.nn.intrinsic as nni +import torch.ao.nn.intrinsic.qat as nniqat +import torch.nn as nn +import torch.nn.functional as F +from torch._ops import ops +from torch.nn.common_types import _size_1_t +from torch.nn.modules.utils import _pair, _single, _triple +from torch.nn.utils import fuse_conv_bn_weights + +from .utils import _quantize_weight, WeightedQuantizedModule + + +__all__ = [ + "Conv1d", + "Conv2d", + "Conv3d", + "ConvTranspose1d", + "ConvTranspose2d", + "ConvTranspose3d", +] + +_SUPPORTED_PADDING = {"zeros", "reflect"} + + +def _reverse_repeat_padding(padding: list[int]) -> list[int]: + _reversed_padding_repeated_twice: list[int] = [] + N = len(padding) + for idx in range(N): + _reversed_padding_repeated_twice.extend(padding[N - idx - 1] for _ in range(2)) + return _reversed_padding_repeated_twice + + +class _ConvNd(WeightedQuantizedModule): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode="zeros", + device=None, + dtype=None, + ): + # All subclasses have this signature - See PR #49702s + raise NotImplementedError + + def _init( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + bias, + padding_mode="zeros", + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + + if in_channels % groups != 0: + raise ValueError("in_channels must be divisible by groups") + if out_channels % groups != 0: + raise ValueError("out_channels must be divisible by groups") + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.transposed = transposed + self.output_padding = output_padding + self.groups = groups + if padding_mode not in _SUPPORTED_PADDING: + raise ValueError( + f"'padding_mode' {padding_mode} is not supported by quantized convolution" + ) + self.padding_mode = padding_mode + # Initialize as NCHW. set_weight will internally transpose to NHWC. + if self.transposed: + weight_shape = [in_channels, out_channels // self.groups] + else: + weight_shape = [out_channels, in_channels // self.groups] + qweight = torch._empty_affine_quantized( + weight_shape + list(kernel_size), + scale=1, + zero_point=0, + dtype=torch.qint8, + **{k: v for k, v in factory_kwargs.items() if k != "dtype"}, + ) + bias_float = ( + torch.zeros( + out_channels, + dtype=torch.float, + **{k: v for k, v in factory_kwargs.items() if k != "dtype"}, + ) + if bias + else None + ) + + self.set_weight_bias(qweight, bias_float) + self.scale = 1.0 + self.zero_point = 0 + + def set_weight_bias(self, qweight, bias_float): + raise NotImplementedError + + def bias(self): + raise NotImplementedError + + def _weight_bias(self): + raise NotImplementedError + + def extra_repr(self): + s = ( + "{in_channels}, {out_channels}, kernel_size={kernel_size}" + ", stride={stride}, scale={scale}, zero_point={zero_point}" + ) + if self.padding != (0,) * len(self.padding): + s += ", padding={padding}" + if self.dilation != (1,) * len(self.dilation): + s += ", dilation={dilation}" + if self.output_padding != (0,) * len(self.output_padding): + s += ", output_padding={output_padding}" + if self.groups != 1: + s += ", groups={groups}" + if self.bias() is None: + s += ", bias=False" + return s.format(**self.__dict__) + + # ===== Serialization methods ===== + # The special consideration here is that we have to unpack the weights into + # their regular QTensor form for serialization. Packed weights should not + # live outside the process in which they were created, rather they should be + # derived from the QTensor weight. + # self + # |--- weight : Tensor + # |--- bias : Tensor + # + # TODO: maybe change to this when https://github.com/pytorch/pytorch/pull/32958 is landed + # self + # |--- _packed_params : Conv2dPackedParamsBase or Conv3dPackedParamsBase + def _save_to_state_dict(self, destination, prefix, keep_vars): + super()._save_to_state_dict(destination, prefix, keep_vars) + (w, b) = self._weight_bias() + destination[prefix + "weight"] = w + destination[prefix + "bias"] = b + destination[prefix + "scale"] = torch.tensor(self.scale) + destination[prefix + "zero_point"] = torch.tensor(self.zero_point) + + @torch.jit.export + def __getstate__(self): + (w, b) = self._weight_bias() + return ( + self.in_channels, + self.out_channels, + self.kernel_size, + self.stride, + self.padding, + self.dilation, + self.transposed, + self.output_padding, + self.groups, + self.padding_mode, + w, + b, + self.scale, + self.zero_point, + self.training, + ) + + # ===== Deserialization methods ===== + # Counterpart to the serialization methods, we must pack the serialized + # QTensor weight into its packed format for use by the FBGEMM ops. + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + self.set_weight_bias(state_dict[prefix + "weight"], state_dict[prefix + "bias"]) + state_dict.pop(prefix + "weight") + state_dict.pop(prefix + "bias") + self.scale = float(state_dict[prefix + "scale"]) + state_dict.pop(prefix + "scale") + self.zero_point = int(state_dict[prefix + "zero_point"]) + state_dict.pop(prefix + "zero_point") + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + False, + missing_keys, + unexpected_keys, + error_msgs, + ) + + @torch.jit.export + def __setstate__(self, state): + self.in_channels = state[0] + self.out_channels = state[1] + self.kernel_size = state[2] + self.stride = state[3] + self.padding = state[4] + self.dilation = state[5] + self.transposed = state[6] + self.output_padding = state[7] + self.groups = state[8] + self.padding_mode = state[9] + self.set_weight_bias(state[10], state[11]) + self.scale = state[12] + self.zero_point = state[13] + self.training = state[14] + + def __deepcopy__(self, memo): + new_instance = type(self).__new__(type(self)) + torch.nn.Module.__init__(new_instance) + state = self.__getstate__() + new_instance.__setstate__(state) + return new_instance + + def __copy__(self): + return self.__deepcopy__({}) + + @classmethod + def get_qconv(cls, mod, activation_post_process, weight_post_process=None): + r"""Creates a qconv object and returns it.""" + if weight_post_process is None: + weight_post_process = mod.qconfig.weight() + weight_post_process(mod.weight) + assert weight_post_process.dtype == torch.qint8, ( + "Weight observer must have a dtype of qint8" + ) + qweight = _quantize_weight(mod.weight.float(), weight_post_process) + # the __init__ call used is the one from derived classes and not the one from _ConvNd + qconv = cls( + mod.in_channels, + mod.out_channels, + mod.kernel_size, + mod.stride, + mod.padding, + mod.dilation, + mod.groups, + mod.bias is not None, + mod.padding_mode, + ) + qconv.set_weight_bias(qweight, mod.bias) + if ( + activation_post_process is None + or activation_post_process.dtype == torch.float + ): + return qconv # dynamic quantization doesn't need scale/zero_point + else: + act_scale, act_zp = activation_post_process.calculate_qparams() + qconv.scale = float(act_scale) + qconv.zero_point = int(act_zp) + return qconv + + @staticmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + if hasattr(mod, "weight_fake_quant"): + # assert type(mod) is cls.__QAT_MODULE, " nnq." + cls.__name__ + \ + # ".from_float only works for " + cls.__QAT_MODULE.__name__ + if type(mod) is cls._NNIQAT_CONV_BN_MODULE: + mod.weight, mod.bias = fuse_conv_bn_weights( + mod.weight, + mod.bias, + mod.bn.running_mean, + mod.bn.running_var, + mod.bn.eps, + mod.bn.weight, + mod.bn.bias, + ) + assert hasattr(mod, "activation_post_process"), ( + "Input QAT module must have observer attached" + ) + weight_post_process = mod.weight_fake_quant + activation_post_process = mod.activation_post_process + else: + assert type(mod) is cls._FLOAT_MODULE, ( + " nnq." + + cls.__name__ + + ".from_float only works for " + + cls._FLOAT_MODULE.__name__ + + " but got:" + + str(type(mod)) + ) + assert hasattr(mod, "qconfig"), ( + "Input float module must have qconfig defined." + ) + activation_post_process = ( + None + if not hasattr(mod, "activation_post_process") + else mod.activation_post_process + ) + if type(mod) in [ + cls._NNI_CONV_RELU_MODULE, + cls._NNI_CONV_ADD_MODULE, + cls._NNI_CONV_ADD_RELU_MODULE, + ]: + mod = mod[0] + weight_post_process = mod.qconfig.weight() + return cls.get_qconv(mod, activation_post_process, weight_post_process) + + @classmethod + def from_reference(cls, ref_qconv, output_scale, output_zero_point): + r"""Create a (fbgemm/qnnpack) quantized module from a reference quantized module + Args: + ref_qconv (Module): a reference quantized module, either produced by torch.ao.quantization + utilities or provided by the user + output_scale (float): scale for output Tensor + output_zero_point (int): zero point for output Tensor + """ + qconv = cls( + ref_qconv.in_channels, + ref_qconv.out_channels, + ref_qconv.kernel_size, # type: ignore[arg-type] + ref_qconv.stride, # type: ignore[arg-type] + ref_qconv.padding, # type: ignore[arg-type] + ref_qconv.dilation, # type: ignore[arg-type] + ref_qconv.groups, + ref_qconv.bias is not None, # type: ignore[arg-type] + ref_qconv.padding_mode, + device=ref_qconv.weight.device, + dtype=ref_qconv.weight.dtype, + ) + qweight = ref_qconv.get_quantized_weight() + qconv.set_weight_bias(qweight, ref_qconv.bias) + qconv.scale = float(output_scale) + qconv.zero_point = int(output_zero_point) + return qconv + + +class Conv1d(_ConvNd): + r"""Applies a 1D convolution over a quantized input signal composed of + several quantized input planes. + + For details on input arguments, parameters, and implementation see + :class:`~torch.nn.Conv1d`. + + .. note:: + Only `zeros` is supported for the :attr:`padding_mode` argument. + + .. note:: + Only `torch.quint8` is supported for the input data type. + + + Attributes: + weight (Tensor): packed tensor derived from the learnable weight + parameter. + scale (Tensor): scalar for the output scale + zero_point (Tensor): scalar for the output zero point + + See :class:`~torch.nn.Conv1d` for other attributes. + + Examples:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE) + >>> m = nn.quantized.Conv1d(16, 33, 3, stride=2) + >>> input = torch.randn(20, 16, 100) + >>> # quantize input to quint8 + >>> # xdoctest: +SKIP + >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, + ... dtype=torch.quint8) + >>> output = m(q_input) + + """ + + _FLOAT_MODULE: ClassVar[type[nn.Conv1d]] = nn.Conv1d + _NNIQAT_CONV_BN_MODULE: ClassVar[Optional[type[nn.Module]]] = nniqat.ConvBn1d + _NNI_CONV_RELU_MODULE: ClassVar[Optional[type[nn.Module]]] = nni.ConvReLU1d + _NNI_CONV_ADD_MODULE: ClassVar[Optional[type[nn.Module]]] = None + _NNI_CONV_ADD_RELU_MODULE: ClassVar[Optional[type[nn.Module]]] = None + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_1_t, + stride: _size_1_t = 1, + padding: _size_1_t = 0, + dilation: _size_1_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: Literal["zeros", "reflect", "replicate", "circular"] = "zeros", + device=None, + dtype=None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + kernel_size = _single(kernel_size) + stride = _single(stride) + # pyrefly: ignore [bad-assignment] + padding = padding if isinstance(padding, str) else _single(padding) + dilation = _single(dilation) + + # Subclasses of _ConvNd needs to call _init rather than __init__. See + # discussion on PR #49702 + super()._init( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + False, + _single(0), + groups, + bias, + padding_mode, + **factory_kwargs, + ) + + def _get_name(self): + return "QuantizedConv1d" + + def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None: + if self.padding_mode == "zeros": + self._packed_params = torch.ops.quantized.conv1d_prepack( + w, b, self.stride, self.padding, self.dilation, self.groups + ) + else: + self._packed_params = torch.ops.quantized.conv1d_prepack( + w, b, self.stride, _pair(0), self.dilation, self.groups + ) + + def _weight_bias(self): + w, b = torch.ops.quantized.conv1d_unpack(self._packed_params) + return w, b + + def weight(self): + return self._weight_bias()[0] + + def bias(self): + return self._weight_bias()[1] + + def forward(self, input): + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 3: + raise ValueError("Input shape must be `(N, C, L)`!") + if self.padding_mode != "zeros": + # Padding in Conv1d is stored as (p, p), need to get (p,) + _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding[:1]) + input = F.pad( + input, _reversed_padding_repeated_twice, mode=self.padding_mode + ) + return ops.quantized.conv1d( + input, self._packed_params, self.scale, self.zero_point + ) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] + r"""Creates a quantized module from a float module or qparams_dict. + + Args: + mod (Module): a float module, either produced by torch.ao.quantization + utilities or provided by the user + """ + return _ConvNd.from_float( + cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + + +class Conv2d(_ConvNd): + r"""Applies a 2D convolution over a quantized input signal composed of + several quantized input planes. + + For details on input arguments, parameters, and implementation see + :class:`~torch.nn.Conv2d`. + + .. note:: + Only `zeros` is supported for the :attr:`padding_mode` argument. + + .. note:: + Only `torch.quint8` is supported for the input data type. + + + Attributes: + weight (Tensor): packed tensor derived from the learnable weight + parameter. + scale (Tensor): scalar for the output scale + zero_point (Tensor): scalar for the output zero point + + See :class:`~torch.nn.Conv2d` for other attributes. + + Examples:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE) + >>> # With square kernels and equal stride + >>> m = nn.quantized.Conv2d(16, 33, 3, stride=2) + >>> # non-square kernels and unequal stride and with padding + >>> m = nn.quantized.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2)) + >>> # non-square kernels and unequal stride and with padding and dilation + >>> m = nn.quantized.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1)) + >>> input = torch.randn(20, 16, 50, 100) + >>> # quantize input to quint8 + >>> # xdoctest: +SKIP + >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8) + >>> output = m(q_input) + + """ + + _FLOAT_MODULE: ClassVar[type[nn.Conv2d]] = nn.Conv2d + _NNIQAT_CONV_BN_MODULE: ClassVar[Optional[type[nn.Module]]] = nniqat.ConvBn2d + _NNI_CONV_RELU_MODULE: ClassVar[Optional[type[nn.Module]]] = nni.ConvReLU2d + _NNI_CONV_ADD_MODULE: ClassVar[type[nni.ConvAdd2d]] = nni.ConvAdd2d + _NNI_CONV_ADD_RELU_MODULE: ClassVar[type[nni.ConvAddReLU2d]] = nni.ConvAddReLU2d + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode="zeros", + device=None, + dtype=None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + kernel_size = _pair(kernel_size) + stride = _pair(stride) + padding = _pair(padding) + dilation = _pair(dilation) + # Subclasses of _ConvNd need to call _init rather than __init__. See + # discussion on PR #49702 + super()._init( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + False, + _pair(0), + groups, + bias, + padding_mode, + **factory_kwargs, + ) + + def _get_name(self): + return "QuantizedConv2d" + + def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None: + if self.padding_mode == "zeros": + self._packed_params = torch.ops.quantized.conv2d_prepack( + w, b, self.stride, self.padding, self.dilation, self.groups + ) + else: + self._packed_params = torch.ops.quantized.conv2d_prepack( + w, b, self.stride, _pair(0), self.dilation, self.groups + ) + + def _weight_bias(self): + return self._packed_params.unpack() + + def weight(self): + return self._weight_bias()[0] + + def bias(self): + return self._weight_bias()[1] + + def forward(self, input): + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 4: + raise ValueError("Input shape must be `(N, C, H, W)`!") + if self.padding_mode != "zeros": + _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding) + input = F.pad( + input, _reversed_padding_repeated_twice, mode=self.padding_mode + ) + return ops.quantized.conv2d( + input, self._packed_params, self.scale, self.zero_point + ) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] + r"""Creates a quantized module from a float module or qparams_dict. + + Args: + mod (Module): a float module, either produced by torch.ao.quantization + utilities or provided by the user + """ + return _ConvNd.from_float( + cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + + +class Conv3d(_ConvNd): + r"""Applies a 3D convolution over a quantized input signal composed of + several quantized input planes. + + For details on input arguments, parameters, and implementation see + :class:`~torch.nn.Conv3d`. + + .. note:: + Only `zeros` is supported for the :attr:`padding_mode` argument. + + .. note:: + Only `torch.quint8` is supported for the input data type. + + + Attributes: + weight (Tensor): packed tensor derived from the learnable weight + parameter. + scale (Tensor): scalar for the output scale + zero_point (Tensor): scalar for the output zero point + + See :class:`~torch.nn.Conv3d` for other attributes. + + Examples:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE) + >>> # With square kernels and equal stride + >>> m = nn.quantized.Conv3d(16, 33, 3, stride=2) + >>> # non-square kernels and unequal stride and with padding + >>> m = nn.quantized.Conv3d(16, 33, (3, 5, 5), stride=(1, 2, 2), padding=(1, 2, 2)) + >>> # non-square kernels and unequal stride and with padding and dilation + >>> m = nn.quantized.Conv3d(16, 33, (3, 5, 5), stride=(1, 2, 2), padding=(1, 2, 2), dilation=(1, 2, 2)) + >>> input = torch.randn(20, 16, 56, 56, 56) + >>> # quantize input to quint8 + >>> # xdoctest: +SKIP + >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8) + >>> output = m(q_input) + + """ + + _FLOAT_MODULE: ClassVar[type[nn.Conv3d]] = nn.Conv3d + _NNIQAT_CONV_BN_MODULE: ClassVar[Optional[type[nn.Module]]] = nniqat.ConvBn3d + _NNI_CONV_RELU_MODULE: ClassVar[Optional[type[nn.Module]]] = nni.ConvReLU3d + _NNI_CONV_ADD_MODULE: ClassVar[Optional[type[nn.Module]]] = None + _NNI_CONV_ADD_RELU_MODULE: ClassVar[Optional[type[nn.Module]]] = None + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode="zeros", + device=None, + dtype=None, + ): + assert padding_mode != "reflect", "Conv3d does not support reflection padding" + factory_kwargs = {"device": device, "dtype": dtype} + kernel_size = _triple(kernel_size) + stride = _triple(stride) + padding = _triple(padding) + dilation = _triple(dilation) + # Subclasses of _ConvNd need to call _init rather than __init__. See + # discussion on PR #49702 + super()._init( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + False, + _triple(0), + groups, + bias, + padding_mode, + **factory_kwargs, + ) + + def _get_name(self): + return "QuantizedConv3d" + + def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None: + if self.padding_mode == "zeros": + self._packed_params = torch.ops.quantized.conv3d_prepack( + w, b, self.stride, self.padding, self.dilation, self.groups + ) + else: + self._packed_params = torch.ops.quantized.conv3d_prepack( + w, b, self.stride, _triple(0), self.dilation, self.groups + ) + + def _weight_bias(self): + return self._packed_params.unpack() + + def weight(self): + return self._weight_bias()[0] + + def bias(self): + return self._weight_bias()[1] + + def forward(self, input): + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 5: + raise ValueError("Input shape must be `(N, C, D, H, W)`!") + if self.padding_mode != "zeros": + _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding) + input = F.pad( + input, _reversed_padding_repeated_twice, mode=self.padding_mode + ) + return ops.quantized.conv3d( + input, self._packed_params, self.scale, self.zero_point + ) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] + r"""Creates a quantized module from a float module or qparams_dict. + + Args: + mod (Module): a float module, either produced by torch.ao.quantization + utilities or provided by the user + """ + return _ConvNd.from_float( + cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + + +# === Transposed Convolutions === + + +class _ConvTransposeNd(_ConvNd): + _FLOAT_MODULE: ClassVar[type[nn.modules.conv._ConvNd]] + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + bias, + padding_mode, + device=None, + dtype=None, + ): + if padding_mode != "zeros": + raise ValueError( + f'Only "zeros" padding mode is supported for {self.__class__.__name__}' + ) + factory_kwargs = {"device": device, "dtype": dtype} + # Subclasses of _ConvNd need to call _init rather than __init__. See + # discussion on PR #49702 + super()._init( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + bias, + padding_mode, + **factory_kwargs, + ) + + def _input_padding( + self, kernel_size: list[int], dilation: list[int], padding: list[int] + ) -> list[int]: + res = torch.jit.annotate(list[int], []) + for kdx in range(len(kernel_size)): + pad = dilation[kdx] * (kernel_size[kdx] - 1) - padding[kdx] + res.append(pad) + return res + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] + r"""Creates a quantized module from a float module or qparams_dict. + Args: + mod (Module): a float module, either produced by torch.ao.quantization + utilities or provided by the user + """ + # derived classes override cls._FLOAT_MODULE attribute + msg = ( + " nnq." + + cls.__name__ + + ".from_float only works for " + + cls._FLOAT_MODULE.__name__ # type: ignore[attr-defined] + ) + assert type(mod) is cls._FLOAT_MODULE, msg + assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined." + weight_post_process = mod.qconfig.weight() # type: ignore[operator, union-attr] + weight_post_process(mod.weight) + assert weight_post_process.dtype == torch.qint8, ( + "Weight observer must have a dtype of qint8" + ) + qweight = _quantize_weight(mod.weight.float(), weight_post_process) + # the __init__ call used is the one from derived classes and not the one from _ConvTransposeNd + qconv = cls( + mod.in_channels, + mod.out_channels, + mod.kernel_size, # type: ignore[call-arg] + mod.stride, + mod.padding, + mod.output_padding, + mod.groups, + mod.bias is not None, + mod.dilation, + mod.padding_mode, + ) + qconv.set_weight_bias(qweight, mod.bias) + if ( + not hasattr(mod, "activation_post_process") + or mod.activation_post_process.dtype == torch.float + ): + return qconv # dynamic quantization doesn't need scale/zero_point + else: + act_scale, act_zp = mod.activation_post_process.calculate_qparams() # type: ignore[operator, union-attr] + qconv.scale = float(act_scale) + qconv.zero_point = int(act_zp) + return qconv + + @staticmethod + def from_reference(cls, ref_qconvt, output_scale, output_zero_point): # type: ignore[override] + r"""Create a (fbgemm/qnnpack) quantized module from a reference quantized module + Args: + ref_qconvt (Module): a reference quantized module, either produced by torch.ao.quantization + utilities or provided by the user + output_scale (float): scale for output Tensor + output_zero_point (int): zero point for output Tensor + """ + qconv = cls( + ref_qconvt.in_channels, + ref_qconvt.out_channels, + ref_qconvt.kernel_size, # type: ignore[arg-type] + ref_qconvt.stride, # type: ignore[arg-type] + ref_qconvt.padding, # type: ignore[arg-type] + ref_qconvt.output_padding, # type: ignore[arg-type] + ref_qconvt.groups, + ref_qconvt.bias is not None, # type: ignore[arg-type] + ref_qconvt.dilation, # type: ignore[arg-type] + ref_qconvt.padding_mode, + device=ref_qconvt.weight.device, + dtype=ref_qconvt.weight.dtype, + ) + qweight = ref_qconvt.get_quantized_weight() + qconv.set_weight_bias(qweight, ref_qconvt.bias) + qconv.scale = float(output_scale) + qconv.zero_point = int(output_zero_point) + return qconv + + +class ConvTranspose1d(_ConvTransposeNd): + r"""Applies a 1D transposed convolution operator over an input image + composed of several input planes. + For details on input arguments, parameters, and implementation see + :class:`~torch.nn.ConvTranspose1d`. + + .. note:: Currently only the QNNPACK engine is implemented. + Please, set the `torch.backends.quantized.engine = 'qnnpack'` + + For special notes, please, see :class:`~torch.ao.nn.quantized.Conv1d` + + Attributes: + weight (Tensor): packed tensor derived from the learnable weight + parameter. + scale (Tensor): scalar for the output scale + zero_point (Tensor): scalar for the output zero point + See :class:`~torch.nn.ConvTranspose2d` for other attributes. + + Examples:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE) + >>> torch.backends.quantized.engine = 'qnnpack' + >>> from torch.ao.nn import quantized as nnq + >>> # With square kernels and equal stride + >>> m = nnq.ConvTranspose1d(16, 33, 3, stride=2) + >>> # non-square kernels and unequal stride and with padding + >>> m = nnq.ConvTranspose1d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2)) + >>> input = torch.randn(20, 16, 50) + >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8) + >>> output = m(q_input) + >>> # exact output size can be also specified as an argument + >>> input = torch.randn(1, 16, 12) + >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8) + >>> downsample = nnq.Conv1d(16, 16, 3, stride=2, padding=1) + >>> upsample = nnq.ConvTranspose1d(16, 16, 3, stride=2, padding=1) + >>> h = downsample(q_input) + >>> h.size() + torch.Size([1, 16, 6]) + >>> # xdoctest: +SKIP("FIXME: output_size is not a parameter) + >>> output = upsample(h, output_size=input.size()) + >>> output.size() + torch.Size([1, 16, 12]) + """ + + _FLOAT_MODULE: ClassVar[type[nn.ConvTranspose1d]] = nn.ConvTranspose1d + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + output_padding=0, + groups=1, + bias=True, + dilation=1, + padding_mode="zeros", + device=None, + dtype=None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + kernel_size = _single(kernel_size) + stride = _single(stride) + padding = _single(padding) + dilation = _single(dilation) + output_padding = _single(output_padding) + + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + True, + output_padding, + groups, + bias, + padding_mode, + **factory_kwargs, + ) + + def _get_name(self): + return "QuantizedConvTranspose1d" + + def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None: + self._packed_params = torch.ops.quantized.conv_transpose1d_prepack( + w, + b, + self.stride, + self.padding, + self.output_padding, + self.dilation, + self.groups, + ) + + def _weight_bias(self): + w, b = torch.ops.quantized.conv_transpose1d_unpack(self._packed_params) + return w, b + + def weight(self): + (w, _) = self._weight_bias() + return w + + def bias(self): + (_, b) = self._weight_bias() + return b + + def forward(self, input): + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 3: + raise ValueError("Input shape must be `(N, C, L)`!") + return torch.ops.quantized.conv_transpose1d( + input, self._packed_params, self.scale, self.zero_point + ) + + @classmethod + def from_reference(cls, ref_qconvt, output_scale, output_zero_point): # type: ignore[override] + return _ConvTransposeNd.from_reference( + cls, ref_qconvt, output_scale, output_zero_point + ) + + +class ConvTranspose2d(_ConvTransposeNd): + r"""Applies a 2D transposed convolution operator over an input image + composed of several input planes. + For details on input arguments, parameters, and implementation see + :class:`~torch.nn.ConvTranspose2d`. + + For special notes, please, see :class:`~torch.ao.nn.quantized.Conv2d` + + Attributes: + weight (Tensor): packed tensor derived from the learnable weight + parameter. + scale (Tensor): scalar for the output scale + zero_point (Tensor): scalar for the output zero point + See :class:`~torch.nn.ConvTranspose2d` for other attributes. + + Examples:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE) + >>> # QNNPACK or FBGEMM as backend + >>> torch.backends.quantized.engine = 'qnnpack' + >>> # With square kernels and equal stride + >>> import torch.ao.nn.quantized as nnq + >>> m = nnq.ConvTranspose2d(16, 33, 3, stride=2) + >>> # non-square kernels and unequal stride and with padding + >>> m = nnq.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2)) + >>> input = torch.randn(20, 16, 50, 100) + >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8) + >>> output = m(q_input) + >>> # exact output size can be also specified as an argument + >>> input = torch.randn(1, 16, 12, 12) + >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8) + >>> downsample = nnq.Conv2d(16, 16, 3, stride=2, padding=1) + >>> upsample = nnq.ConvTranspose2d(16, 16, 3, stride=2, padding=1) + >>> h = downsample(q_input) + >>> h.size() + torch.Size([1, 16, 6, 6]) + >>> # xdoctest: +SKIP("FIXME: output_size is not a parameter) + >>> output = upsample(h, output_size=input.size()) + >>> output.size() + torch.Size([1, 16, 12, 12]) + """ + + _FLOAT_MODULE: ClassVar[type[nn.ConvTranspose2d]] = nn.ConvTranspose2d + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + output_padding=0, + groups=1, + bias=True, + dilation=1, + padding_mode="zeros", + device=None, + dtype=None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + kernel_size = _pair(kernel_size) + stride = _pair(stride) + padding = _pair(padding) + dilation = _pair(dilation) + output_padding = _pair(output_padding) + + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + True, + output_padding, + groups, + bias, + padding_mode, + **factory_kwargs, + ) + + def _get_name(self): + return "QuantizedConvTranspose2d" + + def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None: + self._packed_params = torch.ops.quantized.conv_transpose2d_prepack( + w, + b, + self.stride, + self.padding, + self.output_padding, + self.dilation, + self.groups, + ) + + def _weight_bias(self): + w, b = torch.ops.quantized.conv2d_unpack(self._packed_params) + return w, b + + def weight(self): + (w, _) = self._weight_bias() + return w + + def bias(self): + (_, b) = self._weight_bias() + return b + + def forward(self, input): + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 4: + raise ValueError("Input shape must be `(N, C, H, W)`!") + return ops.quantized.conv_transpose2d( + input, self._packed_params, self.scale, self.zero_point + ) + + @classmethod + def from_reference(cls, ref_qconvt, output_scale, output_zero_point): # type: ignore[override] + return _ConvTransposeNd.from_reference( + cls, ref_qconvt, output_scale, output_zero_point + ) + + +class ConvTranspose3d(_ConvTransposeNd): + r"""Applies a 3D transposed convolution operator over an input image + composed of several input planes. + For details on input arguments, parameters, and implementation see + :class:`~torch.nn.ConvTranspose3d`. + + .. note:: Currently only the FBGEMM engine is implemented. + Please, set the `torch.backends.quantized.engine = 'fbgemm'` + + For special notes, please, see :class:`~torch.ao.nn.quantized.Conv3d` + + Attributes: + weight (Tensor): packed tensor derived from the learnable weight + parameter. + scale (Tensor): scalar for the output scale + zero_point (Tensor): scalar for the output zero point + See :class:`~torch.nn.ConvTranspose3d` for other attributes. + + Examples:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE) + >>> torch.backends.quantized.engine = 'fbgemm' + >>> from torch.ao.nn import quantized as nnq + >>> # With cubic kernels and equal stride + >>> m = nnq.ConvTranspose3d(16, 33, 3, stride=2) + >>> # non-cubic kernels and unequal stride and with padding + >>> m = nnq.ConvTranspose3d(16, 33, (3, 3, 5), stride=(2, 1, 1), padding=(4, 2, 2)) + >>> input = torch.randn(20, 16, 50, 100, 100) + >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8) + >>> output = m(q_input) + >>> # exact output size can be also specified as an argument + >>> input = torch.randn(1, 16, 12, 12, 12) + >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8) + >>> downsample = nnq.Conv3d(16, 16, 3, stride=2, padding=1) + >>> upsample = nnq.ConvTranspose3d(16, 16, 3, stride=2, padding=1) + >>> h = downsample(q_input) + >>> h.size() + torch.Size([1, 16, 6, 6, 6]) + >>> # xdoctest: +SKIP("FIXME: output_size is not a parameter) + >>> output = upsample(h, output_size=input.size()) + >>> output.size() + torch.Size([1, 16, 12, 12, 12]) + """ + + _FLOAT_MODULE: ClassVar[type[nn.ConvTranspose3d]] = nn.ConvTranspose3d + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + output_padding=0, + groups=1, + bias=True, + dilation=1, + padding_mode="zeros", + device=None, + dtype=None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + kernel_size = _triple(kernel_size) + stride = _triple(stride) + padding = _triple(padding) + dilation = _triple(dilation) + output_padding = _triple(output_padding) + + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + True, + output_padding, + groups, + bias, + padding_mode, + **factory_kwargs, + ) + + def _get_name(self): + return "QuantizedConvTranspose3d" + + def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None: + self._packed_params = torch.ops.quantized.conv_transpose3d_prepack( + w, + b, + self.stride, + self.padding, + self.output_padding, + self.dilation, + self.groups, + ) + + def _weight_bias(self): + w, b = torch.ops.quantized.conv3d_unpack(self._packed_params) + return w, b + + def weight(self): + (w, _) = self._weight_bias() + return w + + def bias(self): + (_, b) = self._weight_bias() + return b + + def forward(self, input): + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 5: + raise ValueError("Input shape must be `(N, C, T, H, W)`!") + return ops.quantized.conv_transpose3d( + input, self._packed_params, self.scale, self.zero_point + ) + + @classmethod + def from_reference(cls, ref_qconvt, output_scale, output_zero_point): # type: ignore[override] + return _ConvTransposeNd.from_reference( + cls, ref_qconvt, output_scale, output_zero_point + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/dropout.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/dropout.py new file mode 100644 index 0000000000000000000000000000000000000000..3744ca30d5a49ba92cbb86690f2683af02d594fe --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/dropout.py @@ -0,0 +1,30 @@ +# mypy: allow-untyped-defs +import torch + + +__all__ = ["Dropout"] + + +class Dropout(torch.nn.Dropout): + r"""This is the quantized equivalent of :class:`~torch.nn.Dropout`. + And this is a placeholder to enable models where fp32 tensors + had dropout to work with quantized tensors in train and eval mode. + + Args: + p: probability of an element to be zeroed + inplace: can optionally do the operation in-place. Default: ``False`` + """ + + def forward(self, input): + return input + + def _get_name(self): + return "QuantizedDropout" + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + return cls(mod.p, mod.inplace) + + @classmethod + def from_reference(cls, mod, scale, zero_point): + return cls(mod.p, mod.inplace) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/embedding_ops.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/embedding_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..7e843653ed27a49fa62d0f7e3408a7ac04f48fdf --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/embedding_ops.py @@ -0,0 +1,413 @@ +# mypy: allow-untyped-defs +import torch +import torch.nn as nn +from torch import Tensor # noqa: F401 +from torch._jit_internal import List, Optional # noqa: F401 + +from .utils import _hide_packed_params_repr, _quantize_weight + + +__all__ = ["EmbeddingPackedParams", "Embedding", "EmbeddingBag"] + + +class EmbeddingPackedParams(torch.nn.Module): + _version = 1 + + def __init__(self, num_embeddings, embedding_dim, dtype=torch.quint8): + super().__init__() + self.dtype = dtype + if self.dtype in [torch.quint8, torch.quint4x2]: + scales = torch.ones(num_embeddings, dtype=torch.float) + zero_points = torch.zeros(num_embeddings, dtype=torch.float) + wq = torch._empty_per_channel_affine_quantized( + [num_embeddings, embedding_dim], + scales=scales, + zero_points=zero_points, + axis=0, + dtype=self.dtype, + ) + self.set_weight(wq) + else: + raise NotImplementedError( + f"Unsupported dtype on quantized embedding! Supports quint8 and quint4x2. Got dtype: {dtype}" + ) + + @torch.jit.export + def set_weight(self, weight: torch.Tensor) -> None: + if self.dtype in [torch.quint8, torch.quint4x2]: + self._packed_weight = torch.ops.quantized.embedding_bag_prepack(weight) + else: + raise NotImplementedError( + "Unsupported dtype for quantized embedding prepack! Supports quint8 and quint4x2." + ) + + @torch.jit.export + def _weight(self): + if self.dtype in [torch.quint8, torch.quint4x2]: + return torch.ops.quantized.embedding_bag_unpack(self._packed_weight) + else: + raise NotImplementedError( + "Unsupported dtype for quantized embedding unpack! Supports quint8 and quint4x2." + ) + + def forward(self, x): + return x + + # Version 1 + # self + # |--- _packed_weight : Tensor representing weight of EmbeddingPackedParamsBase + # |--- dtype : torch.dtype + + def _save_to_state_dict(self, destination, prefix, keep_vars): + super()._save_to_state_dict(destination, prefix, keep_vars) + destination[prefix + "dtype"] = self.dtype + destination[prefix + "_packed_weight"] = self._weight() + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + self.dtype = state_dict[prefix + "dtype"] + state_dict.pop(prefix + "dtype") + + weight = state_dict[prefix + "_packed_weight"] + state_dict.pop(prefix + "_packed_weight") + self.set_weight(weight) + + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + False, + missing_keys, + unexpected_keys, + error_msgs, + ) + + def __repr__(self): + return self._weight().__repr__() + + +class Embedding(torch.nn.Module): + r""" + A quantized Embedding module with quantized packed weights as inputs. + We adopt the same interface as `torch.nn.Embedding`, please see + https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html for documentation. + + Similar to :class:`~torch.nn.Embedding`, attributes will be randomly + initialized at module creation time and will be overwritten later + + Attributes: + weight (Tensor): the non-learnable quantized weights of the module of + shape :math:`(\text{num\_embeddings}, \text{embedding\_dim})`. + + Examples:: + >>> m = nn.quantized.Embedding(num_embeddings=10, embedding_dim=12) + >>> indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8]) + >>> output = m(indices) + >>> print(output.size()) + torch.Size([9, 12]) + + """ + + _version = 1 + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + sparse: bool = False, + _weight: Optional[Tensor] = None, + dtype=torch.quint8, + ) -> None: + super().__init__() + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.dtype = dtype + + if _weight is None: + scales = torch.ones(num_embeddings, dtype=torch.float) + zero_points = torch.zeros(num_embeddings, dtype=torch.float) + qweight = torch._empty_per_channel_affine_quantized( + [num_embeddings, embedding_dim], + scales=scales, + zero_points=zero_points, + axis=0, + dtype=torch.quint8, + ) + else: + assert list(_weight.shape) == [ + num_embeddings, + embedding_dim, + ], "Shape of weight does not match num_embeddings and embedding_dim" + qweight = _weight + + self._packed_params = EmbeddingPackedParams( + num_embeddings, embedding_dim, dtype + ) + self._packed_params.set_weight(qweight) + + def forward(self, indices: Tensor) -> Tensor: + if self.dtype == torch.quint4x2: + return torch.ops.quantized.embedding_4bit( + self._packed_params._packed_weight, indices + ) + else: + return torch.ops.quantized.embedding_byte( + self._packed_params._packed_weight, indices + ) + + def _get_name(self): + return "QuantizedEmbedding" + + def __repr__(self): + return _hide_packed_params_repr(self, EmbeddingPackedParams) + + def extra_repr(self): + extra_repr_str = ( + f"num_embeddings={self.num_embeddings}, embedding_dim={self.embedding_dim}, " + f"dtype={self._packed_params.dtype}, qscheme={self.weight().qscheme()}" + ) + + return extra_repr_str + + def set_weight(self, w: torch.Tensor) -> None: + self._packed_params.set_weight(w) + + def weight(self): + return self._packed_params._weight() + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + r"""Create a quantized embedding module from a float module + + Args: + mod (Module): a float module, either produced by torch.ao.quantization + utilities or provided by user + """ + if hasattr(mod, "weight_fake_quant"): + assert type(mod) is torch.ao.nn.qat.Embedding, ( + "nnq." + + cls.__name__ + + ".from_float " + + "with fake quant only works for " + + torch.ao.nn.qat.Embedding.__name__ + ) + weight_observer = mod.weight_fake_quant + else: + assert type(mod) is nn.Embedding, ( + "nnq." + + cls.__name__ + + ".from_float only works for " + + nn.Embedding.__name__ + ) + assert hasattr(mod, "qconfig"), ( + "Embedding input float module must have qconfig defined" + ) + from torch.ao.quantization import float_qparams_weight_only_qconfig + + if mod.qconfig is not None and mod.qconfig.weight is not None: # type: ignore[union-attr] + weight_observer = mod.qconfig.weight() # type: ignore[union-attr, operator] + else: + weight_observer = float_qparams_weight_only_qconfig.weight() + + dtype = weight_observer.dtype + is_float_qparams_qconfig = ( + weight_observer.qscheme == torch.per_channel_affine_float_qparams + ) + assert is_float_qparams_qconfig, ( + "Embedding quantization is only supported with float_qparams_weight_only_qconfig." + ) + + assert dtype == torch.quint8 or dtype == torch.quint4x2, ( + f"The only supported dtype for nnq.Embedding is torch.quint8 and torch.quint4x2, got {dtype}" + ) + + # Run the observer to calculate qparams. + weight_observer(mod.weight) + qweight = _quantize_weight(mod.weight.float(), weight_observer) + + # Create quantized Embedding module and pass in the quantized weight + qembedding = Embedding(mod.num_embeddings, mod.embedding_dim) + qembedding.set_weight(qweight) + return qembedding + + @classmethod + def from_reference(cls, ref_embedding): + qembedding = cls( + ref_embedding.num_embeddings, + ref_embedding.embedding_dim, + ref_embedding.padding_idx, + ref_embedding.max_norm, + ref_embedding.norm_type, + ref_embedding.scale_grad_by_freq, + ref_embedding.sparse, + ref_embedding.get_quantized_weight(), + ref_embedding.weight_dtype, + ) + return qembedding + + +class EmbeddingBag(Embedding): + r""" + A quantized EmbeddingBag module with quantized packed weights as inputs. + We adopt the same interface as `torch.nn.EmbeddingBag`, please see + https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html for documentation. + + Similar to :class:`~torch.nn.EmbeddingBag`, attributes will be randomly + initialized at module creation time and will be overwritten later + + Attributes: + weight (Tensor): the non-learnable quantized weights of the module of + shape :math:`(\text{num\_embeddings}, \text{embedding\_dim})`. + + Examples:: + >>> m = nn.quantized.EmbeddingBag(num_embeddings=10, embedding_dim=12, include_last_offset=True, mode='sum') + >>> indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3]) + >>> offsets = torch.tensor([0, 19, 20, 28, 28, 32]) + >>> output = m(indices, offsets) + >>> print(output.size()) + torch.Size([5, 12]) + + """ + + _version = 1 + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + max_norm: Optional[float] = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + mode: str = "sum", + sparse: bool = False, + _weight: Optional[Tensor] = None, + include_last_offset: bool = False, + dtype=torch.quint8, + ) -> None: + super().__init__(num_embeddings, embedding_dim, _weight=_weight, dtype=dtype) + + self.mode = mode + self.pruned_weights = False + self.include_last_offset = include_last_offset + self.dtype = dtype + + def forward( + self, + indices: Tensor, + offsets: Optional[Tensor] = None, + per_sample_weights: Optional[Tensor] = None, + compressed_indices_mapping: Optional[Tensor] = None, + ) -> Tensor: + if self.dtype == torch.quint4x2: + return torch.ops.quantized.embedding_bag_4bit( + self._packed_params._packed_weight, + indices, + offsets, + False, + 0, + self.pruned_weights, + per_sample_weights, + compressed_indices_mapping, + self.include_last_offset, + ) + else: + return torch.ops.quantized.embedding_bag_byte( + self._packed_params._packed_weight, + indices, + offsets, + False, + 0, + self.pruned_weights, + per_sample_weights, + compressed_indices_mapping, + self.include_last_offset, + ) + + def _get_name(self): + return "QuantizedEmbeddingBag" + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + r"""Create a quantized embedding_bag module from a float module + + Args: + mod (Module): a float module, either produced by torch.ao.quantization + utilities or provided by user + """ + if hasattr(mod, "weight_fake_quant"): + weight_observer = mod.weight_fake_quant + else: + assert type(mod) is nn.EmbeddingBag, ( + "nnq." + + cls.__name__ + + ".from_float only works for " + + nn.EmbeddingBag.__name__ + ) + assert hasattr(mod, "qconfig"), ( + "EmbeddingBag input float module must have qconfig defined" + ) + from torch.ao.quantization.qconfig import float_qparams_weight_only_qconfig + + if mod.qconfig is not None and mod.qconfig.weight is not None: # type: ignore[union-attr] + weight_observer = mod.qconfig.weight() # type: ignore[union-attr, operator] + else: + weight_observer = float_qparams_weight_only_qconfig.weight() + + dtype = weight_observer.dtype + is_float_qparams_qconfig = ( + weight_observer.qscheme == torch.per_channel_affine_float_qparams + ) + assert is_float_qparams_qconfig, ( + "EmbeddingBag quantization is only supported with float_qparams_weight_only_qconfig." + ) + + assert dtype == torch.quint8 or dtype == torch.quint4x2, ( + f"The only supported dtype for nnq.EmbeddingBag is torch.quint8 and torch.quint4x2, got {dtype}" + ) + + # Run the observer to calculate qparams. + weight_observer(mod.weight) + qweight = _quantize_weight(mod.weight.float(), weight_observer) + + # Create quantized EmbeddingBag module and pass in the quantized weight + qembedding_bag = EmbeddingBag( + mod.num_embeddings, + mod.embedding_dim, + max_norm=mod.max_norm, + norm_type=mod.norm_type, + scale_grad_by_freq=mod.scale_grad_by_freq, + mode=mod.mode, + sparse=mod.sparse, + include_last_offset=mod.include_last_offset, + dtype=dtype, + ) + qembedding_bag.set_weight(qweight) + return qembedding_bag + + @classmethod + def from_reference(cls, ref_embedding_bag): + qembedding_bag = cls( + ref_embedding_bag.num_embeddings, + ref_embedding_bag.embedding_dim, + ref_embedding_bag.max_norm, + ref_embedding_bag.norm_type, + ref_embedding_bag.scale_grad_by_freq, + ref_embedding_bag.mode, + ref_embedding_bag.sparse, + ref_embedding_bag.get_quantized_weight(), + ref_embedding_bag.include_last_offset, + ref_embedding_bag.weight_dtype, + ) + return qembedding_bag diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/functional_modules.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/functional_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..acb578d0cc7989ecedd92fcb30664d50b4c18f87 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/functional_modules.py @@ -0,0 +1,298 @@ +# mypy: allow-untyped-defs + +import torch +from torch import Tensor +from torch._ops import ops + + +__all__ = ["FloatFunctional", "FXFloatFunctional", "QFunctional"] + + +class FloatFunctional(torch.nn.Module): + r"""State collector class for float operations. + + The instance of this class can be used instead of the ``torch.`` prefix for + some operations. See example usage below. + + .. note:: + + This class does not provide a ``forward`` hook. Instead, you must use + one of the underlying functions (e.g. ``add``). + + Examples:: + + >>> f_add = FloatFunctional() + >>> a = torch.tensor(3.0) + >>> b = torch.tensor(4.0) + >>> f_add.add(a, b) # Equivalent to ``torch.add(a, b)`` + + Valid operation names: + - add + - cat + - mul + - add_relu + - add_scalar + - mul_scalar + """ + + def __init__(self) -> None: + super().__init__() + self.activation_post_process = torch.nn.Identity() + + def forward(self, x): + raise RuntimeError( + "FloatFunctional is not intended to use the " + + "'forward'. Please use the underlying operation" + ) + + r"""Operation equivalent to ``torch.add(Tensor, Tensor)``""" + + def add(self, x: Tensor, y: Tensor) -> Tensor: + r = torch.add(x, y) + r = self.activation_post_process(r) + return r + + r"""Operation equivalent to ``torch.add(Tensor, float)``""" + + def add_scalar(self, x: Tensor, y: float) -> Tensor: + r = torch.add(x, y) + # Note: this operation is not observed because the observation is not + # needed for the quantized op. + return r + + r"""Operation equivalent to ``torch.mul(Tensor, Tensor)``""" + + def mul(self, x: Tensor, y: Tensor) -> Tensor: + r = torch.mul(x, y) + r = self.activation_post_process(r) + return r + + r"""Operation equivalent to ``torch.mul(Tensor, float)``""" + + def mul_scalar(self, x: Tensor, y: float) -> Tensor: + r = torch.mul(x, y) + # Note: this operation is not observed because the observation is not + # needed for the quantized op. + return r + + r"""Operation equivalent to ``torch.cat``""" + + def cat(self, x: list[Tensor], dim: int = 0) -> Tensor: + r = torch.cat(x, dim=dim) + r = self.activation_post_process(r) + return r + + r"""Operation equivalent to ``relu(torch.add(x,y))``""" + + def add_relu(self, x: Tensor, y: Tensor) -> Tensor: + r = torch.add(x, y) + r = torch.nn.functional.relu(r) + r = self.activation_post_process(r) + return r + + r"""Operation equivalent to ``torch.matmul(Tensor, Tensor)``""" + + def matmul(self, x: Tensor, y: Tensor) -> Tensor: + r = torch.matmul(x, y) + r = self.activation_post_process(r) + return r + + +class FXFloatFunctional(torch.nn.Module): + r"""module to replace FloatFunctional module before FX graph mode quantization, + since activation_post_process will be inserted in top level module directly + + Valid operation names: + - add + - cat + - mul + - add_relu + - add_scalar + - mul_scalar + """ + + def forward(self, x): + raise RuntimeError( + "FloatFunctional is not intended to use the " + + "'forward'. Please use the underlying operation" + ) + + r"""Operation equivalent to ``torch.add(Tensor, Tensor)``""" + + def add(self, x: Tensor, y: Tensor) -> Tensor: + r = torch.add(x, y) + return r + + r"""Operation equivalent to ``torch.add(Tensor, float)``""" + + def add_scalar(self, x: Tensor, y: float) -> Tensor: + r = torch.add(x, y) + return r + + r"""Operation equivalent to ``torch.mul(Tensor, Tensor)``""" + + def mul(self, x: Tensor, y: Tensor) -> Tensor: + r = torch.mul(x, y) + return r + + r"""Operation equivalent to ``torch.mul(Tensor, float)``""" + + def mul_scalar(self, x: Tensor, y: float) -> Tensor: + r = torch.mul(x, y) + return r + + r"""Operation equivalent to ``torch.cat``""" + + def cat(self, x: list[Tensor], dim: int = 0) -> Tensor: + r = torch.cat(x, dim=dim) + return r + + r"""Operation equivalent to ``relu(torch.add(x,y))``""" + + def add_relu(self, x: Tensor, y: Tensor) -> Tensor: + r = torch.add(x, y) + r = torch.nn.functional.relu(r) + return r + + r"""Operation equivalent to ``torch.matmul(Tensor, Tensor)``""" + + def matmul(self, x: Tensor, y: Tensor) -> Tensor: + r = torch.matmul(x, y) + return r + + +class QFunctional(torch.nn.Module): + r"""Wrapper class for quantized operations. + + The instance of this class can be used instead of the + ``torch.ops.quantized`` prefix. See example usage below. + + .. note:: + + This class does not provide a ``forward`` hook. Instead, you must use + one of the underlying functions (e.g. ``add``). + + Examples:: + + >>> q_add = QFunctional() + >>> # xdoctest: +SKIP + >>> a = torch.quantize_per_tensor(torch.tensor(3.0), 1.0, 0, torch.qint32) + >>> b = torch.quantize_per_tensor(torch.tensor(4.0), 1.0, 0, torch.qint32) + >>> q_add.add(a, b) # Equivalent to ``torch.ops.quantized.add(a, b, 1.0, 0)`` + + Valid operation names: + - add + - cat + - mul + - add_relu + - add_scalar + - mul_scalar + """ + + def __init__(self) -> None: + super().__init__() + self.scale = 1.0 + self.zero_point = 0 + self.activation_post_process = torch.nn.Identity() + + def _save_to_state_dict(self, destination, prefix, keep_vars): + super()._save_to_state_dict(destination, prefix, keep_vars) + destination[prefix + "scale"] = torch.tensor(self.scale) + destination[prefix + "zero_point"] = torch.tensor(self.zero_point) + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + self.scale = float(state_dict.pop(prefix + "scale")) + self.zero_point = int(state_dict.pop(prefix + "zero_point")) + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + False, + missing_keys, + unexpected_keys, + error_msgs, + ) + + def _get_name(self): + return "QFunctional" + + def extra_repr(self): + return f"scale={self.scale}, zero_point={self.zero_point}" + + def forward(self, x): + raise RuntimeError( + "Functional is not intended to use the " + + "'forward'. Please use the underlying operation" + ) + + r"""Operation equivalent to ``torch.ops.quantized.add``""" + + def add(self, x: Tensor, y: Tensor) -> Tensor: + r = ops.quantized.add(x, y, scale=self.scale, zero_point=self.zero_point) + r = self.activation_post_process(r) + return r + + r"""Operation equivalent to ``torch.ops.quantized.add(Tensor, float)``""" + + def add_scalar(self, x: Tensor, y: float) -> Tensor: + r = ops.quantized.add_scalar(x, y) + # Note: this operation is not observed because the observation is not + # needed for the quantized op. + return r + + r"""Operation equivalent to ``torch.ops.quantized.mul(Tensor, Tensor)``""" + + def mul(self, x: Tensor, y: Tensor) -> Tensor: + r = ops.quantized.mul(x, y, scale=self.scale, zero_point=self.zero_point) + r = self.activation_post_process(r) + return r + + r"""Operation equivalent to ``torch.ops.quantized.mul(Tensor, float)``""" + + def mul_scalar(self, x: Tensor, y: float) -> Tensor: + r = ops.quantized.mul_scalar(x, y) + # Note: this operation is not observed because the observation is not + # needed for the quantized op. + return r + + r"""Operation equivalent to ``torch.ops.quantized.cat``""" + + def cat(self, x: list[Tensor], dim: int = 0) -> Tensor: + r = ops.quantized.cat(x, scale=self.scale, zero_point=self.zero_point, dim=dim) + r = self.activation_post_process(r) + return r + + r"""Operation equivalent to ``torch.ops.quantized.add_relu``""" + + def add_relu(self, x: Tensor, y: Tensor) -> Tensor: + r = ops.quantized.add_relu(x, y, scale=self.scale, zero_point=self.zero_point) + r = self.activation_post_process(r) + return r + + r"""Operation equivalent to ``torch.ops.quantized.matmul(Tensor, Tensor)``""" + + def matmul(self, x: Tensor, y: Tensor) -> Tensor: + r = ops.quantized.matmul(x, y, scale=self.scale, zero_point=self.zero_point) + # Note: this operation is not observed because the observation is not + # needed for the quantized op. + return r + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + assert type(mod) is FloatFunctional, ( + "QFunctional.from_float expects an instance of FloatFunctional" + ) + scale, zero_point = mod.activation_post_process.calculate_qparams() # type: ignore[operator] + new_mod = QFunctional() + new_mod.scale = float(scale) + new_mod.zero_point = int(zero_point) + return new_mod diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/linear.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..84fa07b4a02207a34c16747d52d7283ad2ecfc8f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/linear.py @@ -0,0 +1,361 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +from collections.abc import Iterable + +import torch +import torch.ao.nn.intrinsic as nni +import torch.ao.nn.intrinsic.qat as nniqat +import torch.nn as nn +from torch.nn.utils.fusion import fuse_linear_bn_weights +from torch.nn.utils.parametrize import type_before_parametrizations + +from .utils import _hide_packed_params_repr, _quantize_weight, WeightedQuantizedModule + + +__all__ = ["LinearPackedParams", "Linear"] + + +class LinearPackedParams(torch.nn.Module): + _version = 3 + + def __init__(self, dtype=torch.qint8): + super().__init__() + self.dtype = dtype + if self.dtype == torch.qint8: + wq = torch._empty_affine_quantized( + [1, 1], scale=1.0, zero_point=0, dtype=torch.qint8 + ) + elif self.dtype == torch.float16: + wq = torch.zeros([1, 1], dtype=torch.float) + self.set_weight_bias(wq, None) # type: ignore[possibly-undefined] + + @torch.jit.export + def set_weight_bias(self, weight: torch.Tensor, bias: torch.Tensor | None) -> None: + if self.dtype == torch.qint8: + self._packed_params = torch.ops.quantized.linear_prepack(weight, bias) + elif self.dtype == torch.float16: + self._packed_params = torch.ops.quantized.linear_prepack_fp16(weight, bias) + else: + raise RuntimeError("Unsupported dtype on dynamic quantized linear!") + + @torch.jit.export + def _weight_bias(self): + if self.dtype == torch.qint8: + return torch.ops.quantized.linear_unpack(self._packed_params) + elif self.dtype == torch.float16: + return torch.ops.quantized.linear_unpack_fp16(self._packed_params) + else: + raise RuntimeError("Unsupported dtype on dynamic quantized linear!") + + def forward(self, x): + return x + + # Version 1 + # self + # |--- weight : Tensor + # |--- bias : Tensor + # + # Version 2 + # self + # |--- weight : Tensor + # |--- bias : Tensor + # |--- dtype : torch.dtype + # + # Version 3 + # self + # |--- _packed_params : (Tensor, Tensor) representing (weight, bias) + # of LinearPackedParams + # |--- dtype : torch.dtype + def _save_to_state_dict(self, destination, prefix, keep_vars): + super()._save_to_state_dict(destination, prefix, keep_vars) + destination[prefix + "dtype"] = self.dtype + destination[prefix + "_packed_params"] = self._weight_bias() + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + version = local_metadata.get("version", None) + if version is None or version < 2: + self.dtype = torch.qint8 + else: + self.dtype = state_dict[prefix + "dtype"] + state_dict.pop(prefix + "dtype") + + if version is None or version < 3: + self.set_weight_bias( + state_dict[prefix + "weight"], state_dict[prefix + "bias"] + ) + state_dict.pop(prefix + "weight") + state_dict.pop(prefix + "bias") + + if version == 3: + weight, bias = state_dict[prefix + "_packed_params"] + state_dict.pop(prefix + "_packed_params") + self.set_weight_bias(weight, bias) + + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + False, + missing_keys, + unexpected_keys, + error_msgs, + ) + + def __repr__(self): + return self._weight_bias().__repr__() + + +class Linear(WeightedQuantizedModule): + r""" + A quantized linear module with quantized tensor as inputs and outputs. + We adopt the same interface as `torch.nn.Linear`, please see + https://pytorch.org/docs/stable/nn.html#torch.nn.Linear for documentation. + + Similar to :class:`~torch.nn.Linear`, attributes will be randomly + initialized at module creation time and will be overwritten later + + Attributes: + weight (Tensor): the non-learnable quantized weights of the module of + shape :math:`(\text{out\_features}, \text{in\_features})`. + bias (Tensor): the non-learnable bias of the module of shape :math:`(\text{out\_features})`. + If :attr:`bias` is ``True``, the values are initialized to zero. + scale: `scale` parameter of output Quantized Tensor, type: double + zero_point: `zero_point` parameter for output Quantized Tensor, type: long + + Examples:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE) + >>> m = nn.quantized.Linear(20, 30) + >>> input = torch.randn(128, 20) + >>> # xdoctest: +SKIP + >>> input = torch.quantize_per_tensor(input, 1.0, 0, torch.quint8) + >>> output = m(input) + >>> print(output.size()) + torch.Size([128, 30]) + """ + + _version = 3 + _FLOAT_MODULE = (nn.Linear, nn.modules.linear.NonDynamicallyQuantizableLinear) + + def __init__(self, in_features, out_features, bias_=True, dtype=torch.qint8): + super().__init__() + # We don't muck around with buffers or attributes or anything here + # to keep the module simple. *everything* is simply a Python attribute. + # Serialization logic is explicitly handled in the below serialization and + # deserialization modules + self.in_features = in_features + self.out_features = out_features + bias = None + if bias_: + bias = torch.zeros(out_features, dtype=torch.float) + + if dtype == torch.qint8: + qweight = torch._empty_affine_quantized( + [out_features, in_features], scale=1, zero_point=0, dtype=torch.qint8 + ) + elif dtype == torch.float16: + qweight = torch.zeros([out_features, in_features], dtype=torch.float) + else: + raise RuntimeError("Unsupported dtype specified for quantized Linear!") + + self._packed_params = LinearPackedParams(dtype) + self._packed_params.set_weight_bias(qweight, bias) + self.scale = 1.0 + self.zero_point = 0 + + def _get_name(self): + return "QuantizedLinear" + + def extra_repr(self): + return ( + f"in_features={self.in_features}, out_features={self.out_features}, scale={self.scale}, " + f"zero_point={self.zero_point}, qscheme={self.weight().qscheme()}" + ) + + def __repr__(self): + return _hide_packed_params_repr(self, LinearPackedParams) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.ops.quantized.linear( + x, self._packed_params._packed_params, self.scale, self.zero_point + ) + + # ===== Serialization methods ===== + # The special consideration here is that we have to unpack the weights into their + # regular QTensor form for serialization. Packed weights should not live + # outside the process in which they were created, rather they should be derived + # from the QTensor weight. + # + # Version 1 + # self + # |--- scale : float + # |--- zero_point : int + # |--- weight : Tensor + # |--- bias : Tensor + # + # Version 2 + # self + # |--- scale : float + # |--- zero_point : int + # |--- _packed_params : Module + # |--- weight : Tensor + # |--- bias : Tensor + # + # Version 3 + # self + # |--- scale : float + # |--- zero_point : int + # |--- _packed_params : Module + # |--- _packed_params : (Tensor, Tensor) representing weight, bias + # of LinearPackedParams C++ struct + # + def _save_to_state_dict(self, destination, prefix, keep_vars): + super()._save_to_state_dict(destination, prefix, keep_vars) + destination[prefix + "scale"] = torch.tensor(self.scale) + destination[prefix + "zero_point"] = torch.tensor(self.zero_point) + + # ===== Deserialization methods ===== + # Counterpart to the serialization methods, we must pack the serialized QTensor + # weight into its packed format for use by the FBGEMM ops. + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + self.scale = float(state_dict[prefix + "scale"]) + state_dict.pop(prefix + "scale") + + self.zero_point = int(state_dict[prefix + "zero_point"]) + state_dict.pop(prefix + "zero_point") + + version = local_metadata.get("version", None) + + if version is None or version == 1: + # We moved the parameters into a LinearPackedParameters submodule + weight = state_dict.pop(prefix + "weight") + bias = state_dict.pop(prefix + "bias") + state_dict.update( + { + prefix + "_packed_params.weight": weight, + prefix + "_packed_params.bias": bias, + } + ) + + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + False, + missing_keys, + unexpected_keys, + error_msgs, + ) + + # Function rather than property to make sure that JIT serialization doesn't + # register this as an attribute + def _weight_bias(self): + return self._packed_params._weight_bias() + + def weight(self): + return self._weight_bias()[0] + + def bias(self): + return self._weight_bias()[1] + + def set_weight_bias(self, w: torch.Tensor, b: torch.Tensor | None) -> None: + self._packed_params.set_weight_bias(w, b) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + r"""Create a quantized module from an observed float module + + Args: + mod (Module): a float module, either produced by torch.ao.quantization + utilities or provided by the user + use_precomputed_fake_quant (bool): if True, the module will reuse min/max + values from the precomputed fake quant module. + """ + if hasattr(mod, "weight_fake_quant"): + if type_before_parametrizations(mod) == nniqat.LinearBn1d: + mod.weight, mod.bias = fuse_linear_bn_weights( + mod.weight, + mod.bias, + mod.bn.running_mean, + mod.bn.running_var, + mod.bn.eps, + mod.bn.weight, + mod.bn.bias, + ) + weight_post_process = mod.weight_fake_quant + activation_post_process = mod.activation_post_process + else: + # This function does not participate in JIT, so it is OK to ignore + # the type mismatch in assignment. Also, mypy has an issue with + # iterables not being implemented, so we are ignoring those too. + if not isinstance(cls._FLOAT_MODULE, Iterable): + # pyrefly: ignore [bad-assignment] + cls._FLOAT_MODULE = [cls._FLOAT_MODULE] + supported_modules = ", ".join( + [float_mod.__name__ for float_mod in cls._FLOAT_MODULE] + ) + error_msg = f"nnq.{cls.__name__}.from_float only works for {supported_modules}, but got: {type(mod)}" + assert type_before_parametrizations(mod) in cls._FLOAT_MODULE, ( + error_msg.format() + ) + assert hasattr(mod, "qconfig"), ( + "Input float module must have qconfig defined" + ) + activation_post_process = mod.activation_post_process + if type_before_parametrizations(mod) == nni.LinearReLU: + mod = mod[0] + weight_post_process = ( + mod.qconfig.weight() + if not hasattr(mod, "weight_fake_quant") + else mod.weight_fake_quant + ) + + if not use_precomputed_fake_quant: + # Observer may not have been called yet + # Observer might have been called in the previous stage via PTQ algorithm e.g. AdaRound + weight_post_process(mod.weight) + dtype = weight_post_process.dtype + act_scale, act_zp = activation_post_process.calculate_qparams() + assert dtype == torch.qint8, "Weight observer must have dtype torch.qint8" + qweight = _quantize_weight(mod.weight.float(), weight_post_process) + qlinear = cls(mod.in_features, mod.out_features, dtype=dtype) + qlinear.set_weight_bias(qweight, mod.bias) + qlinear.scale = float(act_scale) + qlinear.zero_point = int(act_zp) + return qlinear + + @classmethod + def from_reference(cls, ref_qlinear, output_scale, output_zero_point): + r"""Create a (fbgemm/qnnpack) quantized module from a reference quantized module + + Args: + ref_qlinear (Module): a reference quantized linear module, either produced by torch.ao.quantization + utilities or provided by the user + output_scale (float): scale for output Tensor + output_zero_point (int): zero point for output Tensor + """ + qlinear = cls(ref_qlinear.in_features, ref_qlinear.out_features) + qweight = ref_qlinear.get_quantized_weight() + qlinear.set_weight_bias(qweight, ref_qlinear.bias) + + qlinear.scale = float(output_scale) + qlinear.zero_point = int(output_zero_point) + return qlinear diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/normalization.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/normalization.py new file mode 100644 index 0000000000000000000000000000000000000000..fa335b4699db5519e2e53f27aa18958b5afced94 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/normalization.py @@ -0,0 +1,358 @@ +# mypy: allow-untyped-defs +import torch + + +__all__ = [ + "LayerNorm", + "GroupNorm", + "InstanceNorm1d", + "InstanceNorm2d", + "InstanceNorm3d", +] + + +class LayerNorm(torch.nn.LayerNorm): + r"""This is the quantized version of :class:`~torch.nn.LayerNorm`. + + Additional args: + * **scale** - quantization scale of the output, type: double. + * **zero_point** - quantization zero point of the output, type: long. + + """ + + def __init__( + self, + normalized_shape, + weight, + bias, + scale, + zero_point, + eps=1e-5, + elementwise_affine=True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__( + normalized_shape, + eps=eps, + elementwise_affine=elementwise_affine, + # pyrefly: ignore [bad-argument-type] + **factory_kwargs, + ) + self.weight = weight + self.bias = bias + # pyrefly: ignore [bad-argument-type] + self.register_buffer("scale", torch.tensor(scale, **factory_kwargs)) + # pyrefly: ignore [bad-argument-type] + self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs)) + + def forward(self, input): + return torch.ops.quantized.layer_norm( + input, + self.normalized_shape, + weight=self.weight, + bias=self.bias, + eps=self.eps, + output_scale=self.scale, + output_zero_point=self.zero_point, + ) + + def _get_name(self): + return "QuantizedLayerNorm" + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + scale, zero_point = mod.activation_post_process.calculate_qparams() + new_mod = cls( + mod.normalized_shape, + mod.weight, + mod.bias, + float(scale), + int(zero_point), + mod.eps, + mod.elementwise_affine, + ) + return new_mod + + @classmethod + def from_reference(cls, mod, scale, zero_point): + return cls( + mod.normalized_shape, + mod.weight, + mod.bias, + float(scale), + int(zero_point), + mod.eps, + mod.elementwise_affine, + ) + + +class GroupNorm(torch.nn.GroupNorm): + r"""This is the quantized version of :class:`~torch.nn.GroupNorm`. + + Additional args: + * **scale** - quantization scale of the output, type: double. + * **zero_point** - quantization zero point of the output, type: long. + + """ + + __constants__ = ["num_groups", "num_channels", "eps", "affine"] + + def __init__( + self, + num_groups, + num_channels, + weight, + bias, + scale, + zero_point, + eps=1e-5, + affine=True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__(num_groups, num_channels, eps, affine, **factory_kwargs) + self.weight = weight + self.bias = bias + # pyrefly: ignore [bad-argument-type] + self.register_buffer("scale", torch.tensor(scale, **factory_kwargs)) + # pyrefly: ignore [bad-argument-type] + self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs)) + + def forward(self, input): + return torch.ops.quantized.group_norm( + input, + self.num_groups, + self.weight, + self.bias, + self.eps, + self.scale, + self.zero_point, + ) + + def _get_name(self): + return "QuantizedGroupNorm" + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + scale, zero_point = mod.activation_post_process.calculate_qparams() + new_mod = cls( + mod.num_groups, + mod.num_channels, + mod.weight, + mod.bias, + float(scale), + int(zero_point), + mod.eps, + mod.affine, + ) + return new_mod + + +class InstanceNorm1d(torch.nn.InstanceNorm1d): + r"""This is the quantized version of :class:`~torch.nn.InstanceNorm1d`. + + Additional args: + * **scale** - quantization scale of the output, type: double. + * **zero_point** - quantization zero point of the output, type: long. + + """ + + def __init__( + self, + num_features, + weight, + bias, + scale, + zero_point, + eps=1e-5, + momentum=0.1, + affine=False, + track_running_stats=False, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__( + num_features, eps, momentum, affine, track_running_stats, **factory_kwargs + ) + self.weight = weight + self.bias = bias + # pyrefly: ignore [bad-argument-type] + self.register_buffer("scale", torch.tensor(scale, **factory_kwargs)) + # pyrefly: ignore [bad-argument-type] + self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs)) + + def forward(self, input): + return torch.ops.quantized.instance_norm( + input, self.weight, self.bias, self.eps, self.scale, self.zero_point + ) + + def _get_name(self): + return "QuantizedInstanceNorm1d" + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + scale, zero_point = mod.activation_post_process.calculate_qparams() + new_mod = cls( + mod.num_features, + mod.weight, + mod.bias, + float(scale), + int(zero_point), + mod.eps, + mod.affine, + ) + return new_mod + + @classmethod + def from_reference(cls, mod, scale, zero_point): + return cls( + mod.num_features, + mod.weight, + mod.bias, + float(scale), + int(zero_point), + mod.eps, + mod.affine, + ) + + +class InstanceNorm2d(torch.nn.InstanceNorm2d): + r"""This is the quantized version of :class:`~torch.nn.InstanceNorm2d`. + + Additional args: + * **scale** - quantization scale of the output, type: double. + * **zero_point** - quantization zero point of the output, type: long. + + """ + + def __init__( + self, + num_features, + weight, + bias, + scale, + zero_point, + eps=1e-5, + momentum=0.1, + affine=False, + track_running_stats=False, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__( + num_features, eps, momentum, affine, track_running_stats, **factory_kwargs + ) + self.weight = weight + self.bias = bias + # pyrefly: ignore [bad-argument-type] + self.register_buffer("scale", torch.tensor(scale, **factory_kwargs)) + # pyrefly: ignore [bad-argument-type] + self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs)) + + def forward(self, input): + return torch.ops.quantized.instance_norm( + input, self.weight, self.bias, self.eps, self.scale, self.zero_point + ) + + def _get_name(self): + return "QuantizedInstanceNorm2d" + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + scale, zero_point = mod.activation_post_process.calculate_qparams() + new_mod = cls( + mod.num_features, + mod.weight, + mod.bias, + float(scale), + int(zero_point), + mod.eps, + mod.affine, + ) + return new_mod + + @classmethod + def from_reference(cls, mod, scale, zero_point): + return cls( + mod.num_features, + mod.weight, + mod.bias, + float(scale), + int(zero_point), + mod.eps, + mod.affine, + ) + + +class InstanceNorm3d(torch.nn.InstanceNorm3d): + r"""This is the quantized version of :class:`~torch.nn.InstanceNorm3d`. + + Additional args: + * **scale** - quantization scale of the output, type: double. + * **zero_point** - quantization zero point of the output, type: long. + + """ + + def __init__( + self, + num_features, + weight, + bias, + scale, + zero_point, + eps=1e-5, + momentum=0.1, + affine=False, + track_running_stats=False, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__( + num_features, eps, momentum, affine, track_running_stats, **factory_kwargs + ) + self.weight = weight + self.bias = bias + # pyrefly: ignore [bad-argument-type] + self.register_buffer("scale", torch.tensor(scale, **factory_kwargs)) + # pyrefly: ignore [bad-argument-type] + self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs)) + + def forward(self, input): + return torch.ops.quantized.instance_norm( + input, self.weight, self.bias, self.eps, self.scale, self.zero_point + ) + + def _get_name(self): + return "QuantizedInstanceNorm3d" + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + scale, zero_point = mod.activation_post_process.calculate_qparams() + new_mod = cls( + mod.num_features, + mod.weight, + mod.bias, + float(scale), + int(zero_point), + mod.eps, + mod.affine, + ) + return new_mod + + @classmethod + def from_reference(cls, mod, scale, zero_point): + return cls( + mod.num_features, + mod.weight, + mod.bias, + float(scale), + int(zero_point), + mod.eps, + mod.affine, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/rnn.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/rnn.py new file mode 100644 index 0000000000000000000000000000000000000000..5040b8c97d050102779c742989dd4f52cd3bffa8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/rnn.py @@ -0,0 +1,59 @@ +from typing import Any + +import torch + + +__all__ = [ + "LSTM", +] + + +class LSTM(torch.ao.nn.quantizable.LSTM): + r"""A quantized long short-term memory (LSTM). + + For the description and the argument types, please, refer to :class:`~torch.nn.LSTM` + + Attributes: + layers : instances of the `_LSTMLayer` + + .. note:: + To access the weights and biases, you need to access them per layer. + See examples in :class:`~torch.ao.nn.quantizable.LSTM` + + Examples:: + >>> # xdoctest: +SKIP + >>> custom_module_config = { + ... 'float_to_observed_custom_module_class': { + ... nn.LSTM: nn.quantizable.LSTM, + ... }, + ... 'observed_to_quantized_custom_module_class': { + ... nn.quantizable.LSTM: nn.quantized.LSTM, + ... } + ... } + >>> tq.prepare(model, prepare_custom_module_class=custom_module_config) + >>> tq.convert(model, convert_custom_module_class=custom_module_config) + """ + + _FLOAT_MODULE = torch.ao.nn.quantizable.LSTM # type: ignore[assignment] + + def _get_name(self) -> str: + return "QuantizedLSTM" + + @classmethod + def from_float(cls, *args: Any, **kwargs: Any) -> None: + # The whole flow is float -> observed -> quantized + # This class does observed -> quantized only + raise NotImplementedError( + "It looks like you are trying to convert a " + "non-observed LSTM module. Please, see " + "the examples on quantizable LSTMs." + ) + + @classmethod + def from_observed(cls: type["LSTM"], other: torch.ao.nn.quantizable.LSTM) -> "LSTM": + assert isinstance(other, cls._FLOAT_MODULE) # type: ignore[has-type] + converted = torch.ao.quantization.convert( + other, inplace=False, remove_qconfig=True + ) + converted.__class__ = cls + return converted diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..330070913a7521871f123a3e076264498a6ef612 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/utils.py @@ -0,0 +1,144 @@ +# mypy: allow-untyped-defs +import abc +import collections +import itertools + +import torch +from torch.nn.modules.module import _addindent + + +__all__ = [ + "WeightedQuantizedModule", +] + + +class WeightedQuantizedModule(torch.nn.Module, metaclass=abc.ABCMeta): + """Wrapper for quantized modules than can be lowered from reference modules.""" + + @classmethod + @abc.abstractmethod + def from_reference(cls, ref_module, output_scale, output_zero_point): + raise NotImplementedError + + +def _get_weight_observer(observer): + # FakeQuantize observer + if hasattr(observer, "activation_post_process"): + observer = observer.activation_post_process + # UniformQuantizationObserverBase observer + return observer + + +def _needs_weight_clamping(observer, dtype): + observer = _get_weight_observer(observer) + if dtype in [torch.qint8, torch.quint8, torch.qint32]: + info = torch.iinfo(dtype) + return observer.quant_min > info.min or observer.quant_max < info.max + return False + + +def _clamp_weights(qweight, observer, scale, zp): + if not _needs_weight_clamping(observer, qweight.dtype): + return qweight + + observer = _get_weight_observer(observer) + min_, max_ = observer.quant_min, observer.quant_max + + # Doing this because can't use torch.ops.quantized.clamp() with per_channel qscheme yet. + qw_int_max = torch.clone(qweight.int_repr()).fill_(max_) + qw_int_min = torch.clone(qweight.int_repr()).fill_(min_) + qw_int = torch.minimum(torch.maximum(qweight.int_repr(), qw_int_min), qw_int_max) + + if observer.qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine]: + qweight = torch._make_per_tensor_quantized_tensor( + qw_int, scale.item(), zp.item() + ) + elif observer.qscheme in [ + torch.per_channel_symmetric, + torch.per_channel_affine, + torch.per_channel_affine_float_qparams, + ]: + qweight = torch._make_per_channel_quantized_tensor( + qw_int, scale, zp, axis=observer.ch_axis + ) + else: + raise ValueError("Unexpected qscheme " + observer.qscheme) + return qweight + + +def _quantize_weight(float_wt, observer): + wt_scale, wt_zp = observer.calculate_qparams() + if observer.qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine]: + qweight = torch.quantize_per_tensor( + float_wt, float(wt_scale), int(wt_zp), torch.qint8 + ) + qweight = _clamp_weights(qweight, observer, wt_scale, wt_zp) + elif observer.qscheme in [torch.per_channel_symmetric, torch.per_channel_affine]: + wt_axis = observer.ch_axis + qweight = torch.quantize_per_channel( + float_wt, + wt_scale.to(torch.double), + wt_zp.to(torch.int64), + wt_axis, + torch.qint8, + ) + qweight = _clamp_weights(qweight, observer, wt_scale, wt_zp) + elif observer.qscheme == torch.per_channel_affine_float_qparams: + qweight = torch.quantize_per_channel( + float_wt, + wt_scale.to(torch.float), + wt_zp.to(torch.float), + observer.ch_axis, + observer.dtype, + ) + qweight = _clamp_weights(qweight, observer, wt_scale, wt_zp) + else: + raise ValueError("Unexpected qscheme " + observer.qscheme) + return qweight + + +def _ntuple_from_first(n): + """Converts the argument to a tuple of size n + with the first element repeated.""" + + def parse(x): + while isinstance(x, collections.abc.Sequence): + if len(x) == n: + break + x = x[0] + return tuple(itertools.repeat(x, n)) + + return parse + + +def _hide_packed_params_repr(self, params): + # We don't want to show `PackedParams` children, hence custom + # `__repr__`. This is the same as nn.Module.__repr__, except the check + # for the `params module`. + extra_lines = [] + extra_repr = self.extra_repr() + # empty string will be split into list [''] + if extra_repr: + extra_lines = extra_repr.split("\n") + child_lines = [] + for key, module in self._modules.items(): + if isinstance(module, params): + continue + mod_str = repr(module) + mod_str = _addindent(mod_str, 2) + child_lines.append("(" + key + "): " + mod_str) + lines = extra_lines + child_lines + + main_str = self._get_name() + "(" + if lines: + # simple one-liner info, which most builtin Modules will use + if len(extra_lines) == 1 and not child_lines: + main_str += extra_lines[0] + else: + main_str += "\n " + "\n ".join(lines) + "\n" + + main_str += ")" + return main_str + + +_pair_from_first = _ntuple_from_first(2) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e1e15e9c1516d30f7ca9ee47b21b267533de75b6 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/__init__.py @@ -0,0 +1,19 @@ +from .modules import * # noqa: F403 + + +__all__ = [ + "Linear", + "Conv1d", + "Conv2d", + "Conv3d", + "ConvTranspose1d", + "ConvTranspose2d", + "ConvTranspose3d", + "RNNCell", + "LSTMCell", + "GRUCell", + "LSTM", + "GRU", + "Embedding", + "EmbeddingBag", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f1a836c7e66b1ecd9c2632ea9ae5f59dcbcc9570 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fbe97c22f5a46a5eafc1432075fc57dd44c3aa8d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/__init__.py @@ -0,0 +1,29 @@ +from .conv import ( + Conv1d, + Conv2d, + Conv3d, + ConvTranspose1d, + ConvTranspose2d, + ConvTranspose3d, +) +from .linear import Linear +from .rnn import GRU, GRUCell, LSTM, LSTMCell, RNNCell +from .sparse import Embedding, EmbeddingBag + + +__all__ = [ + "Linear", + "Conv1d", + "Conv2d", + "Conv3d", + "ConvTranspose1d", + "ConvTranspose2d", + "ConvTranspose3d", + "RNNCell", + "LSTMCell", + "GRUCell", + "LSTM", + "GRU", + "Embedding", + "EmbeddingBag", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..83511fb6762b2b330c548d864ad22900ff568bf6 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/conv.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/conv.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8bd4269e4f7cc1b286ae471e883568caeb759858 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/conv.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/linear.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/linear.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba0cdb2e84c1e2d041c92138fd187c5b7ea38ed2 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/linear.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/rnn.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/rnn.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0f6cfa9be4b8e5eba2c5a4218f7bd39d899b85c9 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/rnn.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/sparse.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/sparse.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..141b381626f644b6bdc54d18c9220ae8bc01658c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/sparse.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dfe7e93eaa23d9cd91d0933cefa56cd7276041b1 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/conv.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..3273b89cc70ab21a87a0369e71c3ceff19615111 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/conv.py @@ -0,0 +1,518 @@ +# mypy: allow-untyped-defs +from typing import Any, Literal, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.common_types import _size_1_t + +from .utils import ReferenceQuantizedModule + + +__all__ = [ + "Conv1d", + "Conv2d", + "Conv3d", + "ConvTranspose1d", + "ConvTranspose2d", + "ConvTranspose3d", +] + + +class _ConvNd(torch.nn.modules.conv._ConvNd, ReferenceQuantizedModule): + """A reference version of nn.quantized.Conv2d + we will not pack the parameters in this module, since weight packing is an + optimization for quantized backends supported in PyTorch (fbgemm/qnnpack), + this is useful when user want to use this module in other backends like Glow. + """ + + __annotations__ = {"bias": Optional[torch.Tensor]} + _IS_REFERENCE = True + + @staticmethod + def from_float(cls, float_conv, weight_qparams): + qref_conv = cls( + float_conv.in_channels, + float_conv.out_channels, + float_conv.kernel_size, # type: ignore[arg-type] + float_conv.stride, # type: ignore[arg-type] + float_conv.padding, # type: ignore[arg-type] + float_conv.dilation, # type: ignore[arg-type] + float_conv.groups, + float_conv.bias is not None, # type: ignore[arg-type] + float_conv.padding_mode, + device=float_conv.weight.device, + dtype=float_conv.weight.dtype, + weight_qparams=weight_qparams, + ) + qref_conv.weight = torch.nn.Parameter(float_conv.weight.detach()) + if float_conv.bias is not None: + qref_conv.bias = torch.nn.Parameter(float_conv.bias.detach()) + return qref_conv + + +class Conv1d(_ConvNd, nn.Conv1d): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_1_t, + stride: _size_1_t = 1, + padding: _size_1_t = 0, + dilation: _size_1_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: Literal["zeros", "reflect", "replicate", "circular"] = "zeros", + device=None, + dtype=None, + weight_qparams: dict[str, Any] | None = None, + ): + nn.Conv1d.__init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + padding_mode, + device, + dtype, + ) + self._init_weight_qparams(weight_qparams, device) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + we have: + w(float) -- quant - dequant \ + x(float) ------------- F.conv1d --- + + In the full model, we will see + w(float) -- quant - *dequant \ + x -- quant --- *dequant -- *F.conv1d --- *quant - dequant + and the backend should be able to fuse the ops with `*` into a quantized conv1d + """ + weight_quant_dequant = self.get_weight() + + result = F.conv1d( + x, + weight_quant_dequant, + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + ) + return result + + def _get_name(self): + return "QuantizedConv1d(Reference)" + + @classmethod + def from_float(cls, float_conv, weight_qparams): # type: ignore[override] + return _ConvNd.from_float(cls, float_conv, weight_qparams) + + +class Conv2d(_ConvNd, nn.Conv2d): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode="zeros", + device=None, + dtype=None, + weight_qparams: dict[str, Any] | None = None, + ): + nn.Conv2d.__init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + # pyrefly: ignore [bad-argument-type] + padding_mode, + device, + dtype, + ) + self._init_weight_qparams(weight_qparams, device) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + we have: + w(float) -- quant - dequant \ + x(float) ------------- F.conv2d --- + + In the full model, we will see + w(float) -- quant - *dequant \ + x -- quant --- *dequant -- *F.conv2d --- *quant - dequant + and the backend should be able to fuse the ops with `*` into a quantized conv2d + """ + weight_quant_dequant = self.get_weight() + + result = F.conv2d( + x, + weight_quant_dequant, + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + ) + return result + + def _get_name(self): + return "QuantizedConv2d(Reference)" + + @classmethod + def from_float(cls, float_conv, weight_qparams): # type: ignore[override] + return _ConvNd.from_float(cls, float_conv, weight_qparams) + + +class Conv3d(_ConvNd, nn.Conv3d): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode="zeros", + device=None, + dtype=None, + weight_qparams: dict[str, Any] | None = None, + ): + nn.Conv3d.__init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + # pyrefly: ignore [bad-argument-type] + padding_mode, + device, + dtype, + ) + self._init_weight_qparams(weight_qparams, device) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + we have: + w(float) -- quant - dequant \ + x(float) ------------- F.conv3d --- + + In the full model, we will see + w(float) -- quant - *dequant \ + x -- quant --- *dequant -- *F.conv3d --- *quant - dequant + and the backend should be able to fuse the ops with `*` into a quantized conv3d + """ + weight_quant_dequant = self.get_weight() + + result = F.conv3d( + x, + weight_quant_dequant, + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + ) + return result + + def _get_name(self): + return "QuantizedConv3d(Reference)" + + @classmethod + def from_float(cls, float_conv, weight_qparams): # type: ignore[override] + return _ConvNd.from_float(cls, float_conv, weight_qparams) + + +class _ConvTransposeNd(_ConvNd, torch.nn.modules.conv._ConvTransposeNd): + """A reference version of nn.quantized.ConvTranspose2d + we will not pack the parameters in this module, since weight packing is an + optimization for quantized backends supported in PyTorch (fbgemm/qnnpack), + this is useful when user want to use this module in other backends like Glow. + """ + + @staticmethod + def from_float(cls, float_conv, weight_qparams): + qref_conv = cls( + float_conv.in_channels, + float_conv.out_channels, + float_conv.kernel_size, # type: ignore[arg-type] + float_conv.stride, # type: ignore[arg-type] + float_conv.padding, # type: ignore[arg-type] + float_conv.output_padding, # type: ignore[arg-type] + float_conv.groups, + float_conv.bias is not None, # type: ignore[arg-type] + float_conv.dilation, # type: ignore[arg-type] + float_conv.padding_mode, + device=float_conv.weight.device, + dtype=float_conv.weight.dtype, + weight_qparams=weight_qparams, + ) + qref_conv.weight = torch.nn.Parameter(float_conv.weight.detach()) + if float_conv.bias is not None: + qref_conv.bias = torch.nn.Parameter(float_conv.bias.detach()) + return qref_conv + + +class ConvTranspose1d(_ConvTransposeNd, nn.ConvTranspose1d): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_1_t, + stride: _size_1_t = 1, + padding: _size_1_t = 0, + output_padding: _size_1_t = 0, + groups: int = 1, + bias: bool = True, + dilation: _size_1_t = 1, + padding_mode: Literal["zeros", "reflect", "replicate", "circular"] = "zeros", + device=None, + dtype=None, + weight_qparams: dict[str, Any] | None = None, + ): + nn.ConvTranspose1d.__init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + output_padding, + groups, + bias, + dilation, + padding_mode, + device, + dtype, + ) + self._init_weight_qparams(weight_qparams, device) + + def forward( + self, x: torch.Tensor, output_size: list[int] | None = None + ) -> torch.Tensor: + """ + we have: + w(float) -- quant - dequant \ + x(float) ------------- F.convTranspose1d --- + In the full model, we will see + w(float) -- quant - *dequant \ + x -- quant --- *dequant -- *F.convTranspose1d --- *quant - dequant + and the backend should be able to fuse the ops with `*` into a quantized conv1d + """ + + assert isinstance(self.padding, tuple) + # One cannot replace List by Tuple or Sequence in "_output_padding" because + # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`. + output_padding = self._output_padding( + input, # type: ignore[arg-type] + output_size, + self.stride, # type: ignore[arg-type] + self.padding, # type: ignore[arg-type] + self.kernel_size, # type: ignore[arg-type] + self.dilation, # type: ignore[arg-type] + ) + + weight_quant_dequant = self.get_weight() + result = F.conv_transpose1d( + x, + weight_quant_dequant, + self.bias, + self.stride, + self.padding, + output_padding, + self.groups, + self.dilation, + ) + return result + + def _get_name(self): + return "QuantizedConvTranspose1d(Reference)" + + @classmethod + def from_float(cls, float_conv, weight_qparams): # type: ignore[override] + return _ConvTransposeNd.from_float(cls, float_conv, weight_qparams) + + +class ConvTranspose2d(_ConvTransposeNd, nn.ConvTranspose2d): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + output_padding=0, + groups=1, + bias=True, + dilation=1, + padding_mode="zeros", + device=None, + dtype=None, + weight_qparams: dict[str, Any] | None = None, + ): + nn.ConvTranspose2d.__init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + output_padding, + groups, + bias, + dilation, + # pyrefly: ignore [bad-argument-type] + padding_mode, + device, + dtype, + ) + self._init_weight_qparams(weight_qparams, device) + + def forward( + self, x: torch.Tensor, output_size: list[int] | None = None + ) -> torch.Tensor: + """ + we have: + w(float) -- quant - dequant \ + x(float) ------------- F.convTranspose2d --- + In the full model, we will see + w(float) -- quant - *dequant \ + x -- quant --- *dequant -- *F.convTranspose2d --- *quant - dequant + and the backend should be able to fuse the ops with `*` into a quantized conv2d + """ + assert isinstance(self.padding, tuple) + # One cannot replace List by Tuple or Sequence in "_output_padding" because + # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`. + + output_padding = self._output_padding( + input, # type: ignore[arg-type] + output_size, + self.stride, # type: ignore[arg-type] + self.padding, # type: ignore[arg-type] + self.kernel_size, # type: ignore[arg-type] + self.dilation, # type: ignore[arg-type] + ) + + weight_quant_dequant = self.get_weight() + result = F.conv_transpose2d( + x, + weight_quant_dequant, + self.bias, + self.stride, + self.padding, + output_padding, + self.groups, + self.dilation, + ) + + return result + + def _get_name(self): + return "QuantizedConvTranspose2d(Reference)" + + @classmethod + def from_float(cls, float_conv, weight_qparams): # type: ignore[override] + return _ConvTransposeNd.from_float(cls, float_conv, weight_qparams) + + +class ConvTranspose3d(_ConvTransposeNd, nn.ConvTranspose3d): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + output_padding=0, + groups=1, + bias=True, + dilation=1, + padding_mode="zeros", + device=None, + dtype=None, + weight_qparams: dict[str, Any] | None = None, + ): + nn.ConvTranspose3d.__init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + output_padding, + groups, + bias, + dilation, + # pyrefly: ignore [bad-argument-type] + padding_mode, + device, + dtype, + ) + self._init_weight_qparams(weight_qparams, device) + + def forward( + self, x: torch.Tensor, output_size: list[int] | None = None + ) -> torch.Tensor: + """ + we have: + w(float) -- quant - dequant \ + x(float) ------------- F.convTranspose3d --- + In the full model, we will see + w(float) -- quant - *dequant \ + x -- quant --- *dequant -- *F.convTranspose3d --- *quant - dequant + and the backend should be able to fuse the ops with `*` into a quantized conv3d + """ + + assert isinstance(self.padding, tuple) + # One cannot replace List by Tuple or Sequence in "_output_padding" because + # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`. + output_padding = self._output_padding( + input, # type: ignore[arg-type] + output_size, + self.stride, # type: ignore[arg-type] + self.padding, # type: ignore[arg-type] + self.kernel_size, # type: ignore[arg-type] + self.dilation, # type: ignore[arg-type] + ) + + weight_quant_dequant = self.get_weight() + result = F.conv_transpose3d( + x, + weight_quant_dequant, + self.bias, + self.stride, + self.padding, + output_padding, + self.groups, + self.dilation, + ) + return result + + def _get_name(self): + return "QuantizedConvTranspose3d(Reference)" + + @classmethod + def from_float(cls, float_conv, weight_qparams): # type: ignore[override] + return _ConvTransposeNd.from_float(cls, float_conv, weight_qparams) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/linear.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..6014fab24036c30b183f5622d12aae4a345baedb --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/linear.py @@ -0,0 +1,69 @@ +from typing import Any + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .utils import ReferenceQuantizedModule + + +__all__ = ["Linear"] + + +class Linear(nn.Linear, ReferenceQuantizedModule): + """A reference quantized linear module that fits into the FX + Graph Mode Quantization workflow + activation will be floating point Tensor, we will store floating + point weight as well in the module, but in forward we'll quantize + and dequantize the weight before running the floating point functional + linear operator. + """ + + _IS_REFERENCE = True + + def __init__( + self, + in_features: int, + out_features: int, + bias_: bool = True, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + weight_qparams: dict[str, Any] | None = None, + ) -> None: + super().__init__(in_features, out_features, bias_, device, dtype) + self._init_weight_qparams(weight_qparams, device) + + def _get_name(self) -> str: + return "QuantizedLinear(Reference)" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + we have: + w(float) -- quant - dequant \ + x(float) ------------- F.linear --- + + In the full model, we will see + w(float) -- quant - *dequant \ + x -- quant --- *dequant -- *F.linear --- *quant - dequant + and the backend should be able to fuse the ops with `*` into a quantized linear + """ + weight_quant_dequant = self.get_weight() + result = F.linear(x, weight_quant_dequant, self.bias) + return result + + @classmethod + def from_float( + cls, float_linear: nn.Linear, weight_qparams: dict[str, Any] + ) -> "Linear": + qref_linear = Linear( + float_linear.in_features, + float_linear.out_features, + float_linear.bias is not None, + device=float_linear.weight.device, + dtype=float_linear.weight.dtype, + weight_qparams=weight_qparams, + ) + qref_linear.weight = torch.nn.Parameter(float_linear.weight.detach()) + if float_linear.bias is not None: + qref_linear.bias = torch.nn.Parameter(float_linear.bias.detach()) + return qref_linear diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/rnn.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/rnn.py new file mode 100644 index 0000000000000000000000000000000000000000..1bdbfb81430b4db9e09ea752310732b89f47bfa1 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/rnn.py @@ -0,0 +1,861 @@ +# mypy: allow-untyped-defs +from typing import Any + +import torch +import torch.nn as nn +from torch import _VF, Tensor +from torch.nn.utils.rnn import PackedSequence + +from .utils import _quantize_and_dequantize_weight, _quantize_weight + + +__all__ = [ + "RNNCellBase", + "RNNCell", + "LSTMCell", + "GRUCell", + "RNNBase", + "LSTM", + "GRU", + "get_quantized_weight", +] + + +def _apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor: + return tensor.index_select(dim, permutation) + + +def _get_weight_and_quantization_params(module, wn): + weight = getattr(module, wn) + params = [weight] + for param_name in [ + wn + n for n in ["_qscheme", "_dtype", "_scale", "_zero_point", "_axis_int"] + ]: + if hasattr(module, param_name): + param = getattr(module, param_name) + else: + param = None + params.append(param) + return params + + +def get_quantized_weight(module, wn): + if not hasattr(module, wn): + return None + params = _get_weight_and_quantization_params(module, wn) + weight = _quantize_weight(*params) + return weight + + +def _get_quantize_and_dequantized_weight(module, wn): + if not hasattr(module, wn): + return None + params = _get_weight_and_quantization_params(module, wn) + weight = _quantize_and_dequantize_weight(*params) + return weight + + +class RNNCellBase(nn.RNNCellBase): + def __init__( + self, + input_size: int, + hidden_size: int, + bias: bool, + num_chunks: int, + device=None, + dtype=None, + weight_qparams_dict=None, + ) -> None: + super().__init__( + input_size, hidden_size, bias, num_chunks, device=device, dtype=dtype + ) + # TODO(jerryzh168): maybe make this arg a required arg + if weight_qparams_dict is None: + weight_qparams = { + "qscheme": torch.per_tensor_affine, + "dtype": torch.quint8, + "scale": 1.0, + "zero_point": 0, + } + weight_qparams_dict = { + "weight_ih": weight_qparams, + "weight_hh": weight_qparams, + "is_decomposed": False, + } + assert len(weight_qparams_dict) == 3, ( + "Expected length for weight_qparams_dict to be 3 for QuantizedRNNCellBase(Reference)" + ) + self._init_weight_qparams_dict(weight_qparams_dict, device) + + def _init_weight_qparams_dict(self, weight_qparams_dict, device): + assert weight_qparams_dict is not None + self.is_decomposed = weight_qparams_dict["is_decomposed"] + for key, weight_qparams in weight_qparams_dict.items(): + if key == "is_decomposed": + continue + # TODO: refactor the duplicated code to utils.py + weight_qscheme = weight_qparams["qscheme"] + weight_dtype = weight_qparams["dtype"] + setattr(self, key + "_qscheme", weight_qscheme) + setattr(self, key + "_dtype", weight_dtype) + assert weight_qscheme in [ + None, + torch.per_tensor_affine, + torch.per_channel_affine, + ], Exception( + f"qscheme: {weight_qscheme} is not support in {self._get_name()}" + ) + if weight_qscheme is not None: + scale = weight_qparams["scale"] + scale_tensor = ( + scale.detach().clone() + if isinstance(scale, torch.Tensor) + else torch.tensor(scale, dtype=torch.float, device=device) + ) + self.register_buffer(key + "_scale", scale_tensor) + zp = weight_qparams["zero_point"] + zp_tensor = ( + zp.detach().clone() + if isinstance(zp, torch.Tensor) + else torch.tensor(zp, dtype=torch.int, device=device) + ) + self.register_buffer(key + "_zero_point", zp_tensor) + if weight_qscheme == torch.per_channel_affine: + axis = weight_qparams["axis"] + axis_tensor = ( + axis.detach().clone() + if isinstance(axis, torch.Tensor) + else torch.tensor(axis, dtype=torch.int, device=device) + ) + self.register_buffer(key + "_axis", axis_tensor) + else: + # added for TorchScriptability, not used + self.register_buffer( + key + "_axis", torch.tensor(0, dtype=torch.int, device=device) + ) + setattr(self, key + "_axis_int", getattr(self, key + "_axis").item()) + + def _get_name(self): + return "QuantizedRNNCellBase(Reference)" + + def get_quantized_weight_ih(self): + return get_quantized_weight(self, "weight_ih") + + def get_quantized_weight_hh(self): + return get_quantized_weight(self, "weight_hh") + + def get_weight_ih(self): + return _get_quantize_and_dequantized_weight(self, "weight_ih") + + def get_weight_hh(self): + return _get_quantize_and_dequantized_weight(self, "weight_hh") + + +class RNNCell(RNNCellBase): + """ + We'll store weight_qparams for all the weights (weight_ih and weight_hh), + we need to pass in a `weight_qparams_dict` that maps from weight name, + e.g. weight_ih, to the weight_qparams for that weight + """ + + def __init__( + self, + input_size: int, + hidden_size: int, + bias: bool = True, + nonlinearity: str = "tanh", + device=None, + dtype=None, + weight_qparams_dict: dict[str, Any] | None = None, + ) -> None: + factory_kwargs = { + "device": device, + "dtype": dtype, + "weight_qparams_dict": weight_qparams_dict, + } + super().__init__(input_size, hidden_size, bias, num_chunks=1, **factory_kwargs) + self.nonlinearity = nonlinearity + + def _get_name(self): + return "QuantizedRNNCell(Reference)" + + # TODO: refactor nn.RNNCell to have a _forward that takes weight_ih and weight_hh as input + # and remove duplicated code, same for the other two Cell modules + def forward(self, input: Tensor, hx: Tensor | None = None) -> Tensor: + assert input.dim() in ( + 1, + 2, + ), ( + f"RNNCell: Expected input to be 1-D or 2-D but received {input.dim()}-D tensor" + ) + is_batched = input.dim() == 2 + if not is_batched: + input = input.unsqueeze(0) + + if hx is None: + hx = torch.zeros( + input.size(0), self.hidden_size, dtype=input.dtype, device=input.device + ) + else: + hx = hx.unsqueeze(0) if not is_batched else hx + + if self.nonlinearity == "tanh": + ret = _VF.rnn_tanh_cell( + input, + hx, + self.get_weight_ih(), + self.get_weight_hh(), + self.bias_ih, + self.bias_hh, + ) + elif self.nonlinearity == "relu": + ret = _VF.rnn_relu_cell( + input, + hx, + self.get_weight_ih(), + self.get_weight_hh(), + self.bias_ih, + self.bias_hh, + ) + else: + ret = input # TODO: remove when jit supports exception flow + raise RuntimeError(f"Unknown nonlinearity: {self.nonlinearity}") + + if not is_batched: + ret = ret.squeeze(0) + + return ret + + @classmethod + def from_float(cls, mod, weight_qparams_dict): + ref_mod = cls( + mod.input_size, + mod.hidden_size, + mod.bias, + mod.nonlinearity, + mod.weight_ih.device, + mod.weight_ih.dtype, + weight_qparams_dict, + ) + ref_mod.weight_ih = mod.weight_ih + ref_mod.weight_hh = mod.weight_hh + ref_mod.bias_ih = mod.bias_ih + ref_mod.bias_hh = mod.bias_hh + return ref_mod + + +class LSTMCell(RNNCellBase): + """ + We'll store weight_qparams for all the weights (weight_ih and weight_hh), + we need to pass in a `weight_qparams_dict` that maps from weight name, + e.g. weight_ih, to the weight_qparams for that weight + """ + + def __init__( + self, + input_size: int, + hidden_size: int, + bias: bool = True, + device=None, + dtype=None, + weight_qparams_dict: dict[str, Any] | None = None, + ) -> None: + factory_kwargs = { + "device": device, + "dtype": dtype, + "weight_qparams_dict": weight_qparams_dict, + } + super().__init__(input_size, hidden_size, bias, num_chunks=4, **factory_kwargs) + + def _get_name(self): + return "QuantizedLSTMCell(Reference)" + + def forward( + self, input: Tensor, hx: tuple[Tensor, Tensor] | None = None + ) -> tuple[Tensor, Tensor]: + assert input.dim() in ( + 1, + 2, + ), ( + f"LSTMCell: Expected input to be 1-D or 2-D but received {input.dim()}-D tensor" + ) + is_batched = input.dim() == 2 + if not is_batched: + input = input.unsqueeze(0) + + if hx is None: + zeros = torch.zeros( + input.size(0), self.hidden_size, dtype=input.dtype, device=input.device + ) + hx = (zeros, zeros) + else: + hx = (hx[0].unsqueeze(0), hx[1].unsqueeze(0)) if not is_batched else hx + + ret = _VF.lstm_cell( + input, + hx, + self.get_weight_ih(), + self.get_weight_hh(), + self.bias_ih, + self.bias_hh, + ) + + if not is_batched: + ret = (ret[0].squeeze(0), ret[1].squeeze(0)) + return ret + + @classmethod + def from_float(cls, mod, weight_qparams_dict, use_precomputed_fake_quant=False): + ref_mod = cls( + mod.input_size, + mod.hidden_size, + mod.bias, + mod.weight_ih.device, + mod.weight_ih.dtype, + weight_qparams_dict, + ) + ref_mod.weight_ih = mod.weight_ih + ref_mod.weight_hh = mod.weight_hh + ref_mod.bias_ih = mod.bias_ih + ref_mod.bias_hh = mod.bias_hh + return ref_mod + + +class GRUCell(RNNCellBase): + """ + We'll store weight_qparams for all the weights (weight_ih and weight_hh), + we need to pass in a `weight_qparams_dict` that maps from weight name, + e.g. weight_ih, to the weight_qparams for that weight + """ + + def __init__( + self, + input_size: int, + hidden_size: int, + bias: bool = True, + device=None, + dtype=None, + weight_qparams_dict: dict[str, Any] | None = None, + ) -> None: + factory_kwargs = { + "device": device, + "dtype": dtype, + "weight_qparams_dict": weight_qparams_dict, + } + super().__init__(input_size, hidden_size, bias, num_chunks=3, **factory_kwargs) + + def _get_name(self): + return "QuantizedGRUCell(Reference)" + + def forward(self, input: Tensor, hx: Tensor | None = None) -> Tensor: + assert input.dim() in ( + 1, + 2, + ), ( + f"GRUCell: Expected input to be 1-D or 2-D but received {input.dim()}-D tensor" + ) + is_batched = input.dim() == 2 + if not is_batched: + input = input.unsqueeze(0) + + if hx is None: + hx = torch.zeros( + input.size(0), self.hidden_size, dtype=input.dtype, device=input.device + ) + else: + hx = hx.unsqueeze(0) if not is_batched else hx + + ret = _VF.gru_cell( + input, + hx, + self.get_weight_ih(), + self.get_weight_hh(), + self.bias_ih, + self.bias_hh, + ) + + if not is_batched: + ret = ret.squeeze(0) + + return ret + + @classmethod + def from_float(cls, mod, weight_qparams_dict): + ref_mod = cls( + mod.input_size, + mod.hidden_size, + mod.bias, + mod.weight_ih.device, + mod.weight_ih.dtype, + weight_qparams_dict, + ) + ref_mod.weight_ih = mod.weight_ih + ref_mod.weight_hh = mod.weight_hh + ref_mod.bias_ih = mod.bias_ih + ref_mod.bias_hh = mod.bias_hh + return ref_mod + + +class RNNBase(nn.RNNBase): + def __init__( + self, + mode: str, + input_size: int, + hidden_size: int, + num_layers: int = 1, + bias: bool = True, + batch_first: bool = False, + dropout: float = 0.0, + bidirectional: bool = False, + proj_size: int = 0, + device=None, + dtype=None, + weight_qparams_dict: dict[str, Any] | None = None, + ) -> None: + super().__init__( + mode, + input_size, + hidden_size, + num_layers, + bias, + batch_first, + dropout, + bidirectional, + proj_size, + device, + dtype, + ) + # TODO(jerryzh168): maybe make this arg a required arg + if weight_qparams_dict is None: + weight_qparams = { + "qscheme": torch.per_tensor_affine, + "dtype": torch.quint8, + "scale": 1.0, + "zero_point": 0, + } + weight_qparams_dict = {"is_decomposed": False} # type: ignore[dict-item] + for wn in self._flat_weights_names: + if wn.startswith("weight"): + weight_qparams_dict[wn] = weight_qparams + self._init_weight_qparams_dict(weight_qparams_dict, device) + + def _init_weight_qparams_dict(self, weight_qparams_dict, device): + self.is_decomposed = weight_qparams_dict["is_decomposed"] + for key, weight_qparams in weight_qparams_dict.items(): + if key == "is_decomposed": + continue + weight_qscheme = weight_qparams["qscheme"] + weight_dtype = weight_qparams["dtype"] + setattr(self, key + "_qscheme", weight_qscheme) + setattr(self, key + "_dtype", weight_dtype) + assert weight_qscheme in [ + None, + torch.per_tensor_affine, + torch.per_channel_affine, + ], Exception( + f"qscheme: {weight_qscheme} is not support in {self._get_name()}" + ) + if weight_qscheme is not None: + self.register_buffer( + key + "_scale", + torch.tensor( + weight_qparams["scale"], dtype=torch.float, device=device + ), + ) + self.register_buffer( + key + "_zero_point", + torch.tensor( + weight_qparams["zero_point"], dtype=torch.int, device=device + ), + ) + if weight_qscheme == torch.per_channel_affine: + self.register_buffer( + key + "_axis", + torch.tensor( + weight_qparams["axis"], dtype=torch.int, device=device + ), + ) + else: + # added for TorchScriptability, not used + self.register_buffer( + key + "_axis", torch.tensor(0, dtype=torch.int, device=device) + ) + setattr(self, key + "_axis_int", getattr(self, key + "_axis").item()) + + +class LSTM(RNNBase): + """Reference Quantized LSTM Module + We'll store weight_qparams for all the weights in _flat_weights, we need to pass in + a `weight_qparams_dict` that maps from weight name, e.g. weight_ih_l0, + to the weight_qparams for that weight + """ + + def __init__(self, *args, **kwargs): + super().__init__("LSTM", *args, **kwargs) + + # Same as above, see torch/nn/modules/module.py::_forward_unimplemented + def permute_hidden( # type: ignore[override] + self, + hx: tuple[Tensor, Tensor], + permutation: Tensor | None, + ) -> tuple[Tensor, Tensor]: + if permutation is None: + return hx + return _apply_permutation(hx[0], permutation), _apply_permutation( + hx[1], permutation + ) + + def get_expected_cell_size( + self, input: Tensor, batch_sizes: Tensor | None + ) -> tuple[int, int, int]: + if batch_sizes is not None: + mini_batch = int(batch_sizes[0]) + else: + mini_batch = input.size(0) if self.batch_first else input.size(1) + num_directions = 2 if self.bidirectional else 1 + expected_hidden_size = ( + self.num_layers * num_directions, + mini_batch, + self.hidden_size, + ) + return expected_hidden_size + + # In the future, we should prevent mypy from applying contravariance rules here. + # See torch/nn/modules/module.py::_forward_unimplemented + def check_forward_args( # type: ignore[override] + self, + input: Tensor, + hidden: tuple[Tensor, Tensor], + batch_sizes: Tensor | None, + ): + self.check_input(input, batch_sizes) + self.check_hidden_size( + hidden[0], + self.get_expected_hidden_size(input, batch_sizes), + "Expected hidden[0] size {}, got {}", + ) + self.check_hidden_size( + hidden[1], + self.get_expected_cell_size(input, batch_sizes), + "Expected hidden[1] size {}, got {}", + ) + + def get_quantized_weight_bias_dict(self): + """dictionary from flat_weight_name to quantized weight or (unquantized) bias + e.g. + { + "weight_ih_l0": quantized_weight, + "bias_ih_l0": unquantized_bias, + ... + } + """ + quantized_weight_bias_dict = {} + for wn in self._flat_weights_names: + if hasattr(self, wn): + if wn.startswith("weight"): + weight_or_bias = get_quantized_weight(self, wn) + else: + weight_or_bias = getattr(self, wn) + else: + weight_or_bias = None + quantized_weight_bias_dict[wn] = weight_or_bias + return quantized_weight_bias_dict + + def get_flat_weights(self): + flat_weights = [] + for wn in self._flat_weights_names: + if hasattr(self, wn): + weight = getattr(self, wn) + if wn.startswith("weight"): + params = _get_weight_and_quantization_params(self, wn) + weight = _quantize_and_dequantize_weight(*params) + else: + weight = None + flat_weights.append(weight) + return flat_weights + + def forward(self, input, hx=None): # noqa: F811 + orig_input = input + # xxx: isinstance check needs to be in conditional for TorchScript to compile + batch_sizes = None + if isinstance(orig_input, PackedSequence): + input, batch_sizes, sorted_indices, unsorted_indices = input + max_batch_size = int(batch_sizes[0]) + else: + batch_sizes = None + is_batched = input.dim() == 3 + batch_dim = 0 if self.batch_first else 1 + if not is_batched: + input = input.unsqueeze(batch_dim) + max_batch_size = input.size(0) if self.batch_first else input.size(1) + sorted_indices = None + unsorted_indices = None + + if hx is None: + num_directions = 2 if self.bidirectional else 1 + real_hidden_size = ( + self.proj_size if self.proj_size > 0 else self.hidden_size + ) + h_zeros = torch.zeros( + self.num_layers * num_directions, + max_batch_size, + real_hidden_size, + dtype=input.dtype, + device=input.device, + ) + c_zeros = torch.zeros( + self.num_layers * num_directions, + max_batch_size, + self.hidden_size, + dtype=input.dtype, + device=input.device, + ) + hx = (h_zeros, c_zeros) + else: + if batch_sizes is None: # If not PackedSequence input. + if is_batched: # type: ignore[possibly-undefined] + if hx[0].dim() != 3 or hx[1].dim() != 3: + msg = ( + "For batched 3-D input, hx and cx should " + f"also be 3-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors" + ) + raise RuntimeError(msg) + else: + if hx[0].dim() != 2 or hx[1].dim() != 2: + msg = ( + "For unbatched 2-D input, hx and cx should " + f"also be 2-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors" + ) + raise RuntimeError(msg) + hx = (hx[0].unsqueeze(1), hx[1].unsqueeze(1)) + + # Each batch of the hidden state should match the input sequence that + # the user believes he/she is passing in. + hx = self.permute_hidden(hx, sorted_indices) + + self.check_forward_args(input, hx, batch_sizes) + if batch_sizes is None: + result = _VF.lstm( + input, + hx, + self.get_flat_weights(), + self.bias, + self.num_layers, + self.dropout, + self.training, + self.bidirectional, + self.batch_first, + ) + else: + result = _VF.lstm( + input, + batch_sizes, + hx, + self.get_flat_weights(), + self.bias, + self.num_layers, + self.dropout, + self.training, + self.bidirectional, + ) + output = result[0] + hidden = result[1:] + # xxx: isinstance check needs to be in conditional for TorchScript to compile + if isinstance(orig_input, PackedSequence): + output_packed = PackedSequence( + output, + # pyrefly: ignore [bad-argument-type] + batch_sizes, + sorted_indices, + unsorted_indices, + ) + return output_packed, self.permute_hidden(hidden, unsorted_indices) + else: + if not is_batched: # type: ignore[possibly-undefined] + output = output.squeeze(batch_dim) # type: ignore[possibly-undefined] + hidden = (hidden[0].squeeze(1), hidden[1].squeeze(1)) + return output, self.permute_hidden(hidden, unsorted_indices) + + def _get_name(self): + return "QuantizedLSTM(Reference)" + + @classmethod + def from_float(cls, mod, weight_qparams_dict): + ref_mod = cls( + mod.input_size, + mod.hidden_size, + mod.num_layers, + mod.bias, + mod.batch_first, + mod.dropout, + mod.bidirectional, + weight_qparams_dict=weight_qparams_dict, + ) + for wn in mod._flat_weights_names: + setattr(ref_mod, wn, getattr(mod, wn)) + return ref_mod + + +class GRU(RNNBase): + """Reference Quantized GRU Module + We'll store weight_qparams for all the weights in _flat_weights, we need to pass in + a `weight_qparams_dict` that maps from weight name, e.g. weight_ih_l0, + to the weight_qparams for that weight + """ + + def __init__(self, *args, **kwargs): + if "proj_size" in kwargs: + raise ValueError( + "proj_size argument is only supported for LSTM, not RNN or GRU" + ) + super().__init__("GRU", *args, **kwargs) + + def get_quantized_weight_bias_dict(self): + """dictionary from flat_weight_name to quantized weight or (unquantized) bias + e.g. + { + "weight_ih_l0": quantized_weight, + "bias_ih_l0": unquantized_bias, + ... + } + """ + quantized_weight_bias_dict = {} + for wn in self._flat_weights_names: + if hasattr(self, wn): + if wn.startswith("weight"): + weight_or_bias = get_quantized_weight(self, wn) + else: + weight_or_bias = getattr(self, wn) + else: + weight_or_bias = None + quantized_weight_bias_dict[wn] = weight_or_bias + return quantized_weight_bias_dict + + def get_flat_weights(self): + flat_weights = [] + for wn in self._flat_weights_names: + if hasattr(self, wn): + weight = getattr(self, wn) + if wn.startswith("weight"): + params = _get_weight_and_quantization_params(self, wn) + weight = _quantize_and_dequantize_weight(*params) + else: + weight = None + flat_weights.append(weight) + return flat_weights + + def forward(self, input, hx=None): # noqa: F811 + # Note: this is copied from the forward of GRU in https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py + # only changed self._flat_weights to self.get_flat_weights() + # TODO: maybe we can try inheriting from that class and define get_flat_weights + # as a @property? this might interfere with TorchScript, if we remove that + # requirement in the future we should be able to do this + orig_input = input + # xxx: isinstance check needs to be in conditional for TorchScript to compile + if isinstance(orig_input, PackedSequence): + input, batch_sizes, sorted_indices, unsorted_indices = input + max_batch_size = int(batch_sizes[0]) + else: + batch_sizes = None + assert input.dim() in ( + 2, + 3, + ), ( + f"GRU: Expected input to be 2-D or 3-D but received {input.dim()}-D tensor" + ) + is_batched = input.dim() == 3 + batch_dim = 0 if self.batch_first else 1 + if not is_batched: + input = input.unsqueeze(batch_dim) + if hx is not None: + if hx.dim() != 2: + raise RuntimeError( + f"For unbatched 2-D input, hx should also be 2-D but got {hx.dim()}-D tensor" + ) + hx = hx.unsqueeze(1) + else: + if hx is not None and hx.dim() != 3: + raise RuntimeError( + f"For batched 3-D input, hx should also be 3-D but got {hx.dim()}-D tensor" + ) + max_batch_size = input.size(0) if self.batch_first else input.size(1) + sorted_indices = None + unsorted_indices = None + + if hx is None: + num_directions = 2 if self.bidirectional else 1 + hx = torch.zeros( + self.num_layers * num_directions, + max_batch_size, + self.hidden_size, + dtype=input.dtype, + device=input.device, + ) + else: + # Each batch of the hidden state should match the input sequence that + # the user believes he/she is passing in. + hx = self.permute_hidden(hx, sorted_indices) + + self.check_forward_args(input, hx, batch_sizes) + if batch_sizes is None: + result = _VF.gru( + input, + hx, + self.get_flat_weights(), + self.bias, + self.num_layers, + self.dropout, + self.training, + self.bidirectional, + self.batch_first, + ) + else: + result = _VF.gru( + input, + batch_sizes, + hx, + self.get_flat_weights(), + self.bias, + self.num_layers, + self.dropout, + self.training, + self.bidirectional, + ) + output = result[0] + hidden = result[1] + + # xxx: isinstance check needs to be in conditional for TorchScript to compile + if isinstance(orig_input, PackedSequence): + output_packed = PackedSequence( + output, + # pyrefly: ignore [bad-argument-type] + batch_sizes, + sorted_indices, + unsorted_indices, + ) + return output_packed, self.permute_hidden(hidden, unsorted_indices) + else: + if not is_batched: # type: ignore[possibly-undefined] + output = output.squeeze(batch_dim) # type: ignore[possibly-undefined] + hidden = hidden.squeeze(1) + + return output, self.permute_hidden(hidden, unsorted_indices) + + def _get_name(self): + return "QuantizedGRU(Reference)" + + @classmethod + def from_float(cls, mod, weight_qparams_dict): + ref_mod = cls( + mod.input_size, + mod.hidden_size, + mod.num_layers, + mod.bias, + mod.batch_first, + mod.dropout, + mod.bidirectional, + weight_qparams_dict=weight_qparams_dict, + ) + for wn in mod._flat_weights_names: + setattr(ref_mod, wn, getattr(mod, wn)) + return ref_mod diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/sparse.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/sparse.py new file mode 100644 index 0000000000000000000000000000000000000000..8ff80997c1439c50a456df328b4068ae0c419a01 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/sparse.py @@ -0,0 +1,163 @@ +# mypy: allow-untyped-defs +from typing import Any + +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +from .utils import ReferenceQuantizedModule + + +__all__ = ["Embedding", "EmbeddingBag"] + + +class Embedding(nn.Embedding, ReferenceQuantizedModule): + """A reference quantized Embedding module that fits into the + FX Graph Mode Quantization workflow, activation will be floating point Tensor, + we will store floating point weight as well in the module, but in forward we'll + quantize and dequantize the weight before running the floating point functional + embedding operator. + """ + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int | None = None, + max_norm: float | None = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + sparse: bool = False, + _weight: Tensor | None = None, + device=None, + dtype=None, + weight_qparams: dict[str, Any] | None = None, + ) -> None: + super().__init__( + num_embeddings, + embedding_dim, + padding_idx, + max_norm, + norm_type, + scale_grad_by_freq, + sparse, + _weight, + # pyrefly: ignore [bad-argument-type] + device, + dtype, + ) + self._init_weight_qparams(weight_qparams, device) + + def _get_name(self): + return "QuantizedEmbedding(Reference)" + + def forward(self, input: Tensor) -> Tensor: + weight_quant_dequant = self.get_weight() + return F.embedding( + input, + weight_quant_dequant, + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + ) + + @classmethod + def from_float(cls, mod, weight_qparams): + return cls( + mod.num_embeddings, + mod.embedding_dim, + mod.padding_idx, + mod.max_norm, + mod.norm_type, + mod.scale_grad_by_freq, + mod.sparse, + mod.weight, + mod.weight.device, + mod.weight.dtype, + weight_qparams, + ) + + +class EmbeddingBag(nn.EmbeddingBag, ReferenceQuantizedModule): + """A reference quantized EmbeddingBag module that fits into the + FX Graph Mode Quantization workflow, activation will be floating point Tensor, + we will store floating point weight as well in the module, but in forward we'll + quantize and dequantize the weight before running the floating point functional + embedding operator. + """ + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + max_norm: float | None = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + mode: str = "mean", + sparse: bool = False, + _weight: Tensor | None = None, + include_last_offset: bool = False, + padding_idx: int | None = None, + device=None, + dtype=None, + weight_qparams: dict[str, Any] | None = None, + ) -> None: + super().__init__( + num_embeddings, + embedding_dim, + max_norm, + norm_type, + scale_grad_by_freq, + mode, + sparse, + _weight, + include_last_offset, + padding_idx, + device, + dtype, + ) + self._init_weight_qparams(weight_qparams, device) + + def _get_name(self): + return "QuantizedEmbedding(Reference)" + + def forward( + self, + input: Tensor, + offsets: Tensor | None = None, + per_sample_weights: Tensor | None = None, + ) -> Tensor: + weight_quant_dequant = self.get_weight() + return F.embedding_bag( + input, + weight_quant_dequant, + offsets, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.mode, + self.sparse, + per_sample_weights, + self.include_last_offset, + self.padding_idx, + ) + + @classmethod + def from_float(cls, mod, weight_qparams, use_precomputed_fake_quant=False): + return cls( + mod.num_embeddings, + mod.embedding_dim, + mod.max_norm, + mod.norm_type, + mod.scale_grad_by_freq, + mod.mode, + mod.sparse, + mod.weight, + mod.include_last_offset, + mod.padding_idx, + mod.weight.device, + mod.weight.dtype, + weight_qparams, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7bdbcd4a6739e528e679c67b6a6614ea373801d3 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/utils.py @@ -0,0 +1,438 @@ +# mypy: allow-untyped-defs +import typing + +import torch + + +__all__ = [ + "ReferenceQuantizedModule", +] + + +class ReferenceQuantizedModule(torch.nn.Module): + def _init_weight_qparams(self, weight_qparams, device): + if weight_qparams is None: + weight_qparams = { + "qscheme": torch.per_tensor_affine, + "dtype": torch.quint8, + "scale": 1.0, + "zero_point": 0, + } + # pyrefly: ignore [bad-assignment] + self.weight_qscheme: torch.qscheme = weight_qparams["qscheme"] + self.weight_dtype = weight_qparams["dtype"] + assert self.weight_qscheme in [ + None, + torch.per_tensor_affine, + torch.per_channel_affine, + torch.per_channel_affine_float_qparams, + ], ( + f"qscheme: {self.weight_qscheme} is not support in reference quantized {self._get_name()}" + ) + if self.weight_dtype in [ + torch.quint8, + torch.qint8, + torch.quint4x2, + torch.qint32, + ]: + zero_point_dtype = ( + weight_qparams["zero_point"].dtype + if isinstance(weight_qparams["zero_point"], torch.Tensor) + else torch.int + ) + w_scale = weight_qparams["scale"] + w_scale_tensor = ( + w_scale.detach().clone() + if isinstance(w_scale, torch.Tensor) + else torch.tensor(w_scale, dtype=torch.float, device=device) + ) + self.register_buffer("weight_scale", w_scale_tensor) + w_zp = weight_qparams["zero_point"] + w_zp_tensor = ( + w_zp.detach().clone() + if isinstance(w_zp, torch.Tensor) + else torch.tensor(w_zp, dtype=zero_point_dtype, device=device) + ) + self.register_buffer("weight_zero_point", w_zp_tensor) + if self.weight_qscheme in [ + torch.per_channel_affine, + torch.per_channel_affine_float_qparams, + ]: + w_axis = weight_qparams["axis"] + w_axis_tensor = ( + w_axis.detach().clone() + if isinstance(w_axis, torch.Tensor) + else torch.tensor(w_axis, dtype=torch.int, device=device) + ) + self.register_buffer("weight_axis", w_axis_tensor) + else: + # added for TorchScriptability, not used + self.register_buffer( + "weight_axis", torch.tensor(0, dtype=torch.int, device=device) + ) + else: + # added for TorchScriptability, and for torch.float + self.register_buffer( + "weight_scale", torch.tensor(1.0, dtype=torch.float, device=device) + ) + self.register_buffer( + "weight_zero_point", torch.tensor(0, dtype=torch.int, device=device) + ) + self.register_buffer( + "weight_axis", torch.tensor(0, dtype=torch.int, device=device) + ) + # pyrefly: ignore [bad-assignment] + self.is_decomposed: bool = weight_qparams.get("is_decomposed", False) + # store weight_axis as weight_axis_int due to some constraints of torchdynamo.export + # for capturing `.item` operations + self.weight_axis_int: int = self.weight_axis.item() # type: ignore[operator, assignment] + # pyrefly: ignore [bad-assignment] + self.weight_quant_min: int | None = weight_qparams.get("quant_min") + # pyrefly: ignore [bad-assignment] + self.weight_quant_max: int | None = weight_qparams.get("quant_max") + + def get_weight(self): + """ + Fake quantize (quantize and dequantize) the weight with + the quantization parameters for weight, this is used to + simulate the numerics for the quantized weight in a quantized + model + """ + # suppress mypy warning + assert isinstance(self.weight_scale, torch.Tensor) + assert isinstance(self.weight_zero_point, torch.Tensor) + if self.is_decomposed: + return _quantize_and_dequantize_weight_decomposed( + self.weight, # type: ignore[arg-type] + self.weight_qscheme, + # pyrefly: ignore [bad-argument-type] + self.weight_dtype, + self.weight_scale, + self.weight_zero_point, + self.weight_axis_int, + self.weight_quant_min, + self.weight_quant_max, + ) + else: + return _quantize_and_dequantize_weight( + self.weight, # type: ignore[arg-type] + self.weight_qscheme, + # pyrefly: ignore [bad-argument-type] + self.weight_dtype, + self.weight_scale, + self.weight_zero_point, + self.weight_axis_int, + ) + + def get_quantized_weight(self): + # suppress mypy warning + assert isinstance(self.weight_scale, torch.Tensor) + assert isinstance(self.weight_zero_point, torch.Tensor) + # assert isinstance(self.weight_axis, torch.Tensor) + if self.is_decomposed: + return _quantize_weight_decomposed( + self.weight, # type: ignore[arg-type] + self.weight_qscheme, + # pyrefly: ignore [bad-argument-type] + self.weight_dtype, + self.weight_scale, + self.weight_zero_point, + self.weight_axis_int, + self.weight_quant_min, + self.weight_quant_max, + ) + else: + return _quantize_weight( + self.weight, # type: ignore[arg-type] + self.weight_qscheme, + # pyrefly: ignore [bad-argument-type] + self.weight_dtype, + self.weight_scale, + self.weight_zero_point, + self.weight_axis_int, + ) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + super()._save_to_state_dict(destination, prefix, keep_vars) + _save_weight_qparams( + destination, + prefix, + self.weight_qscheme, + self.weight_dtype, + self.weight_scale, + self.weight_zero_point, + self.weight_axis, + ) + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + for key in _get_weight_qparam_keys(state_dict, prefix): + setattr(self, key, state_dict[prefix + key]) + state_dict.pop(prefix + key) + + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + False, + missing_keys, + unexpected_keys, + error_msgs, + ) + + +def _quantize_weight_decomposed( + weight: torch.Tensor, + weight_qscheme: torch.qscheme, + weight_dtype: torch.dtype, + weight_scale: torch.Tensor, + weight_zero_point: torch.Tensor, + weight_axis: int, + weight_quant_min: int | None, + weight_quant_max: int | None, +) -> torch.Tensor: + _DTYPE_TO_QVALUE_BOUNDS: dict[torch.dtype, tuple[int, int]] = { + torch.uint8: (0, 255), + torch.int8: (-128, 127), + torch.int32: (-2147483648, 2147483647), # torch.jit interprets 2**31 as a float + } + + # TODO: add an util function for converting qdtype to dtype + _QDTYPE_TO_UNDERLYING_INT_REPR_DTYPE = { + torch.quint8: torch.uint8, + torch.qint8: torch.int8, + torch.qint32: torch.int32, + } + if weight_qscheme == torch.per_tensor_affine: + if weight_dtype in [torch.quint8, torch.qint8, torch.qint32]: + weight_dtype_ = _QDTYPE_TO_UNDERLYING_INT_REPR_DTYPE[weight_dtype] + if weight_quant_min is None or weight_quant_max is None: + weight_quant_min, weight_quant_max = _DTYPE_TO_QVALUE_BOUNDS[ + weight_dtype_ + ] + weight = torch.ops.quantized_decomposed.quantize_per_tensor( + weight, + weight_scale, + weight_zero_point, + weight_quant_min, + weight_quant_max, + weight_dtype_, + ) + return weight + elif weight_qscheme in [ + torch.per_channel_affine, + torch.per_channel_affine_float_qparams, + ]: + # TODO: torch.quint4x2 is not supported + if weight_dtype in [torch.quint8, torch.qint8, torch.qint32]: + weight_dtype_ = _QDTYPE_TO_UNDERLYING_INT_REPR_DTYPE[weight_dtype] + if weight_quant_min is None or weight_quant_max is None: + weight_quant_min, weight_quant_max = _DTYPE_TO_QVALUE_BOUNDS[ + weight_dtype_ + ] + weight = torch.ops.quantized_decomposed.quantize_per_channel( + weight, + weight_scale, + weight_zero_point, + weight_axis, + weight_quant_min, + weight_quant_max, + weight_dtype_, + ) # type: ignore[arg-type] + return weight + raise ValueError(f"Unsupported dtype and qscheme: {weight_dtype}, {weight_qscheme}") + + +def _dequantize_weight_decomposed( + weight: torch.Tensor, + weight_qscheme: torch.qscheme, + weight_dtype: torch.dtype, + weight_scale: torch.Tensor, + weight_zero_point: torch.Tensor, + weight_axis: int, + weight_quant_min: int | None, + weight_quant_max: int | None, +) -> torch.Tensor: + # TODO: get the quant_min and quant_max from activation_post_process + _DTYPE_TO_QVALUE_BOUNDS: dict[torch.dtype, tuple[int, int]] = { + torch.uint8: (0, 255), + torch.int8: (-128, 127), + torch.int32: (-2147483648, 2147483647), # torch.jit interprets 2**31 as a float + } + # TODO: add an util function for converting qdtype to dtype + _QDTYPE_TO_UNDERLYING_INT_REPR_DTYPE = { + torch.quint8: torch.uint8, + torch.qint8: torch.int8, + torch.qint32: torch.int32, + } + weight_dtype_ = _QDTYPE_TO_UNDERLYING_INT_REPR_DTYPE[weight_dtype] + if weight_quant_min is None or weight_quant_max is None: + weight_quant_min, weight_quant_max = _DTYPE_TO_QVALUE_BOUNDS[weight_dtype_] + if weight_qscheme == torch.per_tensor_affine: + if weight_dtype in [torch.quint8, torch.qint8, torch.qint32]: + weight = torch.ops.quantized_decomposed.dequantize_per_tensor( + weight, + weight_scale, + weight_zero_point, + weight_quant_min, + weight_quant_max, + weight_dtype_, + ) + return weight + elif weight_qscheme in [ + torch.per_channel_affine, + torch.per_channel_affine_float_qparams, + ]: + # TODO: torch.quint4x2 is not supported + if weight_dtype in [torch.quint8, torch.qint8, torch.qint32]: + weight = torch.ops.quantized_decomposed.dequantize_per_channel( + weight, + weight_scale, + weight_zero_point, + weight_axis, + weight_quant_min, + weight_quant_max, + weight_dtype_, + ) # type: ignore[arg-type] + return weight + raise ValueError(f"Unsupported dtype and qscheme: {weight_dtype}, {weight_qscheme}") + + +def _quantize_weight( + weight: torch.Tensor, + weight_qscheme: torch.qscheme, + weight_dtype: torch.dtype, + weight_scale: torch.Tensor, + weight_zero_point: torch.Tensor, + weight_axis_int: int, +) -> torch.Tensor: + if weight_dtype == torch.float16: + weight = weight.to(weight_dtype) + return weight + + if weight_qscheme == torch.per_tensor_affine: + if weight_dtype in [torch.quint8, torch.qint8, torch.qint32]: + weight = torch.quantize_per_tensor( + weight, weight_scale, weight_zero_point, weight_dtype + ) + return weight + elif weight_qscheme in [ + torch.per_channel_affine, + torch.per_channel_affine_float_qparams, + ]: + if weight_dtype in [torch.quint8, torch.qint8, torch.quint4x2, torch.qint32]: + weight = torch.quantize_per_channel( + weight, weight_scale, weight_zero_point, weight_axis_int, weight_dtype + ) # type: ignore[arg-type] + return weight + raise ValueError(f"Unsupported dtype and qscheme: {weight_dtype}, {weight_qscheme}") + + +def _quantize_and_dequantize_weight_decomposed( + weight: torch.Tensor, + weight_qscheme: torch.qscheme, + weight_dtype: torch.dtype, + weight_scale: torch.Tensor, + weight_zero_point: torch.Tensor, + weight_axis_int: int, + weight_quant_min: int | None, + weight_quant_max: int | None, +) -> torch.Tensor: + """Quantize and then dequantize the weight based on + the quantization parameters + """ + if weight_qscheme in [ + torch.per_tensor_affine, + torch.per_channel_affine, + torch.per_channel_affine_float_qparams, + ]: + weight_quant = _quantize_weight_decomposed( + weight, + weight_qscheme, + weight_dtype, + weight_scale, + weight_zero_point, + weight_axis_int, + weight_quant_min, + weight_quant_max, + ) + weight_dequant = _dequantize_weight_decomposed( + weight_quant, + weight_qscheme, + weight_dtype, + weight_scale, + weight_zero_point, + weight_axis_int, + weight_quant_min, + weight_quant_max, + ) + else: + weight_dequant = weight + return weight_dequant + + +def _quantize_and_dequantize_weight( + weight: torch.Tensor, + weight_qscheme: torch.qscheme, + weight_dtype: torch.dtype, + weight_scale: torch.Tensor, + weight_zero_point: torch.Tensor, + weight_axis_int: int, +) -> torch.Tensor: + """Quantize and then dequantize the weight based on + the quantization parameters + """ + if weight_qscheme in [ + torch.per_tensor_affine, + torch.per_channel_affine, + torch.per_channel_affine_float_qparams, + ]: + weight_quant = _quantize_weight( + weight, + weight_qscheme, + weight_dtype, + weight_scale, + weight_zero_point, + weight_axis_int, + ) + weight_dequant = weight_quant.dequantize() + else: + weight_dequant = weight + return weight_dequant + + +def _save_weight_qparams( + destination, + prefix, + weight_qscheme, + weight_dtype, + weight_scale, + weight_zero_point, + weight_axis, +): + destination[prefix + "weight_qscheme"] = weight_qscheme + destination[prefix + "weight_dtype"] = weight_dtype + if weight_qscheme is not None: + destination[prefix + "weight_scale"] = weight_scale + destination[prefix + "weight_zero_point"] = weight_zero_point + if weight_qscheme == torch.per_channel_affine: + destination[prefix + "weight_axis"] = weight_axis + + +def _get_weight_qparam_keys(state_dict: dict[str, typing.Any], prefix: str): + keys = ["weight_qscheme", "weight_dtype"] + weight_qscheme = state_dict[prefix + "weight_qscheme"] + if weight_qscheme is not None: + keys.append("weight_scale") + keys.append("weight_zero_point") + if weight_qscheme == torch.quantize_per_channel: + keys.append("weight_axis") + return keys diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/sparse/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/sparse/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0fda5a58f2984ee05b0d167297b458f62c37fc59 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/sparse/__init__.py @@ -0,0 +1 @@ +from . import quantized diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/sparse/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/sparse/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f1d4e4b12751b92ba440250ab72cbdf374a5d00e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/sparse/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/sparse/quantized/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/sparse/quantized/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ef66c90b0e8ecdbc7cd2cfb4c1cecf0bc38e8466 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/sparse/quantized/__init__.py @@ -0,0 +1,10 @@ +from torch.ao.nn.sparse.quantized import dynamic + +from .linear import Linear, LinearPackedParams + + +__all__ = [ + "dynamic", + "Linear", + "LinearPackedParams", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/sparse/quantized/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/sparse/quantized/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dbc54430939b4b7d64baa4b2eeba23099f50a365 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/sparse/quantized/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/sparse/quantized/__pycache__/linear.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/sparse/quantized/__pycache__/linear.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..75fb8ff1831d97562baf6f0ab48f4e0a8f543670 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/sparse/quantized/__pycache__/linear.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/sparse/quantized/__pycache__/utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/sparse/quantized/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2fabd34376868c4919ef636fb736662e70e70a21 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/sparse/quantized/__pycache__/utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/sparse/quantized/dynamic/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/sparse/quantized/dynamic/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..91ecfd8793dc08b96ed64f47f531724aa8a866d0 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/sparse/quantized/dynamic/__init__.py @@ -0,0 +1,6 @@ +from .linear import Linear + + +__all__ = [ + "Linear", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/sparse/quantized/dynamic/linear.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/sparse/quantized/dynamic/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..d327cabd0d3681cce4ec4b7d62f0f9e734ad0730 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/sparse/quantized/dynamic/linear.py @@ -0,0 +1,191 @@ +# mypy: allow-untyped-defs + +import torch +import torch.ao.nn.intrinsic as nni +from torch.ao.nn.quantized.modules.utils import ( + _hide_packed_params_repr, + _quantize_weight, +) +from torch.ao.nn.sparse.quantized import linear +from torch.ao.nn.sparse.quantized.utils import LinearBlockSparsePattern + + +__all__ = ["Linear"] + + +class Linear(torch.nn.Module): + r""" + A dynamically quantized sparse linear module with float tensor as inputs and outputs. + """ + + _version = 1 + _op_type = "sparse_dynamic" + _FLOAT_MODULE = torch.nn.Linear + + def __init__( + self, + in_features, + out_features, + row_block_size, + col_block_size, + bias=True, + dtype=torch.qint8, + ): + super().__init__() + + if dtype != torch.qint8: + raise NotImplementedError( + "Only QINT8 is supported for Sparse Quantized Linear Dynamic" + ) + + self.in_features = in_features + self.out_features = out_features + + if bias: + bias = torch.zeros(self.out_features, dtype=torch.float) + else: + bias = None + + qweight = torch._empty_affine_quantized( + [out_features, in_features], scale=1, zero_point=0, dtype=torch.qint8 + ) + self._packed_params = linear.LinearPackedParams( + row_block_size=row_block_size, col_block_size=col_block_size, dtype=dtype + ) + self._packed_params.set_weight_bias( + qweight, bias, row_block_size, col_block_size + ) + + def _get_name(self): + return "SparseQuantizedDynamicLinear" + + def extra_repr(self): + return f"in_features={self.in_features}, out_features={self.out_features}, qscheme={self.weight().qscheme()}" + + def __repr__(self): + return _hide_packed_params_repr(self, linear.LinearPackedParams) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.ops.sparse.qlinear_dynamic(x, self._packed_params._packed_params) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + super()._save_to_state_dict(destination, prefix, keep_vars) + destination[prefix + "op_type"] = self._op_type + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + op_type = int(state_dict[prefix + "op_type"]) + assert op_type == "sparse", ( + f"Cannot load from op_type [{op_type}], expecting [{self._op_type}]" + ) + state_dict.pop(prefix + "op_type") + + version = local_metadata.get("version", None) + assert version <= self._version + + # Is this code valid? In old quantization it seemed to be used to load + # older model + weight = state_dict.pop(prefix + "weight") + bias = state_dict.pop(prefix + "bias") + state_dict.update( + { + prefix + "_packed_params.weight": weight, + prefix + "_packed_params.bias": bias, + } + ) + + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + False, + missing_keys, + unexpected_keys, + error_msgs, + ) + + def _weight_bias(self): + return self._packed_params._weight_bias() + + def weight(self): + return self._weight_bias()[0] + + def bias(self): + return self._weight_bias()[1] + + def set_weight_bias( + self, + w: torch.Tensor, + b: torch.Tensor | None, + row_block_size: int | None, + col_block_size: int | None, + ) -> None: + assert row_block_size is not None and col_block_size is not None + self.out_features = w.shape[0] + self.in_features = w.shape[1] + self._packed_params.set_weight_bias(w, b, row_block_size, col_block_size) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + r"""Create a quantized sparse dynamic module from a float module. + + We only care about the convert at this stage, no need for observers just yet. + """ + assert type(mod) is cls._FLOAT_MODULE, ( + " nnq." + + cls.__name__ + + ".from_float only works for " + + cls._FLOAT_MODULE.__name__ + ) + # TODO: Need to add options to qconfig to avoid the calibration. + # TODO: Add calibration for the sparsity + assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" + if type(mod) is nni.LinearReLU: + mod = mod[0] + # pyrefly: ignore [missing-attribute] + if mod.qconfig is not None and mod.qconfig.weight is not None: + # pyrefly: ignore [not-callable] + weight_observer = mod.qconfig.weight() + else: + # We have the circular import issues if we import the qconfig in the beginning of this file: + # https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the + # import until we need it. + from torch.ao.quantization.qconfig import default_dynamic_qconfig + + weight_observer = default_dynamic_qconfig.weight() + + # It is important to multiply by the mask BEFORE calling the `weight_observer` + # TODO (zaf): Mask might not be part of the qconfig (T83295194) + weight = mod.weight + if getattr(mod.qconfig, "mask", False): + weight = mod.qconfig.mask * mod.weight + + weight_observer(weight) + dtype = weight_observer.dtype + assert dtype == torch.qint8, "Weight observer must have dtype torch.qint8" + _w_sc, w_zp = weight_observer.calculate_qparams() + if isinstance(w_zp, torch.Tensor): + assert not torch.any(w_zp.bool()), "All weight zero points must map to 0" + else: + assert w_zp == 0, "Weight zero point must map to 0" + qweight = _quantize_weight(weight.float(), weight_observer) + + row_block_size, col_block_size = LinearBlockSparsePattern.block_size() + qlinear = cls( + mod.in_features, + mod.out_features, + row_block_size, + col_block_size, + dtype=dtype, + ) + # pyrefly: ignore [bad-argument-type] + qlinear.set_weight_bias(qweight, mod.bias, row_block_size, col_block_size) + return qlinear diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/sparse/quantized/linear.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/sparse/quantized/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..f106a32abfbf960b989c8eba860db2dec4a7fe4c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/sparse/quantized/linear.py @@ -0,0 +1,274 @@ +# mypy: allow-untyped-defs + +import torch +from torch.ao.nn.quantized.modules.utils import ( + _hide_packed_params_repr, + _quantize_weight, +) + + +__all__ = ["LinearPackedParams", "Linear"] + + +# TODO (zaf): Inherit from `quantized.LinearPackedParams` (T83294430) +class LinearPackedParams(torch.nn.Module): + _version = 1 + + def __init__(self, row_block_size=1, col_block_size=4, dtype=torch.qint8): + super().__init__() + + if dtype != torch.qint8: + raise NotImplementedError("Linear prepacking only supports QINT8") + self.dtype = dtype + wq = torch._empty_affine_quantized( + [1, 1], scale=1.0, zero_point=0, dtype=torch.qint8 + ) + self.set_weight_bias(wq, None, row_block_size, col_block_size) + + def _get_name(self): + return "SparseQuantizedLinearPackedParams" + + @torch.jit.export + def set_weight_bias( + self, + weight: torch.Tensor, + bias: torch.Tensor | None, + row_block_size: int | None, + col_block_size: int | None, + ) -> None: + assert row_block_size is not None and col_block_size is not None + self._packed_params = torch.ops.sparse.qlinear_prepack( + weight, bias, row_block_size, col_block_size + ) + + @torch.jit.export + def _weight_bias(self): + (weight, bias, block_sizes) = torch.ops.sparse.qlinear_unpack( + self._packed_params + ) + return (weight, bias, block_sizes[0], block_sizes[1]) + + def forward(self, x): + return x + + def _save_to_state_dict(self, destination, prefix, keep_vars): + super()._save_to_state_dict(destination, prefix, keep_vars) + destination[prefix + "dtype"] = self.dtype + destination[prefix + "_packed_params"] = self._weight_bias() + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + version = local_metadata.get("version", None) + assert version <= self._version + + self.dtype = state_dict.pop(prefix + "dtype") + weight, bias, row_block_size, col_block_size = state_dict.pop( + prefix + "_packed_params" + ) + self.set_weight_bias(weight, bias, row_block_size, col_block_size) + + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + False, + missing_keys, + unexpected_keys, + error_msgs, + ) + + @torch.jit.export + def __getstate__(self): + return self._packed_params, self.training, self.dtype + + @torch.jit.export + def __setstate__(self, state): + (self._packed_params, self.training, self.dtype) = state + + def __repr__(self): + return self._weight_bias().__repr__() + + +# TODO (zaf): Inherit from `quantized.Linear` (T83294430) +class Linear(torch.nn.Module): + r""" + A quantized sparse linear module with quantized tensor as inputs and outputs. + """ + + _version = 1 + _FLOAT_MODULE = torch.nn.Linear + + def __init__( + self, + in_features, + out_features, + row_block_size, + col_block_size, + bias=True, + dtype=torch.qint8, + ): + super().__init__() + + if dtype != torch.qint8: + raise NotImplementedError( + "Only QINT8 is supported for Sparse Quantized Linear" + ) + + self.in_features = in_features + self.out_features = out_features + + if bias: + bias = torch.zeros(self.out_features, dtype=torch.float) + else: + bias = None + + qweight = torch._empty_affine_quantized( + [out_features, in_features], scale=1, zero_point=0, dtype=torch.qint8 + ) + self._packed_params = LinearPackedParams( + row_block_size=row_block_size, col_block_size=col_block_size, dtype=dtype + ) + self._packed_params.set_weight_bias( + qweight, bias, row_block_size, col_block_size + ) + self.scale = 1.0 + self.zero_point = 0 + + @classmethod + def _get_name(cls): + return "SparseQuantizedLinear" + + def extra_repr(self): + return ( + f"in_features={self.in_features}, out_features={self.out_features}, scale={self.scale}, " + f"zero_point={self.zero_point}, qscheme={self.weight().qscheme()}" + ) + + def __repr__(self): + return _hide_packed_params_repr(self, LinearPackedParams) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.ops.sparse.qlinear( + x, self._packed_params._packed_params, self.scale, self.zero_point + ) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + super()._save_to_state_dict(destination, prefix, keep_vars) + destination[prefix + "scale"] = torch.tensor(self.scale) + destination[prefix + "zero_point"] = torch.tensor(self.zero_point) + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + self.scale = float(state_dict[prefix + "scale"]) + state_dict.pop(prefix + "scale") + + self.zero_point = int(state_dict[prefix + "zero_point"]) + state_dict.pop(prefix + "zero_point") + + state_dict.pop(prefix + "op_type") + + version = local_metadata.get("version", None) + assert version <= self._version + + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + False, + missing_keys, + unexpected_keys, + error_msgs, + ) + + def _weight_bias(self): + return self._packed_params._weight_bias() + + def weight(self): + return self._weight_bias()[0] + + def bias(self): + return self._weight_bias()[1] + + def set_weight_bias( + self, + w: torch.Tensor, + b: torch.Tensor | None, + row_block_size: int | None, + col_block_size: int | None, + ) -> None: + assert row_block_size is not None and col_block_size is not None + self._packed_params.set_weight_bias(w, b, row_block_size, col_block_size) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + r"""Create a quantized sparse module from a float module. + + We only care about the convert at this stage, no need for observers just yet. + + TODO(zaf): Need to add the sparse params to the qconfig + """ + assert type(mod) is cls._FLOAT_MODULE, ( + cls._get_name() + ".from_float only works for " + cls._FLOAT_MODULE.__name__ + ) + assert hasattr(mod, "sparse_params"), ( + "Expecting the Linear to have `sparse_params`. Make sure you have provided arguments " + 'in the `sparsifier.squash_mask(params_to_save=("sparse_block_shape",))` method.' + ) + sparse_block_shape = mod.sparse_params.get("sparse_block_shape", None) # type: ignore[operator, union-attr] + assert isinstance(sparse_block_shape, (tuple, list)) + assert len(sparse_block_shape) == 2 + # TODO: Need to add options to qconfig to avoid the calibration. + # TODO: Add calibration for the sparsity + assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" + activation_post_process = mod.activation_post_process + weight_post_process = mod.qconfig.weight() # type: ignore[operator, union-attr] + + # Assumption is that the weight is already sparsified by the + # `sparsifier.convert` + weight = mod.weight + + weight_post_process(weight) + dtype = weight_post_process.dtype + act_scale, act_zp = activation_post_process.calculate_qparams() # type: ignore[operator, union-attr] + assert dtype == torch.qint8, "Weight observer must have dtype torch.qint8" + w_sc, w_zp = weight_post_process.calculate_qparams() + if isinstance(w_zp, torch.Tensor): + assert not torch.any(w_zp.bool()), "All weight zero points must map to 0" + else: + assert w_zp == 0, "Weight zero point must map to 0" + qweight = _quantize_weight(weight.float(), weight_post_process) + + row_block_size = mod.sparse_params["sparse_block_shape"][0] # type: ignore[index] + col_block_size = mod.sparse_params["sparse_block_shape"][1] # type: ignore[index] + qlinear = cls( + mod.in_features, + mod.out_features, + row_block_size, + col_block_size, + dtype=dtype, + ) + qlinear.set_weight_bias( + qweight, + mod.bias, + row_block_size, # type: ignore[arg-type] + col_block_size, # type: ignore[arg-type] + ) + qlinear.scale = float(act_scale) + qlinear.zero_point = int(act_zp) + return qlinear diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/sparse/quantized/utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/sparse/quantized/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2cfd4a5973dfa8a5219f5ca97246424ae17a6308 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/sparse/quantized/utils.py @@ -0,0 +1,62 @@ +import threading + + +__all__ = ["LinearBlockSparsePattern"] + + +def _is_valid_linear_block_sparse_pattern( + row_block_size: int, col_block_size: int +) -> bool: + return (row_block_size == 1 and col_block_size == 4) or ( + row_block_size == 8 and col_block_size == 1 + ) + + +# This is a stop-gap measure as current flow does not allow module +# specific block sparse pattern. +# In fact there is no way to convey sparse pattern via module config +# of quantization flow. Thus using the global context to convey +# sparsity pattern. +# Once the flow supports it, this should be removed. +class LinearBlockSparsePattern: + rlock = threading.RLock() + row_block_size: int = 1 + col_block_size: int = 4 + prev_row_block_size: int = 1 + prev_col_block_size: int = 4 + + def __init__(self, row_block_size: int = 1, col_block_size: int = 4): + assert _is_valid_linear_block_sparse_pattern(row_block_size, col_block_size) + LinearBlockSparsePattern.rlock.acquire() + LinearBlockSparsePattern.prev_row_block_size = ( + LinearBlockSparsePattern.row_block_size + ) + LinearBlockSparsePattern.prev_col_block_size = ( + LinearBlockSparsePattern.col_block_size + ) + LinearBlockSparsePattern.row_block_size = row_block_size + LinearBlockSparsePattern.col_block_size = col_block_size + + def __enter__(self) -> None: + pass + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + backtrace: object | None, + ) -> None: + LinearBlockSparsePattern.row_block_size = ( + LinearBlockSparsePattern.prev_row_block_size + ) + LinearBlockSparsePattern.col_block_size = ( + LinearBlockSparsePattern.prev_col_block_size + ) + LinearBlockSparsePattern.rlock.release() + + @staticmethod + def block_size() -> tuple[int, int]: + return ( + LinearBlockSparsePattern.row_block_size, + LinearBlockSparsePattern.col_block_size, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/_numeric_suite.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/_numeric_suite.py new file mode 100644 index 0000000000000000000000000000000000000000..026ac73606e307bedd500a801a76ba1a97c4c655 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/_numeric_suite.py @@ -0,0 +1,568 @@ +# mypy: allow-untyped-defs +from collections.abc import Callable +from typing import Any + +import torch +import torch.ao.nn.quantized as nnq +import torch.ao.nn.quantized.dynamic as nnqd +import torch.nn as nn +from torch.ao.quantization import prepare +from torch.ao.quantization.quantization_mappings import ( + get_default_compare_output_module_list, +) + + +NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST = { + nnqd.Linear, + nnq.Linear, + nnqd.LSTM, + nn.LSTM, +} + + +def _find_match( + str_list: dict[str, Any] | list[str], + key_str: str, + postfix: str, +) -> str | None: + split_str = key_str.split(".") + if split_str[-1] == postfix: + match_string = "".join(key_str.split(".")[0:-1]) + for s2 in str_list: + pattern1 = "".join(s2.split(".")[0:-1]) + pattern2 = "".join(s2.split(".")[0:-2]) + if match_string == pattern1: + return s2 + if match_string == pattern2: + return s2 + + # For matching "fc.weight" and "fc._packed_params._packed_params" + if postfix == "_packed_params": + match_string = "".join(key_str.split(".")[0:-2]) + if len(match_string) == 0: + return None + for s2 in str_list: + pattern1 = "".join(s2.split(".")[0:-1]) + pattern2 = "".join(s2.split(".")[0:-2]) + if match_string == pattern1: + return s2 + if match_string == pattern2: + return s2 + return None + else: + return None + + +def compare_weights( + float_dict: dict[str, Any], quantized_dict: dict[str, Any] +) -> dict[str, dict[str, torch.Tensor]]: + r"""Compare the weights of the float module with its corresponding quantized + module. Return a dict with key corresponding to module names and each entry being + a dictionary with two keys 'float' and 'quantized', containing the float and + quantized weights. This dict can be used to compare and compute the quantization + error of the weights of float and quantized models. + + Example usage:: + + wt_compare_dict = compare_weights(float_model.state_dict(), qmodel.state_dict()) + for key in wt_compare_dict: + print( + key, + compute_error( + wt_compare_dict[key]["float"], + wt_compare_dict[key]["quantized"].dequantize(), + ), + ) + + Args: + float_dict: state dict of the float model + quantized_dict: state dict of the quantized model + + Return: + weight_dict: dict with key corresponding to module names and each entry being + a dictionary with two keys 'float' and 'quantized', containing the float and + quantized weights + """ + torch._C._log_api_usage_once("quantization_api._numeric_suite.compare_weights") + weight_dict: dict[str, dict] = {} + for key in quantized_dict: + match_key = _find_match(float_dict, key, "weight") + if match_key is not None: + weight_dict[key] = {} + weight_dict[key]["float"] = float_dict[match_key] + weight_dict[key]["quantized"] = quantized_dict[key] + continue + + # For matching "fc.weight" and "fc._packed_params._packed_params" + match_key = _find_match(float_dict, key, "_packed_params") + if match_key is not None: + weight_dict[key] = {} + weight_dict[key]["float"] = float_dict[match_key] + weight_dict[key]["quantized"] = quantized_dict[key][0] + + # For LSTM + split_str = key.split(".") + if split_str[-1] == "param" and split_str[-3] == "_all_weight_values": + layer = split_str[-2] + module_name = ".".join(split_str[:-3]) + float_weight_ih_key = module_name + ".weight_ih_l" + layer + float_weight_hh_key = module_name + ".weight_hh_l" + layer + if float_weight_ih_key in float_dict and float_weight_hh_key in float_dict: + weight_dict[key] = {} + weight_dict[key]["float"] = float_dict[float_weight_ih_key] + weight_dict[key]["quantized"] = ( + quantized_dict[key].__getstate__()[0][4][0].__getstate__()[0][0] + ) + weight_dict[key]["float"] = float_dict[float_weight_hh_key] + weight_dict[key]["quantized"] = ( + quantized_dict[key].__getstate__()[0][4][1].__getstate__()[0][0] + ) + + return weight_dict + + +def _get_logger_dict_helper( + mod: nn.Module, + target_dict: dict[str, Any], + prefix: str = "", +) -> None: + r"""This is the helper function for get_logger_dict + + Args: + mod: module we want to save all logger stats + prefix: prefix for the current module + target_dict: the dictionary used to save all logger stats + """ + + def get_prefix(prefix): + return prefix if prefix == "" else prefix + "." + + for child in mod.children(): + if isinstance(child, Logger): + target_dict[get_prefix(prefix) + "stats"] = child.stats + break + + for name, child in mod.named_children(): + module_prefix = get_prefix(prefix) + name if prefix else name + _get_logger_dict_helper(child, target_dict, module_prefix) + + +def get_logger_dict(mod: nn.Module, prefix: str = "") -> dict[str, dict]: + r"""Traverse the modules and save all logger stats into target dict. + This is mainly used for quantization accuracy debug. + + Type of loggers supported: + ShadowLogger: used to log the outputs of the quantized module and its matching float shadow module, + OutputLogger: used to log the outputs of the modules + + Args: + mod: module we want to save all logger stats + prefix: prefix for the current module + + Return: + target_dict: the dictionary used to save all logger stats + + """ + torch._C._log_api_usage_once("quantization_api._numeric_suite.get_logger_dict") + + target_dict: dict[str, dict] = {} + _get_logger_dict_helper(mod, target_dict, prefix) + return target_dict + + +class Logger(nn.Module): + r"""Base class for stats logging""" + + def __init__(self): + super().__init__() + self.stats = {} + # We only insert observer if the op is quantized with static quantization, + # which is identified by activation_observer.dtype == quint8. This is needed + # when attaching Logger as observer for FX mode + self.dtype = torch.quint8 + + def forward(self, x): + # fmt: off + """ + """ # blank docblock to make autodoc happy + # fmt: on + + +class ShadowLogger(Logger): + r"""Class used in Shadow module to record the outputs of the original and + shadow modules. + """ + + def __init__(self): + super().__init__() + self.stats["float"] = [] + self.stats["quantized"] = [] + + def forward(self, x, y): # type: ignore[override] + # fmt: off + """ + """ # blank docblock to make autodoc happy + # fmt: on + if len(x) > 1: + x = x[0] + if len(y) > 1: + y = y[0] + self.stats["quantized"].append(x.detach()) + self.stats["float"].append(y.detach()) + + +class OutputLogger(Logger): + r"""Class used to log the outputs of the module""" + + def __init__(self): + super().__init__() + self.stats["tensor_val"] = [] + + def forward(self, x): + # fmt: off + """ + """ # blank docblock to make autodoc happy + # fmt: on + self.stats["tensor_val"].append(x) + return x + + +def _convert_tuple_to_list(t: Any) -> Any: + return [_convert_tuple_to_list(x) for x in t] if type(t) is tuple else t + + +def _dequantize_tensor_list(t: Any) -> Any: + return ( + [_dequantize_tensor_list(x) for x in t] + if type(t) is list + else t.dequantize() + if t.is_quantized + else t + ) + + +class Shadow(nn.Module): + r"""Shadow module attaches the float module to its matching quantized module + as the shadow. Then it uses Logger module to process the outputs of both + modules. + + Args: + q_module: module quantized from float_module that we want to shadow + float_module: float module used to shadow q_module + logger_cls: type of logger used to process the outputs of q_module and + float_module. ShadowLogger or custom loggers can be used. + """ + + def __init__(self, q_module, float_module, logger_cls): + super().__init__() + self.orig_module = q_module + self.shadow_module = float_module + self.dequant = nnq.DeQuantize() + self.logger = logger_cls() + + def forward(self, *x) -> torch.Tensor: + # fmt: off + """ + """ # blank docblock to make autodoc happy + # fmt: on + xl = _convert_tuple_to_list(x) + output = self.orig_module(*xl) + xl_float = _dequantize_tensor_list(xl) + shadow_output = self.shadow_module(*xl_float) + self.logger(output, shadow_output) + return output + + def add(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + # fmt: off + """ + """ # blank docblock to make autodoc happy + # fmt: on + output = self.orig_module.add(x, y) + x = x.dequantize() + y = y.dequantize() + shadow_output = self.shadow_module.add(x, y) + self.logger(output, shadow_output) + return output + + def add_scalar(self, x: torch.Tensor, y: float) -> torch.Tensor: + # fmt: off + """ + """ # blank docblock to make autodoc happy + # fmt: on + output = self.orig_module.add_scalar(x, y) + x = x.dequantize() + shadow_output = self.shadow_module.add_scalar(x, y) + self.logger(output, shadow_output) + return output + + def mul(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + # fmt: off + """ + """ # blank docblock to make autodoc happy + # fmt: on + output = self.orig_module.mul(x, y) + x = x.dequantize() + y = y.dequantize() + shadow_output = self.shadow_module.mul(x, y) + self.logger(output, shadow_output) + return output + + def mul_scalar(self, x: torch.Tensor, y: float) -> torch.Tensor: + # fmt: off + """ + """ # blank docblock to make autodoc happy + # fmt: on + output = self.orig_module.mul_scalar(x, y) + x = x.dequantize() + shadow_output = self.shadow_module.mul_scalar(x, y) + self.logger(output, shadow_output) + return output + + def cat(self, x: list[torch.Tensor], dim: int = 0) -> torch.Tensor: + # fmt: off + """ + """ # blank docblock to make autodoc happy + # fmt: on + output = self.orig_module.cat(x, dim) + x = [y.dequantize() for y in x] + shadow_output = self.shadow_module.cat(x, dim) + self.logger(output, shadow_output) + return output + + def add_relu(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + # fmt: off + """ + """ # blank docblock to make autodoc happy + # fmt: on + output = self.orig_module.add_relu(x, y) + x = x.dequantize() + y = y.dequantize() + shadow_output = self.shadow_module.add_relu(x, y) + self.logger(output, shadow_output) + return output + + +def prepare_model_with_stubs( + float_module: nn.Module, + q_module: nn.Module, + module_swap_list: set[type], + logger_cls: Callable, +) -> None: + r"""Prepare the model by attaching the float module to its matching quantized + module as the shadow if the float module type is in module_swap_list. + + Example usage:: + + prepare_model_with_stubs(float_model, q_model, module_swap_list, Logger) + q_model(data) + ob_dict = get_logger_dict(q_model) + + Args: + float_module: float module used to generate the q_module + q_module: module quantized from float_module + module_swap_list: list of float module types to attach the shadow + logger_cls: type of logger to be used in shadow module to process the outputs of + quantized module and its float shadow module + """ + torch._C._log_api_usage_once( + "quantization_api._numeric_suite.prepare_model_with_stubs" + ) + + float_module_children = dict(float_module.named_children()) + + reassign = {} + for name, mod in q_module.named_children(): + if name not in float_module_children: + continue + + float_mod = float_module_children[name] + + if type(float_mod) not in module_swap_list: + prepare_model_with_stubs(float_mod, mod, module_swap_list, logger_cls) + + # Insert shadow module only if the module is not of the same type as + # the floating point module + if type(float_mod) in module_swap_list and not _is_identical_module_type( + mod, float_mod + ): + reassign[name] = Shadow(mod, float_mod, logger_cls) + + for key, value in reassign.items(): + q_module._modules[key] = value + + +def _is_identical_module_type(mod1, mod2): + # Compare if two modules have the same dtype + mod1_module_types = [type(mod) for mod in mod1.modules()] + mod2_module_types = [type(mod) for mod in mod2.modules()] + return mod1_module_types == mod2_module_types + + +def compare_model_stub( + float_model: nn.Module, + q_model: nn.Module, + module_swap_list: set[type], + *data, + logger_cls=ShadowLogger, +) -> dict[str, dict]: + r"""Compare quantized module in a model with its floating point counterpart, + feeding both of them the same input. Return a dict with key corresponding to + module names and each entry being a dictionary with two keys 'float' and + 'quantized', containing the output tensors of quantized and its matching + float shadow module. This dict can be used to compare and compute the module + level quantization error. + + This function first call prepare_model_with_stubs() to swap the quantized + module that we want to compare with the Shadow module, which takes quantized + module, corresponding float module and logger as input, and creates a forward + path inside to make the float module to shadow quantized module sharing the + same input. The logger can be customizable, default logger is ShadowLogger + and it will save the outputs of the quantized module and float module that + can be used to compute the module level quantization error. + + Example usage:: + + module_swap_list = [ + torchvision.models.quantization.resnet.QuantizableBasicBlock + ] + ob_dict = compare_model_stub(float_model, qmodel, module_swap_list, data) + for key in ob_dict: + print( + key, + compute_error( + ob_dict[key]["float"], ob_dict[key]["quantized"].dequantize() + ), + ) + + Args: + float_model: float model used to generate the q_model + q_model: model quantized from float_model + module_swap_list: list of float module types at which shadow modules will + be attached. + data: input data used to run the prepared q_model + logger_cls: type of logger to be used in shadow module to process the outputs of + quantized module and its float shadow module + """ + torch._C._log_api_usage_once("quantization_api._numeric_suite.compare_model_stub") + prepare_model_with_stubs(float_model, q_model, module_swap_list, logger_cls) + q_model(*data) + ob_dict = get_logger_dict(q_model) + return ob_dict + + +def get_matching_activations( + float_module: nn.Module, + q_module: nn.Module, +) -> dict[str, dict[str, torch.Tensor]]: + r"""Find the matching activation between float and quantized modules. + + Args: + float_module: float module used to generate the q_module + q_module: module quantized from float_module + + Return: + act_dict: dict with key corresponding to quantized module names and each + entry being a dictionary with two keys 'float' and 'quantized', containing + the matching float and quantized activations + """ + torch._C._log_api_usage_once( + "quantization_api._numeric_suite.get_matching_activations" + ) + float_dict = get_logger_dict(float_module) + quantized_dict = get_logger_dict(q_module) + act_dict: dict[str, dict] = {} + for key in quantized_dict: + if len(quantized_dict[key]["tensor_val"]) == 0: + continue + match_key = _find_match(sorted(float_dict, reverse=True), key, "stats") + if match_key is not None: + act_dict[key] = {} + act_dict[key]["float"] = float_dict[match_key]["tensor_val"] + act_dict[key]["quantized"] = quantized_dict[key]["tensor_val"] + return act_dict + + +def prepare_model_outputs( + float_module: nn.Module, + q_module: nn.Module, + logger_cls=OutputLogger, + allow_list=None, +) -> None: + r"""Prepare the model by attaching the logger to both float module + and quantized module if they are in the allow_list. + + Args: + float_module: float module used to generate the q_module + q_module: module quantized from float_module + logger_cls: type of logger to be attached to float_module and q_module + allow_list: list of module types to attach logger + """ + torch._C._log_api_usage_once( + "quantization_api._numeric_suite.prepare_model_outputs" + ) + if allow_list is None: + allow_list = get_default_compare_output_module_list() + + qconfig_debug = torch.ao.quantization.QConfig(activation=logger_cls, weight=None) + float_module.qconfig = qconfig_debug # type: ignore[assignment] + prepare( + float_module, inplace=True, allow_list=allow_list, prepare_custom_config_dict={} + ) + q_module.qconfig = qconfig_debug # type: ignore[assignment] + prepare( + q_module, + inplace=True, + allow_list=allow_list, + observer_non_leaf_module_list=NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST, + prepare_custom_config_dict={}, + ) + + +def compare_model_outputs( + float_model: nn.Module, + q_model: nn.Module, + *data, + logger_cls=OutputLogger, + allow_list=None, +) -> dict[str, dict[str, torch.Tensor]]: + r"""Compare output activations between float and quantized models at + corresponding locations for the same input. Return a dict with key corresponding + to quantized module names and each entry being a dictionary with two keys + 'float' and 'quantized', containing the activations of quantized model and + float model at matching locations. This dict can be used to compare and + compute the propagation quantization error. + + Example usage:: + + act_compare_dict = compare_model_outputs(float_model, qmodel, data) + for key in act_compare_dict: + print( + key, + compute_error( + act_compare_dict[key]["float"], + act_compare_dict[key]["quantized"].dequantize(), + ), + ) + + Args: + float_model: float model used to generate the q_model + q_model: model quantized from float_model + data: input data used to run the prepared float_model and q_model + logger_cls: type of logger to be attached to float_module and q_module + allow_list: list of module types to attach logger + + Return: + act_compare_dict: dict with key corresponding to quantized module names + and each entry being a dictionary with two keys 'float' and 'quantized', + containing the matching float and quantized activations + """ + torch._C._log_api_usage_once( + "quantization_api._numeric_suite.compare_model_outputs" + ) + if allow_list is None: + allow_list = get_default_compare_output_module_list() + prepare_model_outputs(float_model, q_model, logger_cls, allow_list) + float_model(*data) + q_model(*data) + act_compare_dict = get_matching_activations(float_model, q_model) + return act_compare_dict diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/_numeric_suite_fx.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/_numeric_suite_fx.py new file mode 100644 index 0000000000000000000000000000000000000000..1861d0160db152e73debda3bda7f714ca4bbf601 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/ns/_numeric_suite_fx.py @@ -0,0 +1,1121 @@ +# mypy: allow-untyped-defs +""" +This module contains tooling to compare weights and activations +across models. Example usage:: + + import copy + import torch + import torch.ao.quantization.quantize_fx as quantize_fx + import torch.ao.ns._numeric_suite_fx as ns + + m = torch.nn.Sequential(torch.nn.Conv2d(1, 1, 1)).eval() + mp = quantize_fx.prepare_fx(m, {"": torch.ao.quantization.default_qconfig}) + # We convert a copy because we need the original prepared model + # to be available for comparisons, and `quantize_fx.convert_fx` is inplace. + mq = quantize_fx.convert_fx(copy.deepcopy(mp)) + + # + # Comparing weights + # + + # extract weight pairs + weight_comparison = ns.extract_weights("a", mp, "b", mq) + + # add SQNR for each comparison, inplace + ns.extend_logger_results_with_comparison( + weight_comparison, "a", "b", torch.ao.ns.fx.utils.compute_sqnr, "sqnr" + ) + + # weight_comparison contains the weights from `mp` and `mq` stored + # in pairs, and can be used for further analysis. + + + # + # Comparing activations, with error propagation + # + + # add loggers + mp_ns, mq_ns = ns.add_loggers( + "a", copy.deepcopy(mp), "b", copy.deepcopy(mq), ns.OutputLogger + ) + + # send an example datum to capture intermediate activations + datum = torch.randn(1, 1, 1, 1) + mp_ns(datum) + mq_ns(datum) + + # extract intermediate activations + act_comparison = ns.extract_logger_info(mp_ns, mq_ns, ns.OutputLogger, "b") + + # add SQNR for each comparison, inplace + ns.extend_logger_results_with_comparison( + act_comparison, "a", "b", torch.ao.ns.fx.utils.compute_sqnr, "sqnr" + ) + + # act_comparison contains the activations from `mp_ns` and `mq_ns` stored + # in pairs, and can be used for further analysis. + + # + # Comparing activations, without error propagation + # + + # create shadow model + mp_shadows_mq = ns.add_shadow_loggers( + "a", copy.deepcopy(mp), "b", copy.deepcopy(mq), ns.OutputLogger + ) + + # send an example datum to capture intermediate activations + datum = torch.randn(1, 1, 1, 1) + mp_shadows_mq(datum) + + # extract intermediate activations + shadow_act_comparison = ns.extract_shadow_logger_info( + mp_shadows_mq, ns.OutputLogger, "b" + ) + + # add SQNR for each comparison, inplace + ns.extend_logger_results_with_comparison( + shadow_act_comparison, "a", "b", torch.ao.ns.fx.utils.compute_sqnr, "sqnr" + ) + + # shadow_act_comparison contains the activations from `mp_ns` and `mq_ns` stored + # in pairs, and can be used for further analysis. + +""" + +import collections +from collections.abc import Callable +from typing import Any, TYPE_CHECKING + +import torch +import torch.ao.quantization.quantize_fx as quantize_fx +import torch.nn as nn +from torch.ao.ns.fx.graph_matcher import get_matching_subgraph_pairs +from torch.ao.ns.fx.mappings import get_base_name_to_sets_of_related_ops +from torch.ao.ns.fx.n_shadows_utils import ( + _get_dedup_subgraphs, + create_add_loggers_graph, + create_n_transformed_and_logged_copies_of_subgraph, + create_results_comparison, + extract_weight_comparison, + group_results_by_subgraph, + OutputProp, + print_n_shadows_summary, + SHADOW_WRAPPER_NODE_NAME_PREFIX, +) +from torch.ao.ns.fx.qconfig_multi_mapping import QConfigMultiMapping +from torch.ao.quantization import QConfigMapping +from torch.ao.quantization.backend_config import BackendConfig +from torch.ao.quantization.backend_config.utils import ( + get_fusion_pattern_to_root_node_getter, +) +from torch.ao.quantization.fx.graph_module import _get_observed_graph_module_attr +from torch.ao.quantization.fx.match_utils import _find_matches +from torch.ao.quantization.fx.qconfig_mapping_utils import ( + _generate_node_name_to_qconfig, +) +from torch.ao.quantization.fx.quantize_handler import _get_pattern_to_quantize_handlers +from torch.fx import GraphModule +from torch.fx.graph import Node + +from .fx.graph_passes import add_loggers_to_model, create_a_shadows_b +from .fx.ns_types import NSNodeTargetType, NSResultsType, NSSingleResultValuesType +from .fx.utils import ( + get_target_type_str, + maybe_add_missing_fqns, + rekey_logger_info_on_node_name_of_model, +) +from .fx.weight_utils import extract_weight_from_node + + +if TYPE_CHECKING: + from torch.ao.quantization.qconfig import QConfigAny + +RNNReturnType = tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]] + + +class OutputLogger(nn.Module): + """ + Base class for capturing intermediate values. + """ + + stats: list[torch.Tensor] + stats_rnn: list[RNNReturnType] + + # Mark as impure so that calls to it will not be removed during DCE. + _is_impure = True + + def __init__( + self, + ref_node_name: str, + prev_node_name: str, + model_name: str, + ref_name: str, + prev_node_target_type: str, + ref_node_target_type: str, + results_type: str, + index_within_arg: int, + index_of_arg: int, + fqn: str | None, + qconfig_str: str | None = "", + ): + super().__init__() + self.stats: list[torch.Tensor] = [] + self.stats_rnn: list[RNNReturnType] = [] + + # name of the node which was responsible for adding this logger + # Note: + # - if we are logging node outputs, this is the same as prev_node_name + # - if we are logging node inputs, this is the name of the node + # whose input this logger is logging. + # + # example, where logger1 is logging input of op1 and logger2 is logging + # the output of op1: + # + # x1 -> logger1 -> op1 -> logger2 -> x2 + # + # in this example, + # - logger1's prev_node_name is x1 and ref_node_name is op1 + # - logger2's prev_node_name is op1 and ref_node_name is op1 + self.ref_node_name = ref_node_name + # name of the node whose output this Logger is capturing + self.prev_node_name = prev_node_name + + # name of the model from which the node originated from + self.model_name = model_name + # reference name, used to match loggers from separate models + # to each other + self.ref_name = ref_name + # type of the target of the node whose output this logger is logging + self.prev_node_target_type = prev_node_target_type + # type of the target of the node which was responsible for adding this + # logger + self.ref_node_target_type = ref_node_target_type + # what kind of values are inside of stats + self.results_type = results_type + # index of this node within the arg of the input/output node + # for example, in cat([x1, x2, x3], dim=0), x2 would have index_within_arg == 1 + self.index_within_arg = index_within_arg + # index of this node within the args of the input/output node + # for example, in add(x1, x2), x2 would have index_of_arg == 1 + self.index_of_arg = index_of_arg + # fully qualified name + self.fqn = fqn + # if loggers are added before prepare_fx, but we do not want + # collect results of calibration, only results after convert_fx + # so, we add a flag to control whether this logger collects data + self.enabled = True + # string representation of qconfig + self.qconfig_str = qconfig_str + # this can be turned off to reduce memory usage during calibration + self.save_activations = True + + # Note: cannot annotate the type of x because TorchScript does not support + # the Union type. + def forward(self, x): + # fmt: off + """ + """ # blank docblock to make autodoc happy + # fmt: on + # TODO(future PR): consider designing this better, as the difference + # between these two flags is subtle and not obvious. + if not self.enabled: + return x + if not self.save_activations: + return x + # TODO(future PR): consider refactoring this to better reuse the parent + # class + if isinstance(x, torch.Tensor): + self.stats.append(x.detach()) + elif isinstance(x, tuple) and len(x) == 2 and len(x[1]) == 2: + new_res = (x[0].detach(), (x[1][0].detach(), x[1][1].detach())) + self.stats_rnn.append(new_res) + return x + + def __repr__(self): + clean_dict = { + k: v + for k, v in self.__dict__.items() + # skip nn.Module keys + if (k != "training") and not k.startswith("_") + } + return f"OutputLogger({clean_dict})" + + +class OutputComparisonLogger(OutputLogger): + """ + Same as OutputLogger, but also requires the original activation + in order to calculate the comparison at calibration time + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # TODO(future PR): make the comparison function configurable + self.comparison_fn = torch.ao.ns.fx.utils.compute_sqnr + self.comparison_fn_name = "sqnr" + # precalculated comparisons of logger output versus reference + self.comparisons = [] + # precalculated comparisons function + + def forward(self, x, x_ref): # type: ignore[override] + # fmt: off + """ + """ # blank docblock to make autodoc happy + # fmt: on + if not self.enabled: + return x + if not isinstance(x, torch.Tensor): + raise AssertionError("non-tensor inputs not yet supported") + if self.save_activations: + # save the activation, for debugging + self.stats.append(x.detach()) + # save the comparison + self.comparisons.append(self.comparison_fn(x, x_ref)) + return x + + def __repr__(self): + clean_dict = { + k: v + for k, v in self.__dict__.items() + # skip nn.Module keys + if (k != "training") and not k.startswith("_") + } + return f"OutputComparisonLogger({clean_dict})" + + +class NSTracer(quantize_fx.QuantizationTracer): + """ + Just like a regular FX quantization tracer, but treats observers and fake_quantize + modules as leaf modules. + """ + + def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: + # fmt: off + """ + """ # blank docblock to make autodoc happy + # fmt: on + if isinstance(m, torch.ao.quantization.ObserverBase): + return True + elif isinstance(m, torch.ao.quantization.FakeQuantizeBase): + return True + return super().is_leaf_module(m, module_qualified_name) + + +def _extract_weights_one_model( + model_name: str, + model: GraphModule, + nodes_and_names_to_instrument: list[tuple[Node, str]], + results: NSResultsType, + op_to_type_to_weight_extraction_fn: dict[str, dict[Callable, Callable]] + | None = None, +) -> None: + torch._C._log_api_usage_once( + "quantization_api._numeric_suite_fx._extract_weights_one_model" + ) + for node, ref_name in nodes_and_names_to_instrument: + res_type = NSSingleResultValuesType.WEIGHT.value + extracted_weight = extract_weight_from_node( + node, model, op_to_type_to_weight_extraction_fn + ) + if extracted_weight: + if ref_name not in results: + results[ref_name] = {res_type: {}} + results[ref_name][res_type][model_name] = [extracted_weight] + + +def _extract_weights_impl( + model_name_a: str, + gm_a: GraphModule, + model_name_b: str, + gm_b: GraphModule, + base_name_to_sets_of_related_ops: dict[str, set[NSNodeTargetType]] | None = None, + unmatchable_types_map: dict[str, set[NSNodeTargetType]] | None = None, + op_to_type_to_weight_extraction_fn: dict[str, dict[Callable, Callable]] + | None = None, +) -> NSResultsType: + torch._C._log_api_usage_once( + "quantization_api._numeric_suite_fx._extract_weights_impl" + ) + matched_subgraph_pairs = get_matching_subgraph_pairs( + gm_a, gm_b, base_name_to_sets_of_related_ops, unmatchable_types_map + ) + + # split the subgraph pairs into one data structure for each model + nodes_and_names_to_instrument_a: list[tuple[Node, str]] = [] + nodes_and_names_to_instrument_b: list[tuple[Node, str]] = [] + for match_name, match in matched_subgraph_pairs.items(): + subgraph_a, subgraph_b = match + nodes_and_names_to_instrument_a.append((subgraph_a.base_op_node, match_name)) + nodes_and_names_to_instrument_b.append((subgraph_b.base_op_node, match_name)) + + # populate the results, one model at a time + results: NSResultsType = {} + _extract_weights_one_model( + model_name_a, + gm_a, + nodes_and_names_to_instrument_a, + results, + op_to_type_to_weight_extraction_fn, + ) + _extract_weights_one_model( + model_name_b, + gm_b, + nodes_and_names_to_instrument_b, + results, + op_to_type_to_weight_extraction_fn, + ) + + # fill in missing fqn entries + maybe_add_missing_fqns(results) + + # rekey on names of nodes in gm_b + results = rekey_logger_info_on_node_name_of_model(results, model_name_b) + + return results + + +def extract_weights( + model_name_a: str, + model_a: nn.Module, + model_name_b: str, + model_b: nn.Module, + base_name_to_sets_of_related_ops: dict[str, set[NSNodeTargetType]] | None = None, + unmatchable_types_map: dict[str, set[NSNodeTargetType]] | None = None, + op_to_type_to_weight_extraction_fn: dict[str, dict[Callable, Callable]] + | None = None, +) -> NSResultsType: + """ + Extract weights from model A and model B, and return a comparison. + + Args: + model_name_a: string name of model A to use in results + model_a: model A + model_name_b: string name of model B to use in results + model_b: model B + base_name_to_sets_of_related_ops: optional override of subgraph base nodes, subject to change + unmatchable_types_map: optional override of unmatchable types, subject to change + op_to_type_to_weight_extraction_fn: optional override of function which extracts weight + from a type, subject to change + + Return: + NSResultsType, containing the weight comparisons + """ + + torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.extract_weights") + if base_name_to_sets_of_related_ops is None: + base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops() + + # TODO(future PR): expose these + skipped_module_names: list[str] = [] + skipped_module_classes: list[Callable] = [] + tracer_a = NSTracer(skipped_module_names, skipped_module_classes) + tracer_b = NSTracer(skipped_module_names, skipped_module_classes) + gm_a = GraphModule(model_a, tracer_a.trace(model_a)) + maybe_model_a_node_name_to_scope = _get_observed_graph_module_attr( + model_a, "node_name_to_scope" + ) + if maybe_model_a_node_name_to_scope is not None: + gm_a._node_name_to_scope = maybe_model_a_node_name_to_scope + gm_b = GraphModule(model_b, tracer_b.trace(model_b)) + maybe_model_b_node_name_to_scope = _get_observed_graph_module_attr( + model_b, "node_name_to_scope" + ) + if maybe_model_b_node_name_to_scope is not None: + gm_b._node_name_to_scope = maybe_model_b_node_name_to_scope + return _extract_weights_impl( + model_name_a, + gm_a, + model_name_b, + gm_b, + base_name_to_sets_of_related_ops, + unmatchable_types_map, + op_to_type_to_weight_extraction_fn, + ) + + +def _add_loggers_one_model( + model_name: str, + model: GraphModule, + nodes_and_names_to_instrument_inputs: list[tuple[Node, str, str]], + nodes_and_names_to_instrument_outputs: list[tuple[Node, str, str]], + logger_cls: Callable, +) -> nn.Module: + torch._C._log_api_usage_once( + "quantization_api._numeric_suite_fx._add_loggers_one_model" + ) + + # TODO(future PR): do not observe nodes we do not care + # about (both fp32, denylist, etc) + node_to_instrument_inputs_to_ref_name: dict[Node, tuple[str, str]] = {} + node_to_instrument_outputs_to_ref_name: dict[Node, tuple[str, str]] = {} + for node, ref_name, ref_node_type in nodes_and_names_to_instrument_inputs: + node_to_instrument_inputs_to_ref_name[node] = (ref_name, ref_node_type) + for node, ref_name, ref_node_type in nodes_and_names_to_instrument_outputs: + node_to_instrument_outputs_to_ref_name[node] = (ref_name, ref_node_type) + + model = add_loggers_to_model( + model, + node_to_instrument_inputs_to_ref_name, + node_to_instrument_outputs_to_ref_name, + logger_cls, + model_name, + ) + return model + + +def _add_loggers_impl( + name_a: str, + gm_a: GraphModule, + name_b: str, + gm_b: GraphModule, + logger_cls: Callable, + should_log_inputs: bool, + base_name_to_sets_of_related_ops: dict[str, set[NSNodeTargetType]] | None = None, + unmatchable_types_map: dict[str, set[NSNodeTargetType]] | None = None, +) -> tuple[nn.Module, nn.Module]: + torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._add_loggers_impl") + matched_subgraph_pairs = get_matching_subgraph_pairs( + gm_a, gm_b, base_name_to_sets_of_related_ops, unmatchable_types_map + ) + nodes_and_names_to_instrument_inputs_a = [] + nodes_and_names_to_instrument_inputs_b = [] + nodes_and_names_to_instrument_outputs_a = [] + nodes_and_names_to_instrument_outputs_b = [] + for match_name, (subgraph_a, subgraph_b) in matched_subgraph_pairs.items(): + ref_node_type_a = get_target_type_str(subgraph_a.base_op_node, gm_a) + ref_node_type_b = get_target_type_str(subgraph_b.base_op_node, gm_b) + # Note: for matching inputs we use start_node, such as observing + # the input of linear in linear-relu + if should_log_inputs: + nodes_and_names_to_instrument_inputs_a.append( + (subgraph_a.start_node, match_name, ref_node_type_a) + ) + nodes_and_names_to_instrument_inputs_b.append( + (subgraph_b.start_node, match_name, ref_node_type_b) + ) + # Note: for matching activations we always use end_node, + # such as observing the output of relu in linear-relu + nodes_and_names_to_instrument_outputs_a.append( + (subgraph_a.end_node, match_name, ref_node_type_a) + ) + nodes_and_names_to_instrument_outputs_b.append( + (subgraph_b.end_node, match_name, ref_node_type_b) + ) + + new_model_a = _add_loggers_one_model( + name_a, + gm_a, + nodes_and_names_to_instrument_inputs_a, + nodes_and_names_to_instrument_outputs_a, + logger_cls, + ) + new_model_b = _add_loggers_one_model( + name_b, + gm_b, + nodes_and_names_to_instrument_inputs_b, + nodes_and_names_to_instrument_outputs_b, + logger_cls, + ) + return (new_model_a, new_model_b) + + +def add_loggers( + name_a: str, + model_a: nn.Module, + name_b: str, + model_b: nn.Module, + logger_cls: Callable, + should_log_inputs: bool = False, + base_name_to_sets_of_related_ops: dict[str, set[NSNodeTargetType]] | None = None, + unmatchable_types_map: dict[str, set[NSNodeTargetType]] | None = None, +) -> tuple[nn.Module, nn.Module]: + """ + Instrument model A and model B with loggers. + + Args: + name_a: string name of model A to use in results + model_a: model A + name_b: string name of model B to use in results + model_b: model B + logger_cls: class of Logger to use + base_name_to_sets_of_related_ops: optional override of subgraph base nodes, subject to change + unmatchable_types_map: optional override of unmatchable types, subject to change + + Return: + Returns a tuple of (model_a_with_loggers, model_b_with_loggers). Modifies both models inplace. + """ + + torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.add_loggers") + # TODO(future PR): expose these + skipped_module_names: list[str] = [] + skipped_module_classes: list[Callable] = [] + tracer_a = NSTracer(skipped_module_names, skipped_module_classes) + tracer_b = NSTracer(skipped_module_names, skipped_module_classes) + gm_a = GraphModule(model_a, tracer_a.trace(model_a)) + maybe_model_a_node_name_to_scope = _get_observed_graph_module_attr( + model_a, "node_name_to_scope" + ) + if maybe_model_a_node_name_to_scope is not None: + gm_a._node_name_to_scope = maybe_model_a_node_name_to_scope + gm_b = GraphModule(model_b, tracer_b.trace(model_b)) + maybe_model_b_node_name_to_scope = _get_observed_graph_module_attr( + model_b, "node_name_to_scope" + ) + if maybe_model_b_node_name_to_scope is not None: + gm_b._node_name_to_scope = maybe_model_b_node_name_to_scope + return _add_loggers_impl( + name_a, + gm_a, + name_b, + gm_b, + logger_cls, + should_log_inputs=should_log_inputs, + base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops, + unmatchable_types_map=unmatchable_types_map, + ) + + +def _extract_logger_info_one_model( + model: nn.Module, + results: NSResultsType, + logger_cls: Callable, +) -> None: + torch._C._log_api_usage_once( + "quantization_api._numeric_suite_fx._extract_logger_info_one_model" + ) + for _gm_name, mod in model.named_modules(): + # TODO(future PR): better check when scripted + is_logger = isinstance(mod, logger_cls) or ( # type: ignore[arg-type] + isinstance(mod, torch.jit.RecursiveScriptModule) + and mod.original_name == "OutputLogger" + ) + if is_logger: + key = mod.ref_name + if key not in results: + results[key] = {} + if mod.model_name in results[key]: + raise AssertionError(f"{mod.model_name} is already present in results") + if mod.results_type not in results[key]: + results[key][mod.results_type] = {} + if mod.model_name not in results[key][mod.results_type]: + results[key][mod.results_type][mod.model_name] = [] + stats_to_use = mod.stats + if len(mod.stats_rnn) > 0: + stats_to_use = mod.stats_rnn + data = { + "type": mod.results_type, + "values": stats_to_use, + "ref_node_name": mod.ref_node_name, + "ref_node_target_type": mod.ref_node_target_type, + "prev_node_name": mod.prev_node_name, + "prev_node_target_type": mod.prev_node_target_type, + "index_within_arg": mod.index_within_arg, + "index_of_arg": mod.index_of_arg, + "fqn": mod.fqn, + "qconfig_str": mod.qconfig_str, + } + if hasattr(mod, "comparisons"): + data["comparisons"] = mod.comparisons + data["comparison_fn_name"] = mod.comparison_fn_name + else: + data["comparisons"] = [] + data["comparison_fn_name"] = "" + results[key][mod.results_type][mod.model_name].append(data) + # ensure the list stays sorted + results[key][mod.results_type][mod.model_name].sort( + key=lambda res: f"{res['index_of_arg']}:{res['index_within_arg']}" + ) + + +# TODO(future PR): align on naming +# this is equivalent of just the comparison extraction part of `ns.compare_model_outputs` +def extract_logger_info( + model_a: nn.Module, + model_b: nn.Module, + logger_cls: Callable, + model_name_to_use_for_layer_names: str, +) -> NSResultsType: + """ + Traverse all loggers in `model_a` and `model_b`, and extract the logged + information. + + Args: + model_a: model A + model_b: model B + logger_cls: class of Logger to use + model_name_to_use_for_layer_names: string name of model to use for + layer names in the output + + Return: + NSResultsType, containing the logged comparisons + """ + torch._C._log_api_usage_once( + "quantization_api._numeric_suite_fx.extract_logger_info" + ) + results: NSResultsType = {} + for model in (model_a, model_b): + _extract_logger_info_one_model(model, results, logger_cls) + # fill in missing fqn entries + maybe_add_missing_fqns(results) + # rekey on the name of model b + results = rekey_logger_info_on_node_name_of_model( + results, model_name_to_use_for_layer_names + ) + return results + + +def _add_shadow_loggers_impl( + name_a: str, + gm_a: GraphModule, + name_b: str, + gm_b: GraphModule, + logger_cls: Callable, + should_log_inputs: bool, + base_name_to_sets_of_related_ops: dict[str, set[NSNodeTargetType]] | None = None, + node_type_to_io_type_map: dict[str, set[NSNodeTargetType]] | None = None, + unmatchable_types_map: dict[str, set[NSNodeTargetType]] | None = None, +) -> nn.Module: + torch._C._log_api_usage_once( + "quantization_api._numeric_suite_fx._add_shadow_loggers_impl" + ) + matched_subgraph_pairs = get_matching_subgraph_pairs( + gm_a, gm_b, base_name_to_sets_of_related_ops, unmatchable_types_map + ) + gm_a_shadows_b = create_a_shadows_b( + name_a, + gm_a, + name_b, + gm_b, + matched_subgraph_pairs, + logger_cls, + should_log_inputs=should_log_inputs, + node_type_to_io_type_map=node_type_to_io_type_map, + ) + return gm_a_shadows_b + + +def add_shadow_loggers( + name_a: str, + model_a: nn.Module, + name_b: str, + model_b: nn.Module, + logger_cls: Callable, + should_log_inputs: bool = False, + base_name_to_sets_of_related_ops: dict[str, set[NSNodeTargetType]] | None = None, + node_type_to_io_type_map: dict[str, set[NSNodeTargetType]] | None = None, + unmatchable_types_map: dict[str, set[NSNodeTargetType]] | None = None, +) -> nn.Module: + """ + Instrument model A and model B with shadow loggers. + + Args: + name_a: string name of model A to use in results + model_a: model A + name_b: string name of model B to use in results + model_b: model B + logger_cls: class of Logger to use + should_log_inputs: whether to log inputs + base_name_to_sets_of_related_ops: optional override of subgraph base nodes, subject to change + unmatchable_types_map: optional override of unmatchable types, subject to change + """ + torch._C._log_api_usage_once( + "quantization_api._numeric_suite_fx.add_shadow_loggers" + ) + # TODO(future PR): expose these + skipped_module_names: list[str] = [] + skipped_module_classes: list[Callable] = [] + tracer_a = NSTracer(skipped_module_names, skipped_module_classes) + tracer_b = NSTracer(skipped_module_names, skipped_module_classes) + gm_a = GraphModule(model_a, tracer_a.trace(model_a)) + maybe_model_a_node_name_to_scope = _get_observed_graph_module_attr( + model_a, "node_name_to_scope" + ) + if maybe_model_a_node_name_to_scope is not None: + gm_a._node_name_to_scope = maybe_model_a_node_name_to_scope + gm_b = GraphModule(model_b, tracer_b.trace(model_b)) + maybe_model_b_node_name_to_scope = _get_observed_graph_module_attr( + model_b, "node_name_to_scope" + ) + if maybe_model_b_node_name_to_scope is not None: + gm_b._node_name_to_scope = maybe_model_b_node_name_to_scope + return _add_shadow_loggers_impl( + name_a, + gm_a, + name_b, + gm_b, + logger_cls, + should_log_inputs=should_log_inputs, + base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops, + node_type_to_io_type_map=node_type_to_io_type_map, + unmatchable_types_map=unmatchable_types_map, + ) + + +def extract_shadow_logger_info( + model_a_shadows_b: nn.Module, + logger_cls: Callable, + model_name_to_use_for_layer_names: str, +) -> NSResultsType: + """ + Traverse all loggers in a shadow model, and extract the logged + information. + + Args: + model_a_shadows_b: shadow model + logger_cls: class of Logger to use + model_name_to_use_for_layer_names: string name of model to use for + layer names in the output + + Return: + NSResultsType, containing the logged comparisons + """ + torch._C._log_api_usage_once( + "quantization_api._numeric_suite_fx.extract_shadow_logger_info" + ) + results: NSResultsType = collections.defaultdict(dict) + _extract_logger_info_one_model(model_a_shadows_b, results, logger_cls) + # fill in missing fqn entries + maybe_add_missing_fqns(results) + # rekey on the name of model b + results = rekey_logger_info_on_node_name_of_model( + results, model_name_to_use_for_layer_names + ) + return dict(results) + + +def extend_logger_results_with_comparison( + results: NSResultsType, + model_name_1: str, + model_name_2: str, + comparison_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], + comparison_name: str, +) -> None: + """ + Compares the logged values from `model_name_2` against the corresponding + values in `model_name_1`, using `comparison_fn`. Records the result + in `model_name_2`'s results under `comparison_name`. Modifies `results` inplace. + + Args: + results: the result data structure from `extract_logger_info` or + `extract_shadow_logger_info`. + model_name_1: string name of model 1 + model_name_2: string name of model 2 + comparison_fn: function to compare two Tensors + comparison_name: string name of model to use for + layer names in the output + """ + for results_type_to_results in results.values(): + for model_name_to_results in results_type_to_results.values(): + if model_name_1 not in model_name_to_results: + raise AssertionError(f"{model_name_1} not found in results") + if model_name_2 not in model_name_to_results: + raise AssertionError(f"{model_name_2} not found in results") + + results_1 = model_name_to_results[model_name_1] + results_2 = model_name_to_results[model_name_2] + + for result_2 in results_2: + index_within_arg_2 = result_2["index_within_arg"] + index_of_arg_2 = result_2["index_of_arg"] + # find corresponding result_1 + result_1 = None + for cur_result_1 in results_1: + index_within_arg_1 = cur_result_1["index_within_arg"] + index_of_arg_1 = cur_result_1["index_of_arg"] + if (index_within_arg_1 == index_within_arg_2) and ( + index_of_arg_1 == index_of_arg_2 + ): + result_1 = cur_result_1 + break + if result_1 is None: + raise AssertionError("Expected result_1 to be not None") + + values_1 = result_1["values"] + values_2 = result_2["values"] + result_2[comparison_name] = [] + for value_1, value_2 in zip(values_1, values_2): + comparison_result = comparison_fn(value_1, value_2) + result_2[comparison_name].append(comparison_result) + + +def prepare_n_shadows_model( + model: torch.nn.Module, + example_inputs: Any, + qconfig_multi_mapping: QConfigMultiMapping, + backend_config: BackendConfig, + custom_prepare_fn: Callable | None = None, + custom_prepare_kwargs: dict[str, Any] | None = None, + custom_tracer: Any = None, +) -> GraphModule: + """ + Given a model with a graph with M ops such as + + + args_kwargs_m -> op_m -> output_m + + + And a set of N qconfigs for each op, creates a new model, with + each of the subgraph of `op_m` transformed into + + .. code:: + + |---------> op_m_n -> log_m_n + | / + args_kwargs_m ---------> op_m -> log_m_0 + + Where op_m_n is op_m wrapped in a submodule and transformed with + qconfig_n, and its inner graph looks like + + .. code:: + + args_m -------- op_m_prepared_with_qconfig_n -> out_m_n + / + kwargs_m --- + + This is useful for testing different quantization of multiple layers in + a single pass through the model. + + High level TODOs for future PRs: + * figure out a better way to name the output structure + * return a results data structure instead of printing it out + * add examples to docblocks + """ + + if custom_tracer is None: + tracer = quantize_fx.QuantizationTracer([], []) + else: + tracer = custom_tracer + mt = torch.fx.GraphModule(model, tracer.trace(model)) + # this is necessary to ensure logger FQNs get populated + mt._node_name_to_scope = tracer.node_name_to_scope # type: ignore[assignment] + + # run example input propagation, we need this to call prepare_fx on + # individual subgraphs + output_prop = OutputProp(mt) + output_prop.propagate(*example_inputs) + + # Find the set of subgraphs in the original graph which we need to + # consider. + modules = dict(mt.named_modules(remove_duplicate=False)) + patterns = _get_pattern_to_quantize_handlers(backend_config) + root_node_getter_mapping = get_fusion_pattern_to_root_node_getter(backend_config) + standalone_module_names: list[str] = [] + standalone_module_classes: list[type] = [] + custom_module_classes: list[type] = [] + matches = _find_matches( + mt.graph, + modules, + patterns, + root_node_getter_mapping, + standalone_module_names, + standalone_module_classes, + custom_module_classes, + ) + subgraphs_dedup: dict[str, list[Node]] = _get_dedup_subgraphs(matches) + + # generate node to qconfig for each subgraph + # TODO(future PR): deduplicate repeating entries + list_of_node_name_to_qconfig: list[dict[str, QConfigAny]] = [] + for qconfig_mapping in qconfig_multi_mapping.qconfig_mappings_list: + node_name_to_qconfig = _generate_node_name_to_qconfig( + mt, modules, mt.graph, qconfig_mapping, tracer.node_name_to_scope + ) + list_of_node_name_to_qconfig.append(node_name_to_qconfig) + + # For each region in the model, do the following: + # For each qconfig for that region, do the following: + # 1. create a copy of the region wrapped in a module + # 2. pass original args, original kwargs, and expected output to module + # 3. add an output comparison logger and hook it up to compare + # actual output to expected output + # 4. run `prepare_fx` on the module + for subgraph_idx, (match_name, nodes_in_this_subgraph) in enumerate( + subgraphs_dedup.items() + ): + create_n_transformed_and_logged_copies_of_subgraph( + mt, + subgraph_idx, + match_name, + nodes_in_this_subgraph, + qconfig_multi_mapping.qconfig_mappings_list, + list_of_node_name_to_qconfig, + custom_prepare_fn, + custom_prepare_kwargs, # type: ignore[arg-type] + ) + + return mt + + +# TODO(future PR): we should rethink the names of all the PNP APIs +def _prepare_n_shadows_add_loggers_model( + model: torch.nn.Module, + example_inputs: Any, + qconfig_mapping: QConfigMapping, + backend_config: BackendConfig, +) -> torch.nn.Module: + r""" + Note: this API is not recommended for wide usage, it is only + provided for customers who need to migrate from the `add_loggers` + API. + + This creates a model which provides logging for the following + problem: if we quantize `model` with `qconfig_mapping` and feed + the same input through both models, log the comparisons of + corresponding intermediate layers. + + The problem is solved with a single model. Specifically, we + partition `model` into N subgraphs, create a copy of each relevant + subgraph, wrap it in a module, apply the quantization API to that + module, and hook up loggers to measure the comparisons. + + Example starting graph: + + x0 -> op0 -> x1 -> op1 -> x2 + + Example config: quantize op0 to int8, do nothing to op1. + The following graph will be created: + + .. code:: + + x0_0 -> op0_0 -> x1_0 -> log -----> op1_0 -> x2_0 -> log + \ \ \ # noqa: W605 + ---> op0_1 -> x1_1 ----> clog -> op1_0 -> x2_1 ----> clog + + Where op0_0 is op0, op0_1 is op0 wrapped in a submodule and quantized + to int8, op1_0 is op1 (appearing in the graph twice), log is a logger, + and clog is a comparison logger. + """ + + tracer = quantize_fx.QuantizationTracer([], []) + mt = torch.fx.GraphModule(model, tracer.trace(model)) + # this is necessary to ensure logger FQNs get populated + mt._node_name_to_scope = tracer.node_name_to_scope # type: ignore[assignment] + + # run example input propagation, we need this to call prepare_fx on + # individual subgraphs + output_prop = OutputProp(mt) + output_prop.propagate(*example_inputs) + + # Find the set of subgraphs in the original graph which we need to + # consider. + modules = dict(mt.named_modules(remove_duplicate=False)) + patterns = _get_pattern_to_quantize_handlers(backend_config) + root_node_getter_mapping = get_fusion_pattern_to_root_node_getter(backend_config) + standalone_module_names: list[str] = [] + standalone_module_classes: list[type] = [] + custom_module_classes: list[type] = [] + matches = _find_matches( + mt.graph, + modules, + patterns, + root_node_getter_mapping, + standalone_module_names, + standalone_module_classes, + custom_module_classes, + ) + subgraphs_dedup: dict[str, list[Node]] = _get_dedup_subgraphs(matches) + + # generate node to qconfig for each subgraph + node_name_to_qconfig = _generate_node_name_to_qconfig( + mt, modules, mt.graph, qconfig_mapping, tracer.node_name_to_scope + ) + + # Now, mutate the graph to be the add_loggers graph with propagation + # error. + create_add_loggers_graph(mt, subgraphs_dedup, qconfig_mapping, node_name_to_qconfig) + + return mt + + +# TODO(future PR): we should rethink the names of all the PNP APIs +def _n_shadows_compare_weights( + model: torch.nn.Module, + example_inputs: Any, + qconfig_mapping: QConfigMapping, + backend_config: BackendConfig, +) -> NSResultsType: + """ + Note: this API is not recommended for wide usage, it is only + provided for customers who need to migrate from the `add_loggers` + API. + """ + qconfig_multi_mapping = QConfigMultiMapping.from_list_qconfig_mapping( + [qconfig_mapping] + ) + mp = prepare_n_shadows_model( + model, example_inputs, qconfig_multi_mapping, backend_config + ) + # passing inputs through the model is necessary to populate + # observers which observe weights with real values + mp(*example_inputs) + mq = convert_n_shadows_model(mp) + weight_comparison = extract_weight_comparison(mq) + return weight_comparison + + +# TODO(future PR): consider aligning API signature with other similar quantization +# functions (enable_fake_quant, etc) +def loggers_set_enabled(model: torch.nn.Module, enabled: bool) -> None: + """ + Sets the `enabled` setting on a `model`'s loggers + """ + for _, child in model.named_modules(): + if isinstance(child, OutputLogger): + child.enabled = enabled + + +# TODO(future PR): consider aligning API signature with other similar quantization +# functions (enable_fake_quant, etc) +def loggers_set_save_activations( + model: torch.nn.Module, + save_activations: bool, +) -> None: + """ + Sets the `save_activations` setting on a `model`'s loggers + """ + for _name, child in model.named_modules(): + if isinstance(child, OutputLogger): + child.save_activations = save_activations + + +def convert_n_shadows_model( + model: GraphModule, + custom_convert_fn: Callable | None = None, + custom_convert_kwargs: dict[str, Any] | None = None, +) -> GraphModule: + """ + Given a model from `prepare_n_shadows_model`, runs `convert_fx` + on each shadow submodule. + """ + for node in model.graph.nodes: + # TODO(future PR): consider matching in a safer way than + # node name string match + if node.name.startswith(SHADOW_WRAPPER_NODE_NAME_PREFIX): + orig_mod = getattr(model, node.name) + if custom_convert_fn is None: + converted_mod = torch.ao.quantization.quantize_fx.convert_fx(orig_mod) + else: + if custom_convert_kwargs is None: + custom_convert_kwargs = {} + converted_mod = custom_convert_fn(orig_mod, **custom_convert_kwargs) + setattr(model, node.name, converted_mod) + + return model + + +def extract_results_n_shadows_model(model: torch.nn.Module) -> NSResultsType: + """ + Extracts logger results from `model`. + """ + results: NSResultsType = {} + _extract_logger_info_one_model(model, results, OutputLogger) + return results + + +def print_comparisons_n_shadows_model(results: NSResultsType) -> None: + """ + Prints a summary of extracted `results`. + """ + results_grouped = group_results_by_subgraph(results) + results_comparison = create_results_comparison(results_grouped) + print_n_shadows_summary(results_comparison) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..52fc301befd34642d51f1c27e07600a1f3ef26ff --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/__init__.py @@ -0,0 +1,23 @@ +# Variables +from ._mappings import ( + get_dynamic_sparse_quantized_mapping, + get_static_sparse_quantized_mapping, +) + +# Scheduler +from .scheduler.base_scheduler import BaseScheduler +from .scheduler.cubic_scheduler import CubicSL +from .scheduler.lambda_scheduler import LambdaSL + +# Sparsifier +from .sparsifier.base_sparsifier import BaseSparsifier +from .sparsifier.nearly_diagonal_sparsifier import NearlyDiagonalSparsifier + +# Parametrizations +from .sparsifier.utils import ( + FakeSparsity, + fqn_to_module, + get_arg_info_from_tensor_fqn, + module_to_fqn, +) +from .sparsifier.weight_norm_sparsifier import WeightNormSparsifier diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..75329434e520197aaad9a751184946f203729490 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/__pycache__/_mappings.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/__pycache__/_mappings.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8213797b6b8d44c2b7461581c2f723e68c5fd9d4 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/__pycache__/_mappings.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/FPGM_pruner.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/FPGM_pruner.py new file mode 100644 index 0000000000000000000000000000000000000000..1a89de12bd9345a05acee98309f90d38d70daac1 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/FPGM_pruner.py @@ -0,0 +1,96 @@ +# mypy: allow-untyped-defs +from collections.abc import Callable + +import torch + +from .base_structured_sparsifier import BaseStructuredSparsifier + + +__all__ = ["FPGMPruner"] + + +class FPGMPruner(BaseStructuredSparsifier): + r"""Filter Pruning via Geometric Median (FPGM) Structured Pruner + This sparsifier prune filter (row) in a tensor according to distances among filters according to + `Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration `_. + + This sparsifier is controlled by three variables: + 1. `sparsity_level` defines the number of filters (rows) that are zeroed-out. + 2. `dist` defines the distance measurement type. Default: 3 (L2 distance). + Available options are: [1, 2, (custom callable distance function)]. + + Note:: + Inputs should be a 4D convolutional tensor of shape (N, C, H, W). + - N: output channels size + - C: input channels size + - H: height of kernel + - W: width of kernel + """ + + def __init__(self, sparsity_level: float = 0.5, dist: Callable | int | None = None): + defaults = { + "sparsity_level": sparsity_level, + } + + if dist is None: + dist = 2 + + if callable(dist): + self.dist_fn = dist + elif dist == 1: + self.dist_fn = lambda x: torch.cdist(x, x, p=1) + elif dist == 2: + self.dist_fn = lambda x: torch.cdist(x, x, p=2) + else: + raise NotImplementedError("Distance function is not yet implemented.") + super().__init__(defaults=defaults) + + def _compute_distance(self, t): + r"""Compute distance across all entries in tensor `t` along all dimension + except for the one identified by dim. + Args: + t (torch.Tensor): tensor representing the parameter to prune + Returns: + distance (torch.Tensor): distance computed across filtters + """ + dim = 0 # prune filter (row) + + size = t.size(dim) + slc = [slice(None)] * t.dim() + + # flatten the tensor along the dimension + t_flatten = [ + t[tuple(slc[:dim] + [slice(i, i + 1)] + slc[dim + 1 :])].reshape(-1) + for i in range(size) + ] + t_flatten = torch.stack(t_flatten) + + # distance measurement + dist_matrix = self.dist_fn(t_flatten) + + # more similar with other filter indicates large in the sum of row + # pyrefly: ignore [bad-argument-type] + distance = torch.sum(torch.abs(dist_matrix), 1) + + return distance + + def update_mask( # type: ignore[override] + self, module, tensor_name, sparsity_level, **kwargs + ): + tensor_weight = getattr(module, tensor_name) + mask = getattr(module.parametrizations, tensor_name)[0].mask + + if sparsity_level <= 0: + mask.data = torch.ones_like(mask).bool() + elif sparsity_level >= 1.0: + mask.data = torch.zeros_like(mask).bool() + else: + distance = self._compute_distance(tensor_weight) + + tensor_size = tensor_weight.shape[0] # prune filter (row) + nparams_toprune = round(sparsity_level * tensor_size) + nparams_toprune = min( + max(nparams_toprune, 0), tensor_size + ) # clamp to [0, tensor_size] + topk = torch.topk(distance, k=nparams_toprune, largest=False) + mask[topk.indices] = False diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_mappings.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_mappings.py new file mode 100644 index 0000000000000000000000000000000000000000..6fc2c4f10aef5585072f36116282a2048965197a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/_mappings.py @@ -0,0 +1,23 @@ +# mypy: allow-untyped-defs +__all__ = [ + "get_static_sparse_quantized_mapping", + "get_dynamic_sparse_quantized_mapping", +] + + +def get_static_sparse_quantized_mapping(): + import torch.ao.nn.sparse + + _static_sparse_quantized_mapping = { + torch.nn.Linear: torch.ao.nn.sparse.quantized.Linear, + } + return _static_sparse_quantized_mapping + + +def get_dynamic_sparse_quantized_mapping(): + import torch.ao.nn.sparse + + _dynamic_sparse_quantized_mapping = { + torch.nn.Linear: torch.ao.nn.sparse.quantized.dynamic.Linear, + } + return _dynamic_sparse_quantized_mapping diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/scheduler/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/scheduler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/scheduler/base_scheduler.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/scheduler/base_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..ac8916713dae6fe008b75e6dca9d63851560ab6e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/scheduler/base_scheduler.py @@ -0,0 +1,173 @@ +# mypy: allow-untyped-defs + +import warnings +import weakref +from functools import wraps + +from torch.ao.pruning.sparsifier.base_sparsifier import BaseSparsifier + + +__all__ = ["BaseScheduler"] + + +class BaseScheduler: + def __init__(self, sparsifier, last_epoch=-1, verbose=False): + # Attach sparsifier + if not isinstance(sparsifier, BaseSparsifier): + raise TypeError( + f"{type(sparsifier).__name__} is not an instance of torch.ao.pruning.BaseSparsifier" + ) + self.sparsifier = sparsifier + + # Initialize epoch and base sparsity levels + + self.base_sl = [group["sparsity_level"] for group in sparsifier.groups] + self.last_epoch = last_epoch + + # Following https://github.com/pytorch/pytorch/issues/20124 + # We would like to ensure that `scheduler.step()` is called after + # `sparsifier.step()` + def with_counter(method): + if getattr(method, "_with_counter", False): + # `sparsifier.step()` has already been replaced, return. + return method + + # Keep a weak reference to the sparsifier instance to prevent + # cyclic references. + instance_ref = weakref.ref(method.__self__) + # Get the unbound method for the same purpose. + func = method.__func__ + cls = instance_ref().__class__ + del method + + @wraps(func) + def wrapper(*args, **kwargs): + instance = instance_ref() + instance._step_count += 1 # type: ignore[union-attr] + wrapped = func.__get__(instance, cls) + return wrapped(*args, **kwargs) + + # Note that the returned function here is no longer a bound method, + # so attributes like `__func__` and `__self__` no longer exist. + wrapper._with_counter = True # type: ignore[attr-defined] + return wrapper + + self.sparsifier.step = with_counter(self.sparsifier.step) # type: ignore[assignment] + self.sparsifier._step_count = 0 # type: ignore[attr-defined] + self._step_count: int = 0 + self.verbose = verbose + + # Housekeeping + self._get_sl_called_within_step: bool = False + + self.step() + + def state_dict(self): + """Returns the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the sparsifier. + """ + return { + key: value for key, value in self.__dict__.items() if key != "sparsifier" + } + + def load_state_dict(self, state_dict): + """Loads the schedulers state. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + self.__dict__.update(state_dict) + + def get_last_sl(self): + """Return last computed sparsity level by current scheduler.""" + return self._last_sl + + def get_sl(self): + # Compute sparsity level using chainable form of the scheduler + # Note: This method is not intended to be called directly, and is only + # used by the ".step" method. Use .get_last_sl() instead. + if not self._get_sl_called_within_step: + warnings.warn( + "To get the last sparsity level computed by the scheduler, " + "please use `get_last_sl()`.", + stacklevel=2, + ) + raise NotImplementedError + + def print_sl(self, is_verbose, group, sl, epoch=None): + """Display the current sparsity level.""" + if is_verbose: + if epoch is None: + print(f"Adjusting sparsity level of group {group} to {sl:.4e}.") + else: + print( + f"Epoch {epoch:5d}: adjusting sparsity level of group {group} to {sl:.4e}." + ) + + def __repr__(self): + format_string = self.__class__.__name__ + " (" + format_string += "\n" + format_string += f"Sparsifier {self.sparsifier}\n" + format_string += f" base_sl: {self.base_sl}\n" + format_string += ")" + return format_string + + def step(self, epoch=None): + # Raise warning if trying to call scheduler step before the sparsifier. + # https://github.com/pytorch/pytorch/issues/20124 + if self._step_count == 1: + if not hasattr(self.sparsifier.step, "_with_counter"): + warnings.warn( + "Seems like `sparsifier.step()` has been overridden after sparsity scheduler " + "initialization. Please, make sure to call `sparsifier.step()` before " + "`scheduler.step()`.", + UserWarning, + stacklevel=2, + ) + + # Just check if there were two first scheduler.step() calls before sparsifier.step() + elif self.sparsifier._step_count < 1: # type: ignore[attr-defined] + warnings.warn( + "Detected call of `scheduler.step()` before `sparsifier.step()`. " + "You have to make sure you run the sparsifier.step() BEFORE any " + "calls to the scheduler.step().", + UserWarning, + stacklevel=2, + ) + self._step_count += 1 + + class _enable_get_sl_call: + def __init__(self, o): + self.o = o + + def __enter__(self): + self.o._get_sl_called_within_step = True + return self + + def __exit__(self, type, value, traceback): + self.o._get_sl_called_within_step = False + + with _enable_get_sl_call(self): + self.last_epoch += 1 + values = self.get_sl() + + for i, data in enumerate(zip(self.sparsifier.groups, values)): + param_group, sl = data + param_group["sparsity_level"] = sl + self.print_sl(self.verbose, i, sl, epoch) + + self._last_sl = [group["sparsity_level"] for group in self.sparsifier.groups] + self.sparsifier.enable_mask_update = True + + def _make_sure_a_list(self, var): + r"""Utility that extends it to the same length as the .groups, ensuring it is a list""" + n = len(self.sparsifier.groups) + if not isinstance(var, (list, tuple)): + return [var] * n + else: + if len(var) != n: + raise ValueError(f"Expected variable of length {n}, but got {len(var)}") + return list(var) # We want the result to be in a list, not tuple diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/scheduler/cubic_scheduler.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/scheduler/cubic_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..d4706900762adf411eb68dfd7fee3ff9fed36b51 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/scheduler/cubic_scheduler.py @@ -0,0 +1,114 @@ +# mypy: allow-untyped-defs +import warnings + +from .base_scheduler import BaseScheduler + + +__all__ = ["CubicSL"] + + +def _clamp(x, lo, hi): + return max(lo, min(hi, x)) + + +class CubicSL(BaseScheduler): + r"""Sets the sparsity level of each parameter group to the final sl + plus a given exponential function. + + .. math:: + + s_i = s_f + (s_0 - s_f) \cdot \left( 1 - \frac{t - t_0}{n\Delta t} \right)^3 + + where :math:`s_i` is the sparsity at epoch :math:`t`, :math;`s_f` is the final + sparsity level, :math:`f(i)` is the function to be applied to the current epoch + :math:`t`, initial epoch :math:`t_0`, and final epoch :math:`t_f`. + :math:`\Delta t` is used to control how often the update of the sparsity level + happens. By default, + + Args: + sparsifier (BaseSparsifier): Wrapped sparsifier. + init_sl (int, list): Initial level of sparsity + init_t (int, list): Initial step, when pruning starts + delta_t (int, list): Pruning frequency + total_t (int, list): Total number of pruning steps + initially_zero (bool, list): If True, sets the level of sparsity to 0 + before init_t (:math:`t_0`). Otherwise, the sparsity level before + init_t (:math:`t_0`) is set to init_sl(:math:`s_0`) + last_epoch (int): The index of last epoch. Default: -1. + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + """ + + def __init__( + self, + sparsifier, + init_sl=0.0, + init_t=0, + delta_t=10, + total_t=100, + initially_zero=False, + last_epoch=-1, + verbose=False, + ): + self.sparsifier = sparsifier + + self.init_sl = self._make_sure_a_list(init_sl) + self.init_t = self._make_sure_a_list(init_t) + self.delta_t = self._make_sure_a_list(delta_t) + self.total_t = self._make_sure_a_list(total_t) + + self.initially_zero = self._make_sure_a_list(initially_zero) + + super().__init__(sparsifier, last_epoch, verbose) + + @staticmethod + def sparsity_compute_fn(s_0, s_f, t, t_0, dt, n, initially_zero=False): + r""" "Computes the current level of sparsity. + + Based on https://arxiv.org/pdf/1710.01878.pdf + + Args: + s_0: Initial level of sparsity, :math:`s_i` + s_f: Target level of sparsity, :math:`s_f` + t: Current step, :math:`t` + t_0: Initial step, :math:`t_0` + dt: Pruning frequency, :math:`\Delta T` + n: Pruning steps, :math:`n` + initially_zero: Sets the level of sparsity to 0 before t_0. + If False, sets to s_0 + + Returns: + The sparsity level :math:`s_t` at the current step :math:`t` + """ + if initially_zero and t < t_0: + return 0 + s_t = s_f + (s_0 - s_f) * (1.0 - (t - t_0) / (dt * n)) ** 3 + s_t = _clamp(s_t, s_0, s_f) + return s_t + + def get_sl(self): + if not self._get_sl_called_within_step: + warnings.warn( + "To get the last sparsity level computed by the scheduler, " + "please use `get_last_sl()`.", + stacklevel=2, + ) + return [ + self.sparsity_compute_fn( + s_0=initial_sparsity, + s_f=final_sparsity, + t=self.last_epoch, + t_0=initial_epoch, + dt=delta_epoch, + n=interval_epochs, + initially_zero=initially_zero, + ) + for initial_sparsity, final_sparsity, initial_epoch, delta_epoch, interval_epochs, initially_zero in zip( + self.init_sl, + self.base_sl, + self.init_t, + self.delta_t, + self.total_t, + self.initially_zero, + ) + ] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/scheduler/lambda_scheduler.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/scheduler/lambda_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..fe5737095bf6662ba13a22a8ee8287d07263c05f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/scheduler/lambda_scheduler.py @@ -0,0 +1,64 @@ +import warnings +from collections.abc import Callable + +from torch.ao.pruning.sparsifier.base_sparsifier import BaseSparsifier + +from .base_scheduler import BaseScheduler + + +__all__ = ["LambdaSL"] + + +class LambdaSL(BaseScheduler): + """Sets the sparsity level of each parameter group to the final sl + times a given function. When last_epoch=-1, sets initial sl as zero. + Args: + sparsifier (BaseSparsifier): Wrapped sparsifier. + sl_lambda (function or list): A function which computes a multiplicative + factor given an integer parameter epoch, or a list of such + functions, one for each group in sparsifier.param_groups. + last_epoch (int): The index of last epoch. Default: -1. + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + Example: + >>> # Assuming sparsifier has two groups. + >>> lambda1 = lambda epoch: epoch // 30 + >>> lambda2 = lambda epoch: 0.95**epoch + >>> # xdoctest: +SKIP + >>> scheduler = LambdaSL(sparsifier, sl_lambda=[lambda1, lambda2]) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + + def __init__( + self, + sparsifier: BaseSparsifier, + sl_lambda: Callable[[int], float] | list[Callable[[int], float]], + last_epoch: int = -1, + verbose: bool = False, + ) -> None: + self.sparsifier = sparsifier + + if not isinstance(sl_lambda, list) and not isinstance(sl_lambda, tuple): + self.sl_lambdas = [sl_lambda] * len(sparsifier.groups) + else: + if len(sl_lambda) != len(sparsifier.groups): + raise ValueError( + f"Expected {len(sparsifier.groups)} lr_lambdas, but got {len(sl_lambda)}" + ) + self.sl_lambdas = list(sl_lambda) + super().__init__(sparsifier, last_epoch, verbose) # type: ignore[no-untyped-call] + + def get_sl(self) -> list[float]: + if not self._get_sl_called_within_step: + warnings.warn( + "To get the last sparsity level computed by the scheduler, " + "please use `get_last_sl()`.", + stacklevel=2, + ) + return [ + base_sl * lmbda(self.last_epoch) + for lmbda, base_sl in zip(self.sl_lambdas, self.base_sl) + ] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/sparsifier/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/sparsifier/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/sparsifier/base_sparsifier.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/sparsifier/base_sparsifier.py new file mode 100644 index 0000000000000000000000000000000000000000..1f55d63a26781a3875a5d3ee36fb0ee906a5a0d9 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/sparsifier/base_sparsifier.py @@ -0,0 +1,359 @@ +# mypy: allow-untyped-defs +import abc +import copy +from collections import defaultdict +from typing import Any + +import torch +from torch import nn +from torch.nn.utils import parametrize +from torch.nn.utils.parametrize import type_before_parametrizations + +from .utils import ( + FakeSparsity, + get_arg_info_from_tensor_fqn, + module_contains_param, + module_to_fqn, + swap_module, +) + + +__all__ = ["BaseSparsifier"] + +SUPPORTED_MODULES = {nn.Linear} + +KEYS_NOT_IN_STATE_DICT = ["module", "module_fqn", "tensor_name"] + + +# TODO update desc with new config args +class BaseSparsifier(abc.ABC): + r"""Base class for all sparsifiers. + + Abstract methods that need to be implemented: + + - update_mask: Function to compute a new mask for all keys in the + `groups`. + + Args: + - model [nn.Module]: model to configure. The model itself is not saved + but used for the state_dict saving / loading. + - config [list]: configuration elements should be a dict map that includes + `tensor_fqn` of tensors to sparsify + - defaults [dict]: default configurations will be attached to the + configuration. Only the keys that don't exist in the `config` will + be updated. + + Example:: + + >>> # xdoctest: +SKIP("Can't instantiate abstract class BaseSparsifier with abstract method update_mask") + >>> config = [{'tensor_fqn': 'layer1.weight', 'tensor_fqn': 'linear2.weight2', 'sparsity_level': 0.5}] + >>> defaults = {'sparsity_level': 0.7} + >>> # model.layer1.weight will have `sparsity_level` = 0.7 (getting default) + >>> sparsifier = BaseSparsifier(config, defaults) + """ + + def __init__(self, defaults: dict[str, Any] | None = None): + super().__init__() + self.defaults: dict[str, Any] = defaults or {} + + self.state: dict[str, dict] = defaultdict(dict) + self.groups: list[dict[str, Any]] = [] + self.enable_mask_update = True + + def __getstate__(self) -> dict[str, Any]: + return { + "defaults": self.defaults, + "state": self.state, + "groups": self.groups, + } + + def __setstate__(self, state: dict[str, dict[str, Any]]) -> None: + self.__dict__.update(state) + + def __repr__(self): + format_string = self.__class__.__name__ + " (" + for i, sparse_args in enumerate(self.groups): + module = sparse_args["module"] + format_string += "\n" + format_string += f"\tGroup {i}\n" + format_string += f"\t module: {module}\n" + for key in sorted(sparse_args.keys()): + if key == "module": + continue + format_string += f"\t {key}: {sparse_args[key]}\n" + format_string += ")" + return format_string + + def state_dict(self) -> dict[str, Any]: + r"""Returns the state of the optimizer as a :class:`dict`. + + It contains: + * state - current state of the sparsification. + * groups - a list containing all sparsity configuration groups + with the key 'tensor_fqn' specifying the path to the sparsified tensor within a model + + TODO: Need a clean way of loading the state of the "prepared" module + """ + + groups: list[dict[str, Any]] = [ + dict( + filter( + lambda key_value: key_value[0] not in KEYS_NOT_IN_STATE_DICT, + mg.items(), + ) + ) + for mg in self.groups + ] + + return { + "state": self.state, + "groups": groups, + } + + def load_state_dict(self, state_dict: dict[str, Any], strict: bool = True): + groups = copy.deepcopy(state_dict["groups"]) + states = state_dict["state"] + for tensor_fqn, s in states.items(): + arg_info = get_arg_info_from_tensor_fqn(self.model, tensor_fqn) + module = arg_info["module"] + tensor_name = arg_info["tensor_name"] + if strict and module is None: + raise RuntimeError(f"Error loading {tensor_fqn} into the model") + + found = False + for p in module.parametrizations[tensor_name]: + if isinstance(p, FakeSparsity): + found = True + break + if not found: + p = FakeSparsity(torch.ones(getattr(module, tensor_name).shape)) + parametrize.register_parametrization(module, tensor_name, p) + if s.get("mask", None) is not None: + mask = s.pop("mask") + p.mask = mask + + for mg in groups: + if mg["tensor_fqn"] == tensor_fqn: + mg.update(arg_info) + self.__setstate__({"state": states, "groups": groups}) + + def make_config_from_model( + self, + model: nn.Module, + SUPPORTED_MODULES: set[type[nn.Linear]] = SUPPORTED_MODULES, + ) -> None: + self.config = [] + stack = [model] + while stack: + module = stack.pop() + for _name, child in module.named_children(): + if type(child) in SUPPORTED_MODULES: + module_fqn = module_to_fqn(model, child) + if not isinstance(module_fqn, str): + raise AssertionError("module_fqn must be a string") + self.config.append({"tensor_fqn": module_fqn + ".weight"}) + else: + stack.append(child) + + def prepare(self, model, config): + r"""Prepares a model, by adding the parametrizations. + + Note:: + + The model is modified inplace. If you need to preserve the original + model, use copy.deepcopy. + """ + self.model = model # TODO: Need to figure out how to load without this. + self.config = config + + # If no config -- try getting all the supported layers + if self.config is None: + self.make_config_from_model(model) + + # TODO: Remove the configuration by reference ('module') + # pyrefly: ignore [not-iterable] + for module_config in self.config: + if not isinstance(module_config, dict): + raise AssertionError( + "config elements should be dicts not modules i.e.:" + "[{`tensor_fqn`: `foo.bar.weight`}, {`tensor_fqn`: ... }, ...]" + ) + + if not isinstance(self.defaults, dict): + raise AssertionError("defaults must be a dict") + local_args = copy.deepcopy(self.defaults) + local_args.update(module_config) + + tensor_fqn = local_args.get("tensor_fqn", None) + if tensor_fqn is None: + raise AssertionError( + "tensor_fqn is a required argument in the sparsity config which" + "replaces previous `module` and [module]`fqn` arguments" + ) + + # populate all information from tensor_fqn + info_from_tensor_fqn = get_arg_info_from_tensor_fqn(model, tensor_fqn) + + # check that whatever was put into local_args agrees with what was obtained + # from tensor_fqn + for key in info_from_tensor_fqn: + if key in local_args: + if not ( + info_from_tensor_fqn[key] == local_args[key] + or ( + key == "tensor_fqn" + and "." + info_from_tensor_fqn[key] == local_args[key] + ) + # info_from_tensor_fqn will chop leading '.' from tensor_fqn so ignore that + ): + raise AssertionError( + f"Given both `{key}` and `tensor_fqn` in the config, it is expected them to agree!" + ) + local_args.update(info_from_tensor_fqn) + self.groups.append(local_args) + self._prepare() + + def _prepare(self, *args, **kwargs): + r"""Adds mask parametrization to the layer weight""" + for config in self.groups: + module = config["module"] + tensor_name = config["tensor_name"] + parametrization = config.get("parametrization", FakeSparsity) + mask = config.get("mask", torch.ones_like(getattr(module, tensor_name))) + self.state[config["tensor_fqn"]]["mask"] = mask + parametrize.register_parametrization( + module, tensor_name, parametrization(mask) + ) + + def squash_mask( + self, + params_to_keep: tuple[str, ...] | None = None, + params_to_keep_per_layer: dict[str, tuple[str, ...]] | None = None, + *args, + **kwargs, + ): + r"""Squashes the sparse masks into the appropriate tensors. + + If either the `params_to_keep` or `params_to_keep_per_layer` is set, + the module will have a `sparse_params` dict attached to it. + + Args: + params_to_keep: List of keys to save in the module or a dict + representing the modules and keys that will have + sparsity parameters saved + params_to_keep_per_layer: Dict to specify the params that should be + saved for specific layers. The keys in the dict + should be the module fqn, while the values should + be a list of strings with the names of the variables + to save in the `sparse_params` + + Examples: + >>> # xdoctest: +SKIP("locals are undefined") + >>> # Don't save any sparse params + >>> sparsifier.squash_mask() + >>> hasattr(model.submodule1, "sparse_params") + False + + >>> # Keep sparse params per layer + >>> sparsifier.squash_mask( + ... params_to_keep_per_layer={ + ... "submodule1.linear1": ("foo", "bar"), + ... "submodule2.linear42": ("baz",), + ... } + ... ) + >>> print(model.submodule1.linear1.sparse_params) + {'foo': 42, 'bar': 24} + >>> print(model.submodule2.linear42.sparse_params) + {'baz': 0.1} + + >>> # Keep sparse params for all layers + >>> sparsifier.squash_mask(params_to_keep=("foo", "bar")) + >>> print(model.submodule1.linear1.sparse_params) + {'foo': 42, 'bar': 24} + >>> print(model.submodule2.linear42.sparse_params) + {'foo': 42, 'bar': 24} + + >>> # Keep some sparse params for all layers, and specific ones for + >>> # some other layers + >>> sparsifier.squash_mask( + ... params_to_keep=("foo", "bar"), + ... params_to_keep_per_layer={"submodule2.linear42": ("baz",)}, + ... ) + >>> print(model.submodule1.linear1.sparse_params) + {'foo': 42, 'bar': 24} + >>> print(model.submodule2.linear42.sparse_params) + {'foo': 42, 'bar': 24, 'baz': 0.1} + """ + for config in self.groups: + module = config["module"] + tensor_name = config["tensor_name"] + parametrize.remove_parametrizations( + module, tensor_name, leave_parametrized=True + ) + sparse_params = {} + if params_to_keep is not None: + global_params = {k: config[k] for k in params_to_keep} + sparse_params.update(global_params) + if params_to_keep_per_layer is not None: + params = params_to_keep_per_layer.get(config["module_fqn"], None) + if params is not None: + per_layer_params = {k: config[k] for k in params} + sparse_params.update(per_layer_params) + if sparse_params: + # TODO handle multiple tensor being quantized on a single module, where to store sparse_params? + module.sparse_params = sparse_params + + def convert( + self, + module: nn.Module, + mapping: dict[type[nn.Module], type[nn.Module]] | None = None, + inplace: bool = False, + parameterization: type[nn.Module] = FakeSparsity, + ): + r"""Converts submodules in input module to a different module according to `mapping` + by calling `from_dense` method on the target module class + Args: + module: input module + mapping: a dictionary that maps from source module type to target + module type, can be overwritten to allow swapping user defined + Modules + inplace: carry out model transformations in-place, the original module + is mutated + """ + if mapping is None: + raise NotImplementedError("Need to auto generate mapping ") + if not inplace: + module = copy.deepcopy(module) + + reassign = {} + for name, mod in module.named_children(): + # leaf node + if ( + module_contains_param(mod, parameterization) + and type_before_parametrizations(mod) in mapping + ): + reassign[name] = swap_module(mod, mapping) + else: + # recurse + reassign[name] = self.convert( + mod, + mapping=mapping, + inplace=True, + parameterization=parameterization, + ) + + for key, value in reassign.items(): + module._modules[key] = value + + return module + + def step(self, use_path: bool = True) -> None: + if not self.enable_mask_update: + return + with torch.no_grad(): + for config in self.groups: + self.update_mask(**config) + + @abc.abstractmethod + def update_mask(self, module: nn.Module, tensor_name: str, **kwargs): + pass diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/sparsifier/nearly_diagonal_sparsifier.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/sparsifier/nearly_diagonal_sparsifier.py new file mode 100644 index 0000000000000000000000000000000000000000..26fb3a98b8fb7d37e6bd5965d1d41b091d3e4818 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/sparsifier/nearly_diagonal_sparsifier.py @@ -0,0 +1,60 @@ +# mypy: allow-untyped-defs +import torch + +from . import base_sparsifier + + +class NearlyDiagonalSparsifier(base_sparsifier.BaseSparsifier): + r"""Nearly Diagonal Sparsifier + + This sparsifier creates a nearly diagonal mask to be applied to the weight matrix. + Nearly Diagonal Matrix is a matrix that contains non-zero elements near the diagonal and the rest are zero. + An example of a nearly diagonal matrix with degree (or nearliness) 3 and 5 are follows respectively. + 1 1 0 0 1 1 1 0 + 1 1 1 0 1 1 1 1 + 0 1 1 1 1 1 1 1 + 0 0 1 1 0 1 1 1 + Note that a nearly diagonal matrix with degree 1 is just a matrix with main diagonal populated + + This sparsifier is controlled by one variable: + 1. `nearliness` defines the number of non-zero diagonal lines that are closest to the main diagonal. + Currently - supports only odd number + + Note: + This can be accelerated (vectorized) once the Spdiagonal feature (PR: #78439) is landed or the banded matrix + feature is landed: https://stackoverflow.com/questions/52463972/generating-banded-matrices-using-numpy + + Args: + nearliness: The degree of nearliness (default = 1) + + """ + + def __init__(self, nearliness: int = 1): + defaults = {"nearliness": nearliness} + super().__init__(defaults=defaults) + + def update_mask( # type:ignore[override] + self, module, tensor_name, nearliness, **kwargs + ): + mask = getattr(module.parametrizations, tensor_name)[0].mask + mask.data = torch.zeros_like(mask) + if nearliness <= 0: + return + + tensor = getattr(module, tensor_name) + height, width = tensor.shape + + if nearliness % 2 == 0: + raise ValueError("nearliness can only be an odd number") + dist_to_diagonal = nearliness // 2 + # check + if dist_to_diagonal >= min(height, width): + raise ValueError( + "nearliness cannot be larger than the dimensions of tensor." + ) + + for row in range(height): + # Bounds of entries that needs to be set to 1 + low = max(0, row - dist_to_diagonal) + high = min(width, row + dist_to_diagonal + 1) + mask[row, low:high].fill_(1) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/sparsifier/utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/sparsifier/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..97461630bc3ae9ce60cd02ce13a2371d9ba05536 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/sparsifier/utils.py @@ -0,0 +1,141 @@ +# mypy: allow-untyped-defs +from itertools import chain +from typing import Any + +from torch import nn +from torch.nn.utils.parametrize import is_parametrized, type_before_parametrizations + + +__all__ = [ + "module_contains_param", + "swap_module", + "module_to_fqn", + "fqn_to_module", + "get_arg_info_from_tensor_fqn", + "FakeSparsity", +] + + +def module_contains_param(module: nn.Module, parametrization: type[nn.Module]) -> bool: + if is_parametrized(module): + # see if any of the module tensors have a parametriztion attached that matches the one passed in + return any( + any(isinstance(param, parametrization) for param in param_list) + for key, param_list in module.parametrizations.items() # type: ignore[union-attr,operator] + ) + return False + + +def swap_module( + mod: nn.Module, mapping: dict[type[nn.Module], type[nn.Module]] +) -> nn.Module: + r"""Swaps the module using from_dense according to the mapping passed in. + Args: + mod: input module + mapping: a dictionary that maps from nn module to sparse nn module + Return: + The corresponding sparse module of `mod` according to mapping, created using from_dense + """ + if type_before_parametrizations(mod) in mapping: + sparse_mod = mapping[type_before_parametrizations(mod)] + + # TODO Fix this typing, as Type[Module] has no attribute "from_dense" + new_mod = sparse_mod.from_dense(mod) # type: ignore[attr-defined] + + # Preserve module's pre forward hooks. They'll be called on quantized input + for pre_hook_fn in mod._forward_pre_hooks.values(): + new_mod.register_forward_pre_hook(pre_hook_fn) + # Preserve module's post forward hooks except _observer_forward_hook + # After convert they'll work with quantized output + for hook_fn in mod._forward_hooks.values(): + new_mod.register_forward_hook(hook_fn) + + # respect device affinity when swapping modules + # pyrefly: ignore [bad-argument-type] + devices = {p.device for p in chain(mod.parameters(), mod.buffers())} + if len(devices) > 1: + raise AssertionError( + f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}" + ) + device = next(iter(devices)) if len(devices) > 0 else None + if device: + new_mod.to(device) + + return new_mod + + else: + return mod + + +def module_to_fqn(model: nn.Module, module: nn.Module, prefix: str = "") -> str | None: + """ + Returns the fqn for a module or None if module not a descendent of model. + """ + if module is model: + return "" + for name, child in model.named_children(): + fqn = module_to_fqn(child, module, ".") + if isinstance(fqn, str): + return prefix + name + fqn + return None + + +def fqn_to_module(model: nn.Module | None, path: str) -> nn.Module | None: + """ + Given an fqn, returns the corresponding module or tensor or None if the fqn given by `path` + doesn't correspond to anything. Similar to model.get_submodule(path) but works for tensors. + """ + if path != "": + for name in path.split("."): + model = getattr(model, name, None) + return model + + +def get_arg_info_from_tensor_fqn(model: nn.Module, tensor_fqn: str) -> dict[str, Any]: + """ + Uses tensor_fqn to obtain a dict containing module_fqn, module and tensor_name + """ + # string manip to split tensor_fqn into module_fqn and tensor_name + # if tensor_fqn is 'weight' then module_fqn and tensor_name are '' and 'weight' + # if tensor_fqn is 'linear.weight' then module_fqn and tensor_name are 'linear' and 'weight' + tensor_name = tensor_fqn.rsplit(".", maxsplit=1)[-1] + module_fqn = tensor_fqn[: -len(tensor_name) - ("." in tensor_fqn)] + + module = fqn_to_module(model, module_fqn) + + return { + "module_fqn": module_fqn, + "module": module, + "tensor_name": tensor_name, + "tensor_fqn": tensor_fqn, + } + + +# Parametrizations +class FakeSparsity(nn.Module): + r"""Parametrization for the weights. Should be attached to the 'weight' or + any other parameter that requires a mask applied to it. + + Note:: + + Once the mask is passed, the variable should not change the id. The + contents of the mask can change, but the mask reference itself should + not. + """ + + def __init__(self, mask): + super().__init__() + self.register_buffer("mask", mask) + + def forward(self, x): + if self.mask.shape != x.shape: + raise AssertionError( + f"mask shape ({self.mask.shape}) must match x shape ({x.shape})" + ) + return self.mask * x + + def state_dict(self, *args, **kwargs): + # We don't want to let the parametrizations to save the mask. + # That way we make sure that the linear module doesn't store the masks + # alongside their parametrizations. + return {} diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/sparsifier/weight_norm_sparsifier.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/sparsifier/weight_norm_sparsifier.py new file mode 100644 index 0000000000000000000000000000000000000000..0fd0368f156744f1af362670fa73baf505f50251 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/pruning/sparsifier/weight_norm_sparsifier.py @@ -0,0 +1,250 @@ +# mypy: allow-untyped-defs +import operator +from collections.abc import Callable +from functools import reduce + +import torch +import torch.nn.functional as F + +from .base_sparsifier import BaseSparsifier + + +__all__ = ["WeightNormSparsifier"] + + +def _flat_idx_to_2d(idx, shape): + rows = idx // shape[1] + cols = idx % shape[1] + return rows, cols + + +class WeightNormSparsifier(BaseSparsifier): + r"""Weight-Norm Sparsifier + + This sparsifier computes the norm of every sparse block and "zeroes-out" the + ones with the lowest norm. The level of sparsity defines how many of the + blocks is removed. + + This sparsifier is controlled by three variables: + 1. `sparsity_level` defines the number of *sparse blocks* that are zeroed-out + 2. `sparse_block_shape` defines the shape of the sparse blocks. Note that + the sparse blocks originate at the zero-index of the tensor. + 3. `zeros_per_block` is the number of zeros that we are expecting in each + sparse block. By default we assume that all elements within a block are + zeroed-out. However, setting this variable sets the target number of + zeros per block. The zeros within each block are chosen as the *smallest + absolute values*. + + Args: + + sparsity_level: The target level of sparsity + sparse_block_shape: The shape of a sparse block (see note below) + zeros_per_block: Number of zeros in a sparse block + norm: Norm to use. Could be either `int` or a callable. + If `int`, only L1 and L2 are implemented. + + Note:: + The `sparse_block_shape` is tuple representing (block_ROWS, block_COLS), + irrespective of what the rows / cols mean in the data tensor. That means, + if you were to sparsify a weight tensor in the nn.Linear, which has a + weight shape `(Cout, Cin)`, the `block_ROWS` would refer to the output + channels, while the `block_COLS` would refer to the input channels. + + Note:: + All arguments to the WeightNormSparsifier constructor are "default" + arguments and could be overridden by the configuration provided in the + `prepare` step. + """ + + def __init__( + self, + sparsity_level: float = 0.5, + sparse_block_shape: tuple[int, int] = (1, 4), + zeros_per_block: int | None = None, + norm: Callable | int | None = None, + ): + if zeros_per_block is None: + zeros_per_block = reduce(operator.mul, sparse_block_shape) + defaults = { + "sparsity_level": sparsity_level, + "sparse_block_shape": sparse_block_shape, + "zeros_per_block": zeros_per_block, + } + if norm is None: + norm = 2 + if callable(norm): + self.norm_fn = norm + elif norm == 1: + self.norm_fn = lambda T: T.abs() + elif norm == 2: + self.norm_fn = lambda T: T * T + else: + raise NotImplementedError(f"L-{norm} is not yet implemented.") + super().__init__(defaults=defaults) + + def _scatter_fold_block_mask( + self, + output_shape, + dim, + indices, + block_shape, + mask=None, + input_shape=None, + device=None, + ): + r"""Creates patches of size `block_shape` after scattering the indices.""" + if mask is None: + if input_shape is None: + raise AssertionError("input_shape must be provided when mask is None") + mask = torch.ones(input_shape, device=device) + mask.scatter_(dim=dim, index=indices, value=0) + mask.data = F.fold( + mask, output_size=output_shape, kernel_size=block_shape, stride=block_shape + ) + return mask + + def _make_tensor_mask( + self, data, input_shape, sparsity_level, sparse_block_shape, mask=None + ): + r"""Creates a tensor-level mask. + + Tensor-level mask is described as a mask, where the granularity of sparsification of the + smallest patch is the sparse_block_shape. That means, that for a given mask and a + sparse_block_shape, the smallest "patch" of zeros/ones could be the sparse_block_shape. + + In this context, `sparsity_level` describes the fraction of sparse patches. + """ + h, w = data.shape[-2:] + block_h, block_w = sparse_block_shape + dh = (block_h - h % block_h) % block_h + dw = (block_w - w % block_w) % block_w + + if mask is None: + mask = torch.ones(h + dh, w + dw, device=data.device) + + if sparsity_level >= 1.0: + mask.data = torch.zeros_like(mask) + return mask + elif sparsity_level <= 0.0: + mask.data = torch.ones_like(mask) + return mask + + values_per_block = reduce(operator.mul, sparse_block_shape) + if values_per_block > 1: + # Reduce the data + data = F.avg_pool2d( + data[None, None, :], + kernel_size=sparse_block_shape, + stride=sparse_block_shape, + ceil_mode=True, + ) + data = data.flatten() + num_blocks = len(data) + + data = data.repeat(1, values_per_block, 1) + + threshold_idx = round(sparsity_level * num_blocks) + threshold_idx = max(0, min(num_blocks - 1, threshold_idx)) # Sanity check + _, sorted_idx = torch.topk(data, k=threshold_idx, dim=2, largest=False) + + # Temp reshape for mask + mask_reshape = mask.reshape(data.shape) # data might be reshaped + self._scatter_fold_block_mask( + dim=2, + output_shape=(h + dh, w + dw), + indices=sorted_idx, + block_shape=sparse_block_shape, + mask=mask_reshape, + ) + mask.data = mask_reshape.squeeze().reshape(mask.shape)[:h, :w].contiguous() + return mask + + def _make_block_mask(self, data, sparse_block_shape, zeros_per_block, mask=None): + r"""Creates a block-level mask. + + Block-level mask is described as a mask, where the granularity of sparsification of the + largest patch is the sparse_block_shape. That means that for a given mask and a + sparse_block_shape, the sparsity is computed only within a patch of a size sparse_block_shape. + + In this context the `zeros_per_block` describes the number of zeroed-out elements within a patch. + """ + h, w = data.shape[-2:] + block_h, block_w = sparse_block_shape + dh = (block_h - h % block_h) % block_h + dw = (block_w - w % block_w) % block_w + values_per_block = reduce(operator.mul, sparse_block_shape) + + if mask is None: + mask = torch.ones((h + dh, w + dw), device=data.device) + + if values_per_block == zeros_per_block: + # Everything should be sparsified + mask.data = torch.zeros_like(mask) + return mask + + # create a new padded tensor like data (to match the block_shape) + padded_data = torch.ones(h + dh, w + dw, dtype=data.dtype, device=data.device) + padded_data.fill_(torch.nan) + padded_data[:h, :w] = data + unfolded_data = F.unfold( + padded_data[None, None, :], + kernel_size=sparse_block_shape, + stride=sparse_block_shape, + ) + + # Temp reshape for mask + mask_reshape = mask.reshape(unfolded_data.shape) + _, sorted_idx = torch.topk( + unfolded_data, k=zeros_per_block, dim=1, largest=False + ) + + self._scatter_fold_block_mask( + dim=1, + indices=sorted_idx, + output_shape=padded_data.shape, + block_shape=sparse_block_shape, + mask=mask_reshape, + ) + + mask.data = mask_reshape.squeeze().reshape(mask.shape).contiguous() + return mask + + def update_mask( # type: ignore[call-override, override] + self, + module, + tensor_name, + sparsity_level, + sparse_block_shape, + zeros_per_block, + **kwargs, + ): + values_per_block = reduce(operator.mul, sparse_block_shape) + if zeros_per_block > values_per_block: + raise ValueError( + "Number of zeros per block cannot be more than the total number of elements in that block." + ) + if zeros_per_block < 0: + raise ValueError("Number of zeros per block should be positive.") + + mask = getattr(module.parametrizations, tensor_name)[0].mask + if sparsity_level <= 0 or zeros_per_block == 0: + mask.data = torch.ones_like(mask) + elif sparsity_level >= 1.0 and (zeros_per_block == values_per_block): + mask.data = torch.zeros_like(mask) + else: + ww = self.norm_fn(getattr(module, tensor_name)) + tensor_mask = self._make_tensor_mask( + data=ww, + # pyrefly: ignore [missing-attribute] + input_shape=ww.shape, + sparsity_level=sparsity_level, + sparse_block_shape=sparse_block_shape, + ) + if values_per_block != zeros_per_block: + block_mask = self._make_block_mask( + data=ww, + sparse_block_shape=sparse_block_shape, + zeros_per_block=zeros_per_block, + ) + tensor_mask = torch.logical_or(tensor_mask, block_mask) + mask.data = tensor_mask diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2efc24081b0c13d94b7ab256f635eafce8614543 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/__init__.py @@ -0,0 +1,247 @@ +# mypy: allow-untyped-defs + +import sys +from collections.abc import Callable +from typing import Optional, Union + +import torch +from torch import Tensor + +from .fake_quantize import * # noqa: F403 +from .fuse_modules import fuse_modules, fuse_modules_qat # noqa: F403 +from .fuser_method_mappings import * # noqa: F403 +from .observer import * # noqa: F403 +from .pt2e._numeric_debugger import ( # noqa: F401 + compare_results, + CUSTOM_KEY, + extract_results_from_loggers, + generate_numeric_debug_handle, + NUMERIC_DEBUG_HANDLE_KEY, + prepare_for_propagation_comparison, +) +from .pt2e.export_utils import ( + _allow_exported_model_train_eval as allow_exported_model_train_eval, + _move_exported_model_to_eval as move_exported_model_to_eval, + _move_exported_model_to_train as move_exported_model_to_train, +) + +# pyrefly: ignore [deprecated] +from .qconfig import * # noqa: F403 +from .qconfig_mapping import * # noqa: F403 +from .quant_type import * # noqa: F403 +from .quantization_mappings import * # noqa: F403 # type: ignore[no-redef] +from .quantize import * # noqa: F403 +from .quantize_jit import * # noqa: F403 +from .stubs import * # noqa: F403 + + +# ensure __module__ is set correctly for public APIs +if sys.version_info < (3, 12): + ObserverOrFakeQuantize = Union[ObserverBase, FakeQuantizeBase] + ObserverOrFakeQuantize.__module__ = "torch.ao.quantization" +else: + from typing import TypeAliasType + + ObserverOrFakeQuantize = TypeAliasType( + "ObserverOrFakeQuantize", ObserverBase | FakeQuantizeBase + ) + +for _f in [ + compare_results, + extract_results_from_loggers, + generate_numeric_debug_handle, + prepare_for_propagation_comparison, +]: + _f.__module__ = "torch.ao.quantization" + +__all__ = [ + "DeQuantStub", + "FakeQuantize", + "FakeQuantizeBase", + "FixedQParamsFakeQuantize", + "FixedQParamsObserver", + "FusedMovingAvgObsFakeQuantize", + "HistogramObserver", + "MatchAllNode", + "MinMaxObserver", + "MovingAverageMinMaxObserver", + "MovingAveragePerChannelMinMaxObserver", + "NoopObserver", + "ObserverBase", + "ObserverOrFakeQuantize", + "Pattern", + "PerChannelMinMaxObserver", + "PlaceholderObserver", + "QConfig", + "QConfigAny", + "QConfigDynamic", + "QConfigMapping", + "QuantStub", + "QuantType", + "QuantWrapper", + "RecordingObserver", + "ReuseInputObserver", + "UniformQuantizationObserverBase", + "add_quant_dequant", + "convert", + "convert_dynamic_jit", + "convert_jit", + "default_affine_fixed_qparams_fake_quant", + "default_affine_fixed_qparams_observer", + "default_debug_observer", + "default_dynamic_fake_quant", + "default_dynamic_quant_observer", + "default_embedding_fake_quant", + "default_embedding_fake_quant_4bit", + "default_eval_fn", + "default_fake_quant", + "default_fixed_qparams_range_0to1_fake_quant", + "default_fixed_qparams_range_0to1_observer", + "default_fixed_qparams_range_neg1to1_fake_quant", + "default_fixed_qparams_range_neg1to1_observer", + "default_float_qparams_observer", + "default_float_qparams_observer_4bit", + "default_fused_act_fake_quant", + "default_fused_per_channel_wt_fake_quant", + "default_fused_wt_fake_quant", + "default_histogram_fake_quant", + "default_histogram_observer", + "default_observer", + "default_per_channel_weight_fake_quant", + "default_per_channel_weight_observer", + "default_placeholder_observer", + "default_reuse_input_observer", + "default_symmetric_fixed_qparams_fake_quant", + "default_symmetric_fixed_qparams_observer", + "default_weight_fake_quant", + "default_weight_observer", + "disable_fake_quant", + "disable_observer", + "enable_fake_quant", + "enable_observer", + "fuse_conv_bn", + "fuse_conv_bn_jit", + "fuse_conv_bn_relu", + "fuse_convtranspose_bn", + "fuse_linear_bn", + "fuse_modules", + "fuse_modules_qat", + "fused_per_channel_wt_fake_quant_range_neg_127_to_127", + "fused_wt_fake_quant_range_neg_127_to_127", + "get_combined_dict", + "get_default_compare_output_module_list", + "get_default_custom_config_dict", + "get_default_dynamic_quant_module_mappings", + "get_default_dynamic_sparse_quant_module_mappings", + "get_default_float_to_quantized_operator_mappings", + "get_default_qat_module_mappings", + "get_default_qat_qconfig", + "get_default_qat_qconfig_dict", + "get_default_qat_qconfig_mapping", + "get_default_qconfig", + "get_default_qconfig_dict", + "get_default_qconfig_mapping", + "get_default_qconfig_propagation_list", + "get_default_static_quant_module_mappings", + "get_default_static_quant_reference_module_mappings", + "get_default_static_sparse_quant_module_mappings", + "get_dynamic_quant_module_class", + "get_embedding_qat_module_mappings", + "get_embedding_static_quant_module_mappings", + "get_fuser_method", + "get_fuser_method_new", + "get_observer_state_dict", + "get_quantized_operator", + "get_static_quant_module_class", + "load_observer_state_dict", + "move_exported_model_to_eval", + "move_exported_model_to_train", + "allow_exported_model_train_eval", + "no_observer_set", + "per_channel_weight_observer_range_neg_127_to_127", + "prepare", + "prepare_dynamic_jit", + "prepare_jit", + "prepare_qat", + "propagate_qconfig_", + "qconfig_equals", + "quantize", + "quantize_dynamic", + "quantize_dynamic_jit", + "quantize_jit", + "quantize_qat", + "script_qconfig", + "script_qconfig_dict", + "swap_module", + "weight_observer_range_neg_127_to_127", + "generate_numeric_debug_handle", + "CUSTOM_KEY", + "NUMERIC_DEBUG_HANDLE_KEY", + "prepare_for_propagation_comparison", + "extract_results_from_loggers", + "compare_results", + # from torchao, should be merged with torchao + # in the future + "AffineQuantizedObserverBase", + "Granularity", + "MappingType", + "PerAxis", + "PerBlock", + "PerGroup", + "PerRow", + "PerTensor", + "PerToken", + "TorchAODType", + "ZeroPointDomain", + "get_block_size", +] + + +def default_eval_fn(model, calib_data): + r"""Define the default evaluation function. + + Default evaluation function takes a torch.utils.data.Dataset or a list of + input Tensors and run the model on the dataset + """ + for data, _target in calib_data: + model(data) + + +class _DerivedObserverOrFakeQuantize(ObserverBase): + r"""This observer is used to describe an observer whose quantization parameters + are derived from other observers + """ + + def __init__( + self, + dtype: torch.dtype, + obs_or_fqs: list[ObserverOrFakeQuantize], + derive_qparams_fn: Callable[ + [list[ObserverOrFakeQuantize]], tuple[Tensor, Tensor] + ], + quant_min: int | None = None, + quant_max: int | None = None, + qscheme: torch.qscheme | None = None, + ch_axis: int | None = None, + ): + super().__init__(dtype) + self.obs_or_fqs = obs_or_fqs + self.derive_qparams_fn = derive_qparams_fn + self.quant_min = quant_min + self.quant_max = quant_max + self.qscheme = qscheme + self.ch_axis = ch_axis + + from .utils import is_per_channel + + if is_per_channel(self.qscheme): + if self.ch_axis is None: + raise AssertionError( + "Must provide a valid ch_axis if qscheme is per channel" + ) + + def forward(self, x: Tensor) -> Tensor: + return x + + def calculate_qparams(self): # type:ignore[override] + return self.derive_qparams_fn(self.obs_or_fqs) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/_correct_bias.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/_correct_bias.py new file mode 100644 index 0000000000000000000000000000000000000000..4309e4530cb72bd6620a69527cbe87e2a533c323 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/_correct_bias.py @@ -0,0 +1,156 @@ +# mypy: allow-untyped-defs +import torch +import torch.ao.nn.quantized as nnq +import torch.ao.ns._numeric_suite as ns +import torch.ao.quantization +import torch.nn as nn + + +__all__ = [ + "get_module", + "parent_child_names", + "get_param", + "MeanShadowLogger", + "bias_correction", +] + +_supported_modules = {nn.Linear, nn.Conv2d} +_supported_modules_quantized = {nnq.Linear, nnq.Conv2d} + + +def get_module(model, name): + """Given name of submodule, this function grabs the submodule from given model.""" + return dict(model.named_modules())[name] + + +def parent_child_names(name): + """Split full name of submodule into parent submodule's full name and submodule's name.""" + split_name = name.rsplit(".", 1) + if len(split_name) == 1: + return "", split_name[0] + else: + return split_name[0], split_name[1] + + +def get_param(module, attr): + """Get the parameter given a module and attribute. + + Sometimes the weights/bias attribute gives you the raw tensor, but sometimes + gives a function that will give you the raw tensor, this function takes care of that logic + """ + param = getattr(module, attr, None) + if callable(param): + return param() + else: + return param + + +class MeanShadowLogger(ns.Logger): + """Mean Logger for a Shadow module. + + A logger for a Shadow module whose purpose is to record the rolling mean + of the data passed to the floating point and quantized models + """ + + def __init__(self): + """Set up initial values for float and quantized stats, count, float sum, and quant sum.""" + super().__init__() + self.stats["float"] = None + self.stats["quantized"] = None + self.count = 0 + self.float_sum = None + self.quant_sum = None + + def forward(self, x, y): # type: ignore[override] + """Compute the average of quantized and floating-point data from modules. + + The inputs x,y are output data from the quantized and floating-point modules. + x is for the quantized module, y is for the floating point module + """ + if x.is_quantized: + x = x.dequantize() + + self.count += 1 + if self.stats["quantized"] is None: + self.stats["quantized"] = x + self.quant_sum = x + else: + self.quant_sum += x + self.stats["quantized"] = self.quant_sum / self.count + + if self.stats["float"] is None: + self.stats["float"] = y + self.float_sum = y + else: + self.float_sum += y + self.stats["float"] = self.float_sum / self.count + + def clear(self): + self.stats["float"] = None + self.stats["quantized"] = None + self.count = 0 + self.float_sum = None + self.quant_sum = None + + +def bias_correction( + float_model, + quantized_model, + img_data, + target_modules=_supported_modules_quantized, + neval_batches=None, +): + """Perform bias correction on a module. + + Using numeric suite shadow module, the expected output of the floating point and quantized modules + is recorded. Using that data the bias of supported modules is shifted to compensate for the drift caused + by quantization + Paper reference: https://arxiv.org/pdf/1906.04721.pdf (Section 4.2) + + Args: + float_model: a trained model that serves as a reference to what bias correction should aim for + quantized_model: quantized form of float_model that bias correction is to applied to + img_data: calibration data to estimate the expected output (used to find quantization error) + target_modules: specifies what submodules in quantized_model need bias correction (can be extended to + unquantized submodules) + neval_batches: a cap to the number of batches you want to be used for estimating the expected output + """ + ns.prepare_model_with_stubs( + float_model, quantized_model, _supported_modules, MeanShadowLogger + ) + + uncorrected_modules = { + name: submodule + for name, submodule in quantized_model.named_modules() + if type(submodule) in target_modules + } + + for uncorrected_module in uncorrected_modules: + quantized_submodule = get_module(quantized_model, uncorrected_module) + bias = get_param(quantized_submodule, "bias") + if bias is not None: + for count, data in enumerate(img_data, start=1): + quantized_model(data[0]) + if count == neval_batches: + break + ob_dict = ns.get_logger_dict(quantized_model) + parent_name, _ = parent_child_names(uncorrected_module) + + float_data = ob_dict[parent_name + ".stats"]["float"] + quant_data = ob_dict[parent_name + ".stats"]["quantized"] + + # math for expected_error + quantization_error = quant_data - float_data + dims = list(range(quantization_error.dim())) + # Note: we don't want to take the mean over the output channel dimension + dims.remove(1) + expected_error = torch.mean(quantization_error, dims) + + updated_bias = bias.data - expected_error + + bias.data = updated_bias + + # Resets the data contained in the loggers + for submodule in quantized_model.modules(): + if isinstance(submodule, MeanShadowLogger): + submodule.clear() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/_equalize.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/_equalize.py new file mode 100644 index 0000000000000000000000000000000000000000..e4ff327f285aa4c17f05a9cbf61b7323a0536a12 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/_equalize.py @@ -0,0 +1,279 @@ +# mypy: allow-untyped-defs +import copy +from itertools import chain +from typing import Any + +import torch + + +__all__ = [ + "set_module_weight", + "set_module_bias", + "has_bias", + "get_module_weight", + "get_module_bias", + "max_over_ndim", + "min_over_ndim", + "channel_range", + "get_name_by_module", + "cross_layer_equalization", + "process_paired_modules_list_to_name", + "expand_groups_in_paired_modules_list", + "equalize", + "converged", +] + +_supported_types = {torch.nn.Conv2d, torch.nn.Linear, torch.nn.Conv1d} +_supported_intrinsic_types = { + torch.ao.nn.intrinsic.ConvReLU2d, + torch.ao.nn.intrinsic.LinearReLU, + torch.ao.nn.intrinsic.ConvReLU1d, +} +_all_supported_types = _supported_types.union(_supported_intrinsic_types) + + +def set_module_weight(module, weight) -> None: + if type(module) in _supported_types: + module.weight = torch.nn.Parameter(weight) + else: + module[0].weight = torch.nn.Parameter(weight) + + +def set_module_bias(module, bias) -> None: + if type(module) in _supported_types: + module.bias = torch.nn.Parameter(bias) + else: + module[0].bias = torch.nn.Parameter(bias) + + +def has_bias(module) -> bool: + if type(module) in _supported_types: + return module.bias is not None + else: + return module[0].bias is not None + + +def get_module_weight(module): + if type(module) in _supported_types: + return module.weight + else: + return module[0].weight + + +def get_module_bias(module): + if type(module) in _supported_types: + return module.bias + else: + return module[0].bias + + +def max_over_ndim(input, axis_list, keepdim=False): + """Apply 'torch.max' over the given axes.""" + axis_list.sort(reverse=True) + for axis in axis_list: + input, _ = input.max(axis, keepdim) + return input + + +def min_over_ndim(input, axis_list, keepdim=False): + """Apply 'torch.min' over the given axes.""" + axis_list.sort(reverse=True) + for axis in axis_list: + input, _ = input.min(axis, keepdim) + return input + + +def channel_range(input, axis=0): + """Find the range of weights associated with a specific channel.""" + size_of_tensor_dim = input.ndim + axis_list = list(range(size_of_tensor_dim)) + axis_list.remove(axis) + + mins = min_over_ndim(input, axis_list) + maxs = max_over_ndim(input, axis_list) + + if mins.size(0) != input.size(axis): + raise AssertionError( + "Dimensions of resultant channel range does not match size of requested axis" + ) + return maxs - mins + + +def get_name_by_module(model, module): + """Get the name of a module within a model. + + Args: + model: a model (nn.module) that equalization is to be applied on + module: a module within the model + + Returns: + name: the name of the module within the model + """ + for name, m in model.named_modules(): + if m is module: + return name + raise ValueError("module is not in the model") + + +def cross_layer_equalization(module1, module2, output_axis=0, input_axis=1): + """Scale the range of Tensor1.output to equal Tensor2.input. + + Given two adjacent tensors', the weights are scaled such that + the ranges of the first tensors' output channel are equal to the + ranges of the second tensors' input channel + """ + if ( + type(module1) not in _all_supported_types + or type(module2) not in _all_supported_types + ): + raise ValueError( + "module type not supported:", type(module1), " ", type(module2) + ) + + bias = get_module_bias(module1) if has_bias(module1) else None + + weight1 = get_module_weight(module1) + weight2 = get_module_weight(module2) + + if weight1.size(output_axis) != weight2.size(input_axis): + raise TypeError( + "Number of output channels of first arg do not match \ + number input channels of second arg" + ) + + weight1_range = channel_range(weight1, output_axis) + weight2_range = channel_range(weight2, input_axis) + + # producing scaling factors to applied + weight2_range += 1e-9 + scaling_factors = torch.sqrt(weight1_range / weight2_range) + inverse_scaling_factors = torch.reciprocal(scaling_factors) + + if bias is not None: + bias = bias * inverse_scaling_factors + + # formatting the scaling (1D) tensors to be applied on the given argument tensors + # pads axis to (1D) tensors to then be broadcasted + size1 = [1] * weight1.ndim + size1[output_axis] = weight1.size(output_axis) + size2 = [1] * weight2.ndim + size2[input_axis] = weight2.size(input_axis) + + scaling_factors = torch.reshape(scaling_factors, size2) + inverse_scaling_factors = torch.reshape(inverse_scaling_factors, size1) + + weight1 = weight1 * inverse_scaling_factors + weight2 = weight2 * scaling_factors + + set_module_weight(module1, weight1) + if bias is not None: + set_module_bias(module1, bias) + set_module_weight(module2, weight2) + + +def process_paired_modules_list_to_name(model, paired_modules_list): + """Processes a list of paired modules to a list of names of paired modules.""" + + for group in paired_modules_list: + for i, item in enumerate(group): + if isinstance(item, torch.nn.Module): + group[i] = get_name_by_module(model, item) + elif not isinstance(item, str): + raise TypeError("item must be a nn.Module or a string") + return paired_modules_list + + +def expand_groups_in_paired_modules_list(paired_modules_list): + """Expands module pair groups larger than two into groups of two modules.""" + new_list = [] + + for group in paired_modules_list: + if len(group) == 1: + raise ValueError("Group must have at least two modules") + elif len(group) == 2: + new_list.append(group) + elif len(group) > 2: + new_list.extend([group[i], group[i + 1]] for i in range(len(group) - 1)) + + return new_list + + +def equalize(model, paired_modules_list, threshold=1e-4, inplace=True): + """Equalize modules until convergence is achieved. + + Given a list of adjacent modules within a model, equalization will + be applied between each pair, this will repeated until convergence is achieved + + Keeps a copy of the changing modules from the previous iteration, if the copies + are not that different than the current modules (determined by converged_test), + then the modules have converged enough that further equalizing is not necessary + + Reference is section 4.1 of this paper https://arxiv.org/pdf/1906.04721.pdf + + Args: + model: a model (nn.Module) that equalization is to be applied on + paired_modules_list (List(List[nn.module || str])): a list of lists + where each sublist is a pair of two submodules found in the model, + for each pair the two modules have to be adjacent in the model, + with only piece-wise-linear functions like a (P)ReLU or LeakyReLU in between + to get expected results. + The list can contain either modules, or names of modules in the model. + If you pass multiple modules in the same list, they will all be equalized together. + threshold (float): a number used by the converged function to determine what degree + of similarity between models is necessary for them to be called equivalent + inplace (bool): determines if function is inplace or not + """ + + paired_modules_list = process_paired_modules_list_to_name( + model, paired_modules_list + ) + + if not inplace: + model = copy.deepcopy(model) + + paired_modules_list = expand_groups_in_paired_modules_list(paired_modules_list) + + name_to_module: dict[str, torch.nn.Module] = {} + previous_name_to_module: dict[str, Any] = {} + name_set = set(chain.from_iterable(paired_modules_list)) + + for name, module in model.named_modules(): + if name in name_set: + name_to_module[name] = module + previous_name_to_module[name] = None + while not converged(name_to_module, previous_name_to_module, threshold): + for pair in paired_modules_list: + previous_name_to_module[pair[0]] = copy.deepcopy(name_to_module[pair[0]]) + previous_name_to_module[pair[1]] = copy.deepcopy(name_to_module[pair[1]]) + + cross_layer_equalization(name_to_module[pair[0]], name_to_module[pair[1]]) + + return model + + +def converged(curr_modules, prev_modules, threshold=1e-4): + """Test whether modules are converged to a specified threshold. + + Tests for the summed norm of the differences between each set of modules + being less than the given threshold + + Takes two dictionaries mapping names to modules, the set of names for each dictionary + should be the same, looping over the set of names, for each name take the difference + between the associated modules in each dictionary + + """ + if curr_modules.keys() != prev_modules.keys(): + raise ValueError( + "The keys to the given mappings must have the same set of names of modules" + ) + + summed_norms = torch.tensor(0.0) + if None in prev_modules.values(): + return False + for name in curr_modules: + curr_weight = get_module_weight(curr_modules[name]) + prev_weight = get_module_weight(prev_modules[name]) + + difference = curr_weight.sub(prev_weight) + summed_norms += torch.norm(difference) + return bool(summed_norms < threshold) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/_learnable_fake_quantize.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/_learnable_fake_quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..00b824f8d1ecfe2086576eb3a4c16c4321e9e892 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/_learnable_fake_quantize.py @@ -0,0 +1,199 @@ +# mypy: allow-untyped-defs + +import torch +from torch.nn.parameter import Parameter + + +__all__: list[str] = [] + + +class _LearnableFakeQuantize(torch.ao.quantization.FakeQuantizeBase): + r"""Generalized extension of the FakeQuantize module in fake_quantize.py. + + This is an extension of the FakeQuantize module in fake_quantize.py, which + supports more generalized lower-bit quantization and supports learning of the scale + and zero point parameters through backpropagation. + + In addition to the attributes in the original FakeQuantize module, the _LearnableFakeQuantize + module also includes the following attributes to support quantization parameter learning. + + * :attr:`channel_len` defines the length of the channel when initializing scale and zero point + for the per channel case. + + * :attr:`use_grad_scaling` defines the flag for whether the gradients for scale and zero point are + normalized by the constant, which is proportional to the square root of the number of + elements in the tensor. The related literature justifying the use of this particular constant + can be found here: https://openreview.net/pdf?id=rkgO66VKDS. + + * :attr:`fake_quant_enabled` defines the flag for enabling fake quantization on the output. + + * :attr:`static_enabled` defines the flag for using observer's static estimation for + scale and zero point. + + * :attr:`learning_enabled` defines the flag for enabling backpropagation for scale and zero point. + """ + + def __init__( + self, + observer, + quant_min=0, + quant_max=255, + scale=1.0, + zero_point=0.0, + channel_len=-1, + use_grad_scaling=False, + **observer_kwargs, + ): + super().__init__() + if quant_min >= quant_max: + raise AssertionError("quant_min must be strictly less than quant_max.") + self.quant_min = quant_min + self.quant_max = quant_max + # also pass quant_min and quant_max to observer + observer_kwargs["quant_min"] = quant_min + observer_kwargs["quant_max"] = quant_max + self.use_grad_scaling = use_grad_scaling + if channel_len == -1: + self.scale = Parameter(torch.tensor([scale])) + self.zero_point = Parameter(torch.tensor([zero_point])) + else: + if not (isinstance(channel_len, int) and channel_len > 0): + raise AssertionError("Channel size must be a positive integer.") + self.scale = Parameter(torch.tensor([scale] * channel_len)) + self.zero_point = Parameter(torch.tensor([zero_point] * channel_len)) + + self.activation_post_process = observer(**observer_kwargs) + if torch.iinfo(self.activation_post_process.dtype).min > quant_min: + raise AssertionError("quant_min out of bound") + if quant_max > torch.iinfo(self.activation_post_process.dtype).max: + raise AssertionError("quant_max out of bound") + self.dtype = self.activation_post_process.dtype + self.qscheme = self.activation_post_process.qscheme + self.ch_axis = ( + self.activation_post_process.ch_axis + if hasattr(self.activation_post_process, "ch_axis") + else -1 + ) + self.register_buffer("fake_quant_enabled", torch.tensor([1], dtype=torch.uint8)) + self.register_buffer("static_enabled", torch.tensor([1], dtype=torch.uint8)) + self.register_buffer("learning_enabled", torch.tensor([0], dtype=torch.uint8)) + + bitrange = torch.tensor(quant_max - quant_min + 1).double() + self.bitwidth = int(torch.log2(bitrange).item()) + self.register_buffer("eps", torch.tensor([torch.finfo(torch.float32).eps])) + + @torch.jit.export + def enable_param_learning(self): + r"""Enable parameter learning over static observer estimates. + + Enables learning of quantization parameters and + disables static observer estimates. Forward path returns fake quantized X. + """ + self.toggle_qparam_learning(enabled=True).toggle_fake_quant( + enabled=True + ).toggle_observer_update(enabled=False) + return self + + @torch.jit.export + def enable_static_estimate(self): + """Enable static estimates of quantization parameters. + + Enables static observer estimates and disables learning of + quantization parameters. Forward path returns fake quantized X. + """ + self.toggle_qparam_learning(enabled=False).toggle_fake_quant( + enabled=True + ).toggle_observer_update(enabled=True) + + @torch.jit.export + def enable_static_observation(self): + """Enable accumulation of data without updating quantization parameters. + + Enables static observer accumulating data from input but doesn't + update the quantization parameters. Forward path returns the original X. + """ + self.toggle_qparam_learning(enabled=False).toggle_fake_quant( + enabled=False + ).toggle_observer_update(enabled=True) + + @torch.jit.export + def toggle_observer_update(self, enabled=True): + self.static_enabled[0] = int(enabled) # type: ignore[operator] + return self + + @torch.jit.export + def enable_observer(self, enabled=True): + self.toggle_observer_update(enabled) + + @torch.jit.export + def toggle_qparam_learning(self, enabled=True): + self.learning_enabled[0] = int(enabled) # type: ignore[operator] + self.scale.requires_grad = enabled + self.zero_point.requires_grad = enabled + return self + + @torch.jit.export + def toggle_fake_quant(self, enabled=True): + self.fake_quant_enabled[0] = int(enabled) + return self + + @torch.jit.export + def observe_quant_params(self): + print(f"_LearnableFakeQuantize Scale: {self.scale.detach()}") + print(f"_LearnableFakeQuantize Zero Point: {self.zero_point.detach()}") + + @torch.jit.export + def calculate_qparams(self): # type: ignore[override] + self.scale.data.clamp_(min=self.eps.item()) # type: ignore[operator] + scale = self.scale.detach() + zero_point = ( + self.zero_point.detach() + .round() + .clamp(self.quant_min, self.quant_max) + .long() + ) + return scale, zero_point + + def forward(self, X): + if self.static_enabled[0] == 1: # type: ignore[index] + self.activation_post_process(X.detach()) + _scale, _zero_point = self.activation_post_process.calculate_qparams() + _scale = _scale.to(self.scale.device) + _zero_point = _zero_point.to(self.zero_point.device) + self.scale.data.copy_(_scale) + self.zero_point.data.copy_(_zero_point) + else: + self.scale.data.clamp_(min=self.eps.item()) # type: ignore[operator] + + if self.fake_quant_enabled[0] == 1: + if self.qscheme in ( + torch.per_channel_symmetric, + torch.per_tensor_symmetric, + ): + self.zero_point.data.zero_() + + if self.use_grad_scaling: + grad_factor = 1.0 / (X.numel() * self.quant_max) ** 0.5 + else: + grad_factor = 1.0 + if self.qscheme in (torch.per_channel_symmetric, torch.per_channel_affine): + X = torch._fake_quantize_learnable_per_channel_affine( + X, + self.scale, + self.zero_point, + self.ch_axis, + self.quant_min, + self.quant_max, + grad_factor, + ) + else: + X = torch._fake_quantize_learnable_per_tensor_affine( + X, + self.scale, + self.zero_point, + self.quant_min, + self.quant_max, + grad_factor, + ) + + return X diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1dd69022cbaad3425ea38dbd9b96a8ad7f4c1424 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/_common_operator_config_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/_common_operator_config_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ffa014be83098a4d011b005b44438282020c620 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/_common_operator_config_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/_qnnpack_pt2e.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/_qnnpack_pt2e.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb4d209dda739c507e318745817490260bede243 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/_qnnpack_pt2e.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/backend_config.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/backend_config.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ca80266c808b0fb3a22f61c244444e9c41f139e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/backend_config.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/executorch.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/executorch.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5bad44de2b20c90064d9a240d447e806034849e7 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/executorch.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/fbgemm.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/fbgemm.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..05e0306dead285227eefe85a9cf87026ea2518b2 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/fbgemm.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/native.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/native.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..18401bdb72d8dfc448ca7b01050d5fcdf324c6d2 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/native.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/onednn.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/onednn.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5248f942f4b93cc845af94cb86cd1c3ba197007b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/onednn.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/qnnpack.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/qnnpack.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..909a51dbcf73daaca0137802e8d000ecfd7896e1 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/qnnpack.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/tensorrt.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/tensorrt.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ede4bf948a3dfcb327c5286877b9193e80b7e489 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/tensorrt.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..06c8305119649b9c435733d78762081b2090d7ab Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/x86.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/x86.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..78bc46f409be6175d0e7e02f805b4c6b4c0c9e19 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/x86.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fake_quantize.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fake_quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..c4a380946c8a06dd884680fc52cf1350f49772f8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fake_quantize.py @@ -0,0 +1,663 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +"""Implements modules used to perform fake quantization.""" + +import re +from abc import ABC, abstractmethod +from typing import Any + +import torch +from torch.ao.quantization.observer import ( + _with_args, + default_fixed_qparams_range_0to1_observer, + default_fixed_qparams_range_neg1to1_observer, + FixedQParamsObserver, + HistogramObserver, + MovingAverageMinMaxObserver, + MovingAveragePerChannelMinMaxObserver, +) +from torch.nn import Module + + +__all__ = [ + "FakeQuantizeBase", + "FakeQuantize", + "FixedQParamsFakeQuantize", + "FusedMovingAvgObsFakeQuantize", + "disable_fake_quant", + "disable_observer", + "enable_fake_quant", + "enable_observer", + "default_fake_quant", + "default_weight_fake_quant", + "default_dynamic_fake_quant", + "default_fixed_qparams_range_neg1to1_fake_quant", + "default_fixed_qparams_range_0to1_fake_quant", + "default_symmetric_fixed_qparams_fake_quant", + "default_affine_fixed_qparams_fake_quant", + "default_per_channel_weight_fake_quant", + "default_embedding_fake_quant", + "default_embedding_fake_quant_4bit", + "default_histogram_fake_quant", + "default_fused_act_fake_quant", + "default_fused_wt_fake_quant", + "default_fused_per_channel_wt_fake_quant", + "fused_wt_fake_quant_range_neg_127_to_127", + "fused_per_channel_wt_fake_quant_range_neg_127_to_127", +] + + +def _is_per_channel(qscheme: "torch.qscheme") -> bool: + return qscheme in [ + torch.per_channel_symmetric, + torch.per_channel_affine, + torch.per_channel_affine_float_qparams, + ] + + +def _is_per_tensor(qscheme: "torch.qscheme") -> bool: + return qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine] + + +def _is_symmetric_quant(qscheme: "torch.qscheme") -> bool: + return qscheme in [torch.per_tensor_symmetric, torch.per_channel_symmetric] + + +def _is_float_qparams(qscheme: "torch.qscheme") -> bool: + return qscheme == torch.per_channel_affine_float_qparams + + +class FakeQuantizeBase(ABC, Module): + r"""Base fake quantize module. + + Base fake quantize module + Any fake quantize implementation should derive from this class. + + Concrete fake quantize module should follow the same API. In forward, they will update + the statistics of the observed Tensor and fake quantize the input. They should also provide a + `calculate_qparams` function that computes the quantization parameters given + the collected statistics. + + """ + + fake_quant_enabled: torch.Tensor + observer_enabled: torch.Tensor + + def __init__(self) -> None: + """Set fake_quant_enabled and observer_enabled.""" + super().__init__() + # fake_quant_enabled and observer_enabled are buffers to support their + # replication in DDP. Data type is uint8 because NCCL does not support + # bool tensors. + self.register_buffer("fake_quant_enabled", torch.tensor([1], dtype=torch.uint8)) + self.register_buffer("observer_enabled", torch.tensor([1], dtype=torch.uint8)) + + @abstractmethod + def forward(self, x): + pass + + @abstractmethod + def calculate_qparams(self, **kwargs): + pass + + @torch.jit.export + def enable_fake_quant(self, enabled: bool = True) -> None: + self.fake_quant_enabled[0] = 1 if enabled else 0 + + @torch.jit.export + def disable_fake_quant(self): + self.enable_fake_quant(False) + + @torch.jit.export + def enable_observer(self, enabled: bool = True) -> None: + self.observer_enabled[0] = 1 if enabled else 0 + + @torch.jit.export + def disable_observer(self): + self.enable_observer(False) + + @classmethod + def with_args(cls, **kwargs): + fake_quant_constructor = _with_args(cls, **kwargs) + # need to assign the correct module to fake_quantize + # constructors to satisfy public v private requirements + fake_quant_constructor.__module__ = "torch.ao.quantization.fake_quantize" + return fake_quant_constructor + + +class FakeQuantize(FakeQuantizeBase): + r"""Simulate the quantize and dequantize operations in training time. + + The output of this module is given by:: + + x_out = ( + clamp(round(x / scale + zero_point), quant_min, quant_max) - zero_point + ) * scale + + * :attr:`is_dynamic` indicates whether the fake quantie is a placeholder for dynamic quantization + operators (choose_qparams -> q -> dq) or static quantization operators (q -> dq) + + * :attr:`scale` defines the scale factor used for quantization. + + * :attr:`zero_point` specifies the quantized value to which 0 in floating point maps to + + * :attr:`fake_quant_enabled` controls the application of fake quantization on tensors, note that + statistics can still be updated. + + * :attr:`observer_enabled` controls statistics collection on tensors + + * :attr:`dtype` specifies the quantized dtype that is being emulated with fake-quantization, + allowable values are torch.qint8 and torch.quint8. + + Args: + + observer (module): Module for observing statistics on input tensors and calculating scale + and zero-point. + observer_kwargs (optional): Arguments for the observer module + + Attributes: + activation_post_process (Module): User provided module that collects statistics on the input tensor and + provides a method to calculate scale and zero-point. + + """ + + scale: torch.Tensor + zero_point: torch.Tensor + + def __init__( + self, + observer=MovingAverageMinMaxObserver, + quant_min=None, + quant_max=None, + is_dynamic=False, + **observer_kwargs, + ): + super().__init__() + # Populate quant_min/quant_max to observer_kwargs if valid + if quant_min is not None and quant_max is not None: + if quant_min > quant_max: + raise AssertionError( + "quant_min must be less than or equal to quant_max" + ) + dtype = observer_kwargs.get("dtype", torch.quint8) + if hasattr(observer, "p"): + # In case observer is _PartialWrapper, dtype can be stored in + # observer.p.keywords["dtype"] + dtype = getattr(getattr(observer, "p", {}), "keywords", {}).get( + "dtype", dtype + ) + # pyrefly: ignore [bad-argument-type] + if torch.iinfo(dtype).min > quant_min: + raise AssertionError("quant_min out of bound") + # pyrefly: ignore [bad-argument-type] + if quant_max > torch.iinfo(dtype).max: + raise AssertionError("quant_max out of bound") + observer_kwargs.update({"quant_min": quant_min, "quant_max": quant_max}) + observer_kwargs["is_dynamic"] = is_dynamic + self.activation_post_process = observer(**observer_kwargs) + # TODO: keeping self.quant_min/max for BC; remove after a couple releases + # Users should use self.activation_post_process.quant_min + self.quant_min = self.activation_post_process.quant_min + self.quant_max = self.activation_post_process.quant_max + self.is_dynamic = self.activation_post_process.is_dynamic + if _is_float_qparams(self.activation_post_process.qscheme): + zero_point_dtype = torch.float + else: + zero_point_dtype = torch.int + self.register_buffer("scale", torch.tensor([1.0], dtype=torch.float)) + self.register_buffer("zero_point", torch.tensor([0], dtype=zero_point_dtype)) + self.dtype = self.activation_post_process.dtype + self.qscheme = self.activation_post_process.qscheme + self.ch_axis = ( + self.activation_post_process.ch_axis + if hasattr(self.activation_post_process, "ch_axis") + else -1 + ) + if not (_is_per_channel(self.qscheme) or _is_per_tensor(self.qscheme)): + raise AssertionError( + "Only per channel and per tensor quantization are supported in fake quantize" + + " got qscheme: " + + str(self.qscheme) + ) + self.is_per_channel = _is_per_channel(self.qscheme) + + @torch.jit.export + def calculate_qparams(self): # type: ignore[override] + return self.activation_post_process.calculate_qparams() + + def forward(self, X): + if self.observer_enabled[0] == 1: + self.activation_post_process(X.detach()) + _scale, _zero_point = self.calculate_qparams() + _scale, _zero_point = ( + _scale.to(self.scale.device), + _zero_point.to(self.zero_point.device), + ) + if self.scale.shape != _scale.shape: + self.scale.resize_(_scale.shape) + self.zero_point.resize_(_zero_point.shape) + self.scale.copy_(_scale) + self.zero_point.copy_(_zero_point) + + if self.fake_quant_enabled[0] == 1: + if self.is_per_channel: + X = torch.fake_quantize_per_channel_affine( + X, + self.scale, + self.zero_point, + self.ch_axis, + self.activation_post_process.quant_min, + self.activation_post_process.quant_max, + ) + else: + X = torch.fake_quantize_per_tensor_affine( + X, + self.scale, + self.zero_point, + self.activation_post_process.quant_min, + self.activation_post_process.quant_max, + ) + return X + + @torch.jit.export + def extra_repr(self): + return ( + f"fake_quant_enabled={self.fake_quant_enabled}, observer_enabled={self.observer_enabled}, " + f"quant_min={self.activation_post_process.quant_min}, quant_max={self.activation_post_process.quant_max}, " + f"dtype={self.dtype}, qscheme={self.qscheme}, ch_axis={self.ch_axis}, " + f"scale={self.scale}, zero_point={self.zero_point}" + ) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + # We cannot currently register scalar values as buffers, so need to manually + # specify serialization here. + super()._save_to_state_dict(destination, prefix, keep_vars) + destination[prefix + "scale"] = self.scale + destination[prefix + "zero_point"] = self.zero_point + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + # Removing this function throws an error that the size of the loaded tensor does not match the original size + # i.e., These buffers start out with numel 0 and become numel 1 once they have their first forward pass. + local_state = ["scale", "zero_point"] + for name in local_state: + key = prefix + name + if key in state_dict: + val = state_dict[key] + # Custom handling to allow loading scale and zero_point + # of size N into uninitialized buffers of size 0. The + # buffers are resized here, and the values are copied in + # the default state_dict loading code of the parent. + if name == "scale": + self.scale.resize_(val.shape) + else: + if name != "zero_point": + raise AssertionError( + "Expected 'zero_point' but got different state key" + ) + self.zero_point.resize_(val.shape) + # For torchscript module we need to update the attributes here since we do not + # call the `_load_from_state_dict` function defined module.py + if torch.jit.is_scripting(): + if name == "scale": + self.scale.copy_(val) + else: + if name != "zero_point": + raise AssertionError( + "Expected 'zero_point' but got different state key" + ) + self.zero_point.copy_(val) + elif strict: + missing_keys.append(key) + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + + +class FixedQParamsFakeQuantize(FakeQuantize): + """Simulate quantize and dequantize in training time. + + Simulate quantize and dequantize with fixed quantization + parameters in training time. Only per tensor quantization + is supported. + """ + + # TODO: rename observer to observer_ctr + def __init__(self, observer): + super().__init__(observer=observer) + if type(self.activation_post_process) is not FixedQParamsObserver: + raise AssertionError( + f"{self.__class__.__name__}'s observer must be a {FixedQParamsObserver.__name__}" + ) + self._observer_ctr = observer + self.scale = self.activation_post_process.scale + self.zero_point = self.activation_post_process.zero_point + if not _is_per_tensor(self.qscheme): + raise AssertionError( + "Only per tensor quantization is supported" + + " FixedQParamsFakeQuantize module, got qscheme:" + + str(self.qscheme) + ) + + @torch.jit.export + def calculate_qparams(self): # type: ignore[override] + return self.scale, self.zero_point + + @torch.jit.export + def extra_repr(self): + """Define a string representation of the object's attributes.""" + return ( + f"fake_quant_enabled={self.fake_quant_enabled}, observer_enabled={self.observer_enabled}, " + f"scale={self.scale}, zero_point={self.zero_point}, " + f"dtype={self.dtype}, quant_min={self.activation_post_process.quant_min}, " + f"quant_max={self.activation_post_process.quant_max}, qscheme={self.qscheme}" + ) + + +class FusedMovingAvgObsFakeQuantize(FakeQuantize): + r"""Define a fused module to observe the tensor. + + Fused module that is used to observe the input tensor (compute min/max), compute + scale/zero_point and fake_quantize the tensor. + This module uses calculation similar MovingAverageMinMaxObserver for the inputs, + to compute the min/max values in order to compute the scale/zero_point. + The qscheme input in the observer is used to differentiate between symmetric/affine + quantization scheme. + + The output of this module is given by + x_out = (clamp(round(x/scale + zero_point), quant_min, quant_max)-zero_point)*scale + + Similar to :class:`~torch.ao.quantization.FakeQuantize`, and accepts the same attributes as the + base class. + + """ + + def __init__( + self, + observer: Any = MovingAverageMinMaxObserver, + quant_min: int = 0, + quant_max: int = 255, + **observer_kwargs: Any, + ) -> None: + super().__init__(observer, quant_min, quant_max, **observer_kwargs) + if not isinstance( + self.activation_post_process, + (MovingAverageMinMaxObserver, MovingAveragePerChannelMinMaxObserver), + ): + raise AssertionError( + "Fused observer+fake_quant module only works with MovingAverageMinMaxObserver" + ) + self.register_buffer("fake_quant_enabled", torch.tensor([1], dtype=torch.long)) + self.register_buffer("observer_enabled", torch.tensor([1], dtype=torch.long)) + self.is_symmetric_quant = _is_symmetric_quant( + self.activation_post_process.qscheme + ) + + @torch.jit.export + def calculate_qparams(self) -> tuple[torch.Tensor, torch.Tensor]: # type: ignore[override] + return self.activation_post_process.calculate_qparams() + + @torch.jit.export + def extra_repr(self) -> str: + return ( + f"fake_quant_enabled={self.fake_quant_enabled}, observer_enabled={self.observer_enabled}, " + f"scale={self.scale}, zero_point={self.zero_point}, dtype={self.dtype}, " + f"quant_min={self.activation_post_process.quant_min}, quant_max={self.activation_post_process.quant_max}, " + f"qscheme={self.qscheme}, reduce_range={self.activation_post_process.reduce_range}" + ) + + def forward(self, X: torch.Tensor) -> torch.Tensor: + return torch.fused_moving_avg_obs_fake_quant( + X, + self.observer_enabled, + self.fake_quant_enabled, + self.activation_post_process.min_val, + self.activation_post_process.max_val, + self.scale, + self.zero_point, + self.activation_post_process.averaging_constant, + self.activation_post_process.quant_min, + self.activation_post_process.quant_max, + self.ch_axis, + self.is_per_channel, + self.is_symmetric_quant, + ) + + +default_fake_quant = FakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, + quant_min=0, + quant_max=255, + dtype=torch.quint8, + qscheme=torch.per_tensor_affine, + reduce_range=True, +) +""" +Default fake_quant for activations. +""" + +default_weight_fake_quant = FakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, + quant_min=-128, + quant_max=127, + dtype=torch.qint8, + qscheme=torch.per_tensor_symmetric, + reduce_range=False, +) +""" +Default fake_quant for weights. +Observer is memoryless since averaging_constant is 1. +""" + +default_dynamic_fake_quant = FakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, + quant_min=0, + quant_max=255, + is_dynamic=True, + dtype=torch.quint8, + averaging_constant=1, +) +""" +Default dynamic fake_quant for activations. +""" + +default_fixed_qparams_range_neg1to1_fake_quant = FixedQParamsFakeQuantize.with_args( + observer=default_fixed_qparams_range_neg1to1_observer +) +default_fixed_qparams_range_0to1_fake_quant = FixedQParamsFakeQuantize.with_args( + observer=default_fixed_qparams_range_0to1_observer +) +# TODO: the following 2 variables are kept for backwards compatibility; remove after a few releases +default_symmetric_fixed_qparams_fake_quant = ( + default_fixed_qparams_range_neg1to1_fake_quant +) +default_affine_fixed_qparams_fake_quant = default_fixed_qparams_range_0to1_fake_quant + +default_per_channel_weight_fake_quant = FakeQuantize.with_args( + observer=MovingAveragePerChannelMinMaxObserver, + quant_min=-128, + quant_max=127, + dtype=torch.qint8, + qscheme=torch.per_channel_symmetric, + reduce_range=False, + ch_axis=0, +) +""" +Default fake_quant for per-channel weights. +Observer is memoryless since averaging_constant is 1. +""" +default_embedding_fake_quant = FakeQuantize.with_args( + observer=MovingAveragePerChannelMinMaxObserver, + qscheme=torch.per_channel_affine_float_qparams, + dtype=torch.quint8, + quant_min=0, + quant_max=255, + ch_axis=0, + averaging_constant=1, +) +""" +Default fake_quant for embeddings. +Observer is memoryless since averaging_constant is 1. +""" + +default_embedding_fake_quant_4bit = FakeQuantize.with_args( + observer=MovingAveragePerChannelMinMaxObserver, + qscheme=torch.per_channel_affine_float_qparams, + ch_axis=0, + dtype=torch.quint4x2, + averaging_constant=1, +) + +default_histogram_fake_quant = FakeQuantize.with_args( + observer=HistogramObserver, + quant_min=0, + quant_max=255, + dtype=torch.quint8, + qscheme=torch.per_tensor_affine, + reduce_range=True, +) +""" +Fake_quant for activations using a histogram.. +""" + + +default_fused_act_fake_quant = FusedMovingAvgObsFakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, + quant_min=0, + quant_max=255, + dtype=torch.quint8, +) + +""" +Fused version of `default_fake_quant`, with improved performance. +""" + + +default_fused_wt_fake_quant = FusedMovingAvgObsFakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, + quant_min=-128, + quant_max=127, + dtype=torch.qint8, + qscheme=torch.per_tensor_symmetric, +) +""" +Fused version of `default_weight_fake_quant`, with improved performance. +""" + +default_fused_per_channel_wt_fake_quant = FusedMovingAvgObsFakeQuantize.with_args( + observer=MovingAveragePerChannelMinMaxObserver, + quant_min=-128, + quant_max=127, + dtype=torch.qint8, + qscheme=torch.per_channel_symmetric, +) +""" +Fused version of `default_per_channel_weight_fake_quant`, with improved performance. +""" + +fused_wt_fake_quant_range_neg_127_to_127 = FusedMovingAvgObsFakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, + quant_min=-127, + quant_max=127, + dtype=torch.qint8, + qscheme=torch.per_tensor_symmetric, + eps=2**-12, +) +""" +Fused version of `default_weight_fake_quant`, with the 8-bit values restricted to [-127, +127], excluding -128. +""" + +fused_per_channel_wt_fake_quant_range_neg_127_to_127 = ( + FusedMovingAvgObsFakeQuantize.with_args( + observer=MovingAveragePerChannelMinMaxObserver, + quant_min=-127, + quant_max=127, + dtype=torch.qint8, + qscheme=torch.per_channel_symmetric, + eps=2**-12, + ) +) + +""" +Fused version of `default_per_channel_weight_fake_quant`, with the 8-bit values restricted to [-127, +127], excluding -128. +""" + + +def _is_fake_quant_script_module(mod): + """Return true if given mod is an instance of FakeQuantize script module.""" + if isinstance(mod, torch.jit.RecursiveScriptModule): + # qualified name looks like '__torch__.torch.ao.quantization.fake_quantize.___torch_mangle_2.FakeQuantize' + suffix = mod._c.qualified_name.split(".", 1)[1] + name = re.sub(r"\.___torch_mangle_\d+", "", suffix) + return ( + name == "torch.ao.quantization.fake_quantize.FakeQuantize" + or name + == "torch.ao.quantization.fake_quantize.FusedMovingAvgObsFakeQuantize" + ) + return False + + +def disable_fake_quant(mod): + """Disable fake quantization for the module. + + Disable fake quantization for this module, if applicable. Example usage:: + + # model is any PyTorch model + model.apply(torch.ao.quantization.disable_fake_quant) + + """ + if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod): + mod.disable_fake_quant() + + +def enable_fake_quant(mod): + """Enable fake quantization for the module. + + Enable fake quantization for this module, if applicable. Example usage:: + + # model is any PyTorch model + model.apply(torch.ao.quantization.enable_fake_quant) + + """ + if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod): + mod.enable_fake_quant() + + +def disable_observer(mod): + """Disable observation for this module. + + Disable observation for this module, if applicable. Example usage:: + + # model is any PyTorch model + model.apply(torch.ao.quantization.disable_observer) + + """ + if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod): + mod.disable_observer() + + +def enable_observer(mod): + """Enable observation for this module. + + Enable observation for this module, if applicable. Example usage:: + + # model is any PyTorch model + model.apply(torch.ao.quantization.enable_observer) + + """ + if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod): + mod.enable_observer() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fuse_modules.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fuse_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..4f664c699144917d3314eee7bdf5dd92f9697108 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fuse_modules.py @@ -0,0 +1,215 @@ +# mypy: allow-untyped-defs +import copy + +import torch.nn as nn + +# for backward compatibility +from torch.ao.quantization.fuser_method_mappings import ( # noqa: F401 # noqa: F401 + fuse_conv_bn, + fuse_conv_bn_relu, + get_fuser_method, +) +from torch.nn.utils.parametrize import type_before_parametrizations + + +__all__ = [ + "fuse_known_modules", + "fuse_modules", + "fuse_modules_qat", +] + + +# Generalization of getattr +def _get_module(model, submodule_key): + tokens = submodule_key.split(".") + cur_mod = model + for s in tokens: + cur_mod = getattr(cur_mod, s) + return cur_mod + + +# Generalization of setattr +def _set_module(model, submodule_key, module): + tokens = submodule_key.split(".") + sub_tokens = tokens[:-1] + cur_mod = model + for s in sub_tokens: + cur_mod = getattr(cur_mod, s) + + setattr(cur_mod, tokens[-1], module) + + +def fuse_known_modules(mod_list, is_qat, additional_fuser_method_mapping=None): + r"""Return a list of known fuse modules. + + Returns a list of modules that fuses the operations specified + in the input module list. + + Fuses only the following sequence of modules: + conv, bn + conv, bn, relu + conv, relu + linear, bn + linear, relu + For these sequences, the first element in the output module list performs + the fused operation. The rest of the elements are set to nn.Identity() + """ + types = tuple(type_before_parametrizations(m) for m in mod_list) + fuser_method = get_fuser_method(types, additional_fuser_method_mapping) + if fuser_method is None: + raise NotImplementedError(f"Cannot fuse modules: {types}") + new_mod: list[nn.Module | None] = [None] * len(mod_list) + fused = fuser_method(is_qat, *mod_list) + # NOTE: forward hooks not processed in the two following for loops will be lost after the fusion + # Move pre forward hooks of the base module to resulting fused module + for pre_hook_fn in mod_list[0]._forward_pre_hooks.values(): + fused.register_forward_pre_hook(pre_hook_fn) + mod_list[0]._forward_pre_hooks.clear() + # Move post forward hooks of the last module to resulting fused module + for hook_fn in mod_list[-1]._forward_hooks.values(): + fused.register_forward_hook(hook_fn) + mod_list[-1]._forward_hooks.clear() + new_mod[0] = fused + + for i in range(1, len(mod_list)): + identity = nn.Identity() + identity.training = mod_list[0].training + new_mod[i] = identity + + return new_mod + + +def _fuse_modules_helper( + model, + modules_to_fuse, + is_qat, + fuser_func=fuse_known_modules, + fuse_custom_config_dict=None, +): + if fuse_custom_config_dict is None: + fuse_custom_config_dict = {} + additional_fuser_method_mapping = fuse_custom_config_dict.get( + "additional_fuser_method_mapping", {} + ) + mod_list = [_get_module(model, item) for item in modules_to_fuse] + + # Fuse list of modules + new_mod_list = fuser_func(mod_list, is_qat, additional_fuser_method_mapping) + + # Replace original module list with fused module list + for i, item in enumerate(modules_to_fuse): + _set_module(model, item, new_mod_list[i]) + + +def _fuse_modules( + model, + modules_to_fuse, + is_qat, + inplace=False, + fuser_func=fuse_known_modules, + fuse_custom_config_dict=None, +): + if not inplace: + model = copy.deepcopy(model) + + if all(isinstance(module_element, str) for module_element in modules_to_fuse): + # Handle case of modules_to_fuse being a list + _fuse_modules_helper( + model, modules_to_fuse, is_qat, fuser_func, fuse_custom_config_dict + ) + else: + # Handle case of modules_to_fuse being a list of lists + for module_list in modules_to_fuse: + _fuse_modules_helper( + model, module_list, is_qat, fuser_func, fuse_custom_config_dict + ) + return model + + +def fuse_modules( + model, + modules_to_fuse, + inplace=False, + fuser_func=fuse_known_modules, + fuse_custom_config_dict=None, +): + r"""Fuse a list of modules into a single module. + + Fuses only the following sequence of modules: + conv, bn + conv, bn, relu + conv, relu + linear, relu + bn, relu + All other sequences are left unchanged. + For these sequences, replaces the first item in the list + with the fused module, replacing the rest of the modules + with identity. + + Args: + model: Model containing the modules to be fused + modules_to_fuse: list of list of module names to fuse. Can also be a list + of strings if there is only a single list of modules to fuse. + inplace: bool specifying if fusion happens in place on the model, by default + a new model is returned + fuser_func: Function that takes in a list of modules and outputs a list of fused modules + of the same length. For example, + fuser_func([convModule, BNModule]) returns the list [ConvBNModule, nn.Identity()] + Defaults to torch.ao.quantization.fuse_known_modules + `fuse_custom_config_dict`: custom configuration for fusion + + .. code-block:: python + + # Example of fuse_custom_config_dict + fuse_custom_config_dict = { + # Additional fuser_method mapping + "additional_fuser_method_mapping": { + (torch.nn.Conv2d, torch.nn.BatchNorm2d): fuse_conv_bn + }, + } + + Returns: + model with fused modules. A new copy is created if inplace=True. + + Examples:: + + >>> # xdoctest: +SKIP + >>> m = M().eval() + >>> # m is a module containing the sub-modules below + >>> modules_to_fuse = [ ['conv1', 'bn1', 'relu1'], ['submodule.conv', 'submodule.relu']] + >>> fused_m = torch.ao.quantization.fuse_modules(m, modules_to_fuse) + >>> output = fused_m(input) + + >>> m = M().eval() + >>> # Alternately provide a single list of modules to fuse + >>> modules_to_fuse = ['conv1', 'bn1', 'relu1'] + >>> fused_m = torch.ao.quantization.fuse_modules(m, modules_to_fuse) + >>> output = fused_m(input) + + """ + return _fuse_modules( + model, + modules_to_fuse, + is_qat=False, + inplace=inplace, + fuser_func=fuser_func, + fuse_custom_config_dict=fuse_custom_config_dict, + ) + + +def fuse_modules_qat( + model, + modules_to_fuse, + inplace=False, + fuser_func=fuse_known_modules, + fuse_custom_config_dict=None, +): + """QAT version for `fuse_modules`.""" + return _fuse_modules( + model, + modules_to_fuse, + is_qat=True, + inplace=inplace, + fuser_func=fuser_func, + fuse_custom_config_dict=fuse_custom_config_dict, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fuser_method_mappings.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fuser_method_mappings.py new file mode 100644 index 0000000000000000000000000000000000000000..d72a3579438bc3e5e2687982ab4b550c680d2110 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fuser_method_mappings.py @@ -0,0 +1,314 @@ +# mypy: allow-untyped-defs +import itertools +from collections.abc import Callable +from typing import Any + +import torch.ao.nn.intrinsic as nni +import torch.nn as nn +from torch.ao.quantization.utils import get_combined_dict, MatchAllNode, Pattern + + +__all__ = [ + "fuse_conv_bn", + "fuse_conv_bn_relu", + "fuse_linear_bn", + "fuse_convtranspose_bn", + "get_fuser_method", + "get_fuser_method_new", +] + + +def fuse_conv_bn(is_qat, conv, bn): + r"""Return the fused the conv and bn modules. + Given the conv and bn modules, fuses them and returns the fused module + + Args: + is_qat: a flag for whether we are using quantization aware training fusion + or post training quantization fusion + conv: Module instance of type conv2d/conv3d + bn: Spatial BN instance that needs to be fused with the conv + + Examples:: + + >>> m1 = nn.Conv2d(10, 20, 3) + >>> b1 = nn.BatchNorm2d(20) + >>> # xdoctest: +SKIP + >>> m2 = fuse_conv_bn(m1, b1) + """ + if conv.training != bn.training: + raise AssertionError( + "Conv and BN both must be in the same mode (train or eval)." + ) + + fused_module_class_map = { + nn.Conv1d: nni.ConvBn1d, + nn.Conv2d: nni.ConvBn2d, + nn.Conv3d: nni.ConvBn3d, + } + + if is_qat: + if bn.num_features != conv.out_channels: + raise AssertionError( + "Output channel of Conv2d must match num_features of BatchNorm2d." + ) + if not bn.affine: + raise AssertionError( + "Only support fusing BatchNorm2d with affine set to True" + ) + if not bn.track_running_stats: + raise AssertionError( + "Only support fusing BatchNorm2d with tracking_running_stats set to True" + ) + fused_module_class = fused_module_class_map.get((type(conv)), None) + if fused_module_class is not None: + return fused_module_class(conv, bn) + else: + raise NotImplementedError(f"Cannot fuse train modules: {(conv, bn)}") + else: + return nn.utils.fuse_conv_bn_eval(conv, bn) + + +def fuse_conv_bn_relu(is_qat, conv, bn, relu): + r"""Return the fused conv and bv modules. + + Given the conv and bn modules, fuses them and returns the fused module + + Args: + is_qat: a flag for whether we are using quantization aware training fusion + or post training quantization fusion + conv: Module instance of type conv2d/conv3d + bn: Spatial BN instance that needs to be fused with the conv + + Examples:: + + >>> m1 = nn.Conv2d(10, 20, 3) + >>> b1 = nn.BatchNorm2d(20) + >>> r1 = nn.ReLU(inplace=False) + >>> # xdoctest: +SKIP + >>> m2 = fuse_conv_bn_relu(m1, b1, r1) + """ + if not (conv.training == bn.training == relu.training): + raise AssertionError( + "Conv and BN both must be in the same mode (train or eval)." + ) + fused_module: type[nn.Sequential] | None = None + if is_qat: + map_to_fused_module_train = { + nn.Conv1d: nni.ConvBnReLU1d, + nn.Conv2d: nni.ConvBnReLU2d, + nn.Conv3d: nni.ConvBnReLU3d, + } + if bn.num_features != conv.out_channels: + raise AssertionError( + "Output channel of Conv2d must match num_features of BatchNorm2d" + ) + if not bn.affine: + raise AssertionError( + "Only support fusing BatchNorm2d with affine set to True" + ) + if not bn.track_running_stats: + raise AssertionError( + "Only support fusing BatchNorm2d with tracking_running_stats set to True" + ) + fused_module = map_to_fused_module_train.get(type(conv), None) + if fused_module is not None: + return fused_module(conv, bn, relu) + else: + raise NotImplementedError(f"Cannot fuse train modules: {(conv, bn, relu)}") + else: + map_to_fused_module_eval = { + nn.Conv1d: nni.ConvReLU1d, + nn.Conv2d: nni.ConvReLU2d, + nn.Conv3d: nni.ConvReLU3d, + } + fused_module = map_to_fused_module_eval.get(type(conv), None) + if fused_module is not None: + fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn) + return fused_module(fused_conv, relu) + else: + raise NotImplementedError(f"Cannot fuse eval modules: {(conv, bn, relu)}") + + +def fuse_linear_bn(is_qat, linear, bn): + r"""Return the fused linear and bn modules. + Given the linear and bn modules, fuses them and returns the fused module + + Args: + is_qat: a flag for whether we are using quantization aware training fusion + or post training quantization fusion + linear: Module instance of type Linear + bn: BatchNorm1d instance that needs to be fused with the linear layer + + Examples:: + + >>> m1 = nn.Linear(20, 10) + >>> b1 = nn.BatchNorm1d(10) + >>> # xdoctest: +SKIP + >>> m2 = fuse_linear_bn(m1, b1) + """ + if linear.training != bn.training: + raise AssertionError( + "Linear and BN both must be in the same mode (train or eval)." + ) + + if is_qat: + if bn.num_features != linear.out_features: + raise AssertionError( + "Output features of Linear must match num_features of BatchNorm1d" + ) + if not bn.affine: + raise AssertionError( + "Only support fusing BatchNorm1d with affine set to True" + ) + if not bn.track_running_stats: + raise AssertionError( + "Only support fusing BatchNorm1d with tracking_running_stats set to True" + ) + return nni.LinearBn1d(linear, bn) + else: + return nn.utils.fusion.fuse_linear_bn_eval(linear, bn) + + +def fuse_convtranspose_bn(is_qat, convt, bn): + r"""Return the fused ConvTranspose and bn modules. + Given ConvTranspose and bn modules, fuses them and returns the fused module + + Args: + convt: Module instance of type ConvTransposeNd + bn: BatchNormNd instance that needs to be fused with the linear layer. + batch norm N should match the ConvTranspose N + + Examples:: + + >>> m1 = nn.ConvTranspose2d(10, 20, 3) + >>> b1 = nn.BatchNorm2d(20) + >>> # xdoctest: +SKIP + >>> m2 = fuse_convtranspose_bn(m1, b1) + """ + if convt.training != bn.training: + raise AssertionError( + "ConvTranspose and BN both must be in the same mode (train or eval)." + ) + + if is_qat: + raise Exception( # noqa: TRY002 + "Fusing ConvTranspose+BatchNorm not yet supported in QAT." + ) + else: + return nn.utils.fusion.fuse_conv_bn_eval(convt, bn, transpose=True) + + +def _sequential_wrapper2(sequential): + """Return a sequential wrapped that for is_qat and two modules. + Given a sequential class for two modules, return a function that takes + is_qat, and then two modules as argument, that ignores the is_qat flag + and always returns the sequential that combines the two input modules + """ + + def fuser_method(is_qat, m1, m2): + return sequential(m1, m2) + + return fuser_method + + +_DEFAULT_OP_LIST_TO_FUSER_METHOD: dict[tuple, nn.Sequential | Callable] = { + (nn.Conv1d, nn.BatchNorm1d): fuse_conv_bn, + (nn.Conv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu, + (nn.Conv2d, nn.BatchNorm2d): fuse_conv_bn, + (nn.Conv2d, nn.BatchNorm2d, nn.ReLU): fuse_conv_bn_relu, + (nn.Conv3d, nn.BatchNorm3d): fuse_conv_bn, + (nn.Conv3d, nn.BatchNorm3d, nn.ReLU): fuse_conv_bn_relu, + (nn.Conv1d, nn.ReLU): _sequential_wrapper2(nni.ConvReLU1d), + (nn.Conv2d, nn.ReLU): _sequential_wrapper2(nni.ConvReLU2d), + (nn.Conv3d, nn.ReLU): _sequential_wrapper2(nni.ConvReLU3d), + (nn.Linear, nn.BatchNorm1d): fuse_linear_bn, + (nn.Linear, nn.ReLU): _sequential_wrapper2(nni.LinearReLU), + (nn.BatchNorm2d, nn.ReLU): _sequential_wrapper2(nni.BNReLU2d), + (nn.BatchNorm3d, nn.ReLU): _sequential_wrapper2(nni.BNReLU3d), + (nn.ConvTranspose1d, nn.BatchNorm1d): fuse_convtranspose_bn, + (nn.ConvTranspose2d, nn.BatchNorm2d): fuse_convtranspose_bn, + (nn.ConvTranspose3d, nn.BatchNorm3d): fuse_convtranspose_bn, +} + + +def get_fuser_method(op_list, additional_fuser_method_mapping=None): + """Get fuser method for the given list of module types. + + Get fuser method for the given list of module types, + return None if fuser method does not exist + """ + if additional_fuser_method_mapping is None: + additional_fuser_method_mapping = {} + all_mappings = get_combined_dict( + _DEFAULT_OP_LIST_TO_FUSER_METHOD, additional_fuser_method_mapping + ) + fuser_method = all_mappings.get(op_list, None) + if fuser_method is None: + raise AssertionError(f"did not find fuser method for: {op_list} ") + return fuser_method + + +def _reverse2(f): + def reversed(is_qat, x, y): + return f(is_qat, y, x) + + return reversed + + +def _reverse3(f): + def reversed(is_qat, x, w): + y, z = w + return f(is_qat, z, y, x) + + return reversed + + +def _get_valid_patterns(op_pattern): + """Return a list of valid patterns generated from the op_pattern. + + Returns a list of valid patterns generated from the op_pattern, + since MatchAllNode can match all types of nodes, + e.g. pattern (torch.nn.Conv2d, torch.add) should also be able to match keys like + (MatchAllNode, torch.add) and (torch.nn.Conv2d, MatchAllNode) + + Example Input: + (torch.add, (torch.nn.ReLU, torch.nn.Conv2d)) + + Example Output: + [(torch.add, (torch.nn.ReLU, torch.nn.Conv2d)), + (torch.add, (torch.nn.ReLU, MatchAllNode)), + (torch.add, (MatchAllNode, torch.nn.Conv2d)), + (torch.add, (MatchAllNode, MatchAllNode)), + (MatchAllNode, (torch.nn.ReLU, torch.nn.Conv2d)), + (MatchAllNode, (torch.nn.ReLU, MatchAllNode)), + (MatchAllNode, (MatchAllNode, torch.nn.Conv2d)), + (MatchAllNode, (MatchAllNode, MatchAllNode)), + ] + """ + result: list[Any] + if isinstance(op_pattern, (tuple, list)): + sub_combs = [_get_valid_patterns(sub_pattern) for sub_pattern in op_pattern] + result = list(itertools.product(*sub_combs)) + else: + result = [op_pattern, MatchAllNode] + return result + + +def get_fuser_method_new( + op_pattern: Pattern, + fuser_method_mapping: dict[Pattern, nn.Sequential | Callable], +): + """Get fuser method. + + This will be made default after we deprecate the get_fuser_method + Would like to implement this first and have a separate PR for deprecation + """ + op_patterns = _get_valid_patterns(op_pattern) + fuser_method = None + for op_pattern in op_patterns: + fuser_method = fuser_method_mapping.get(op_pattern) + if fuser_method is not None: + break + if fuser_method is None: + raise AssertionError(f"did not find fuser method for: {op_pattern} ") + return fuser_method diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..402ba073147395b54b6b4d887b74f7242fb4651a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/_decomposed.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/_decomposed.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..74da54c4faf5524ebe3f15a9a66236c8ce53d4b2 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/_decomposed.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/_equalize.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/_equalize.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db2947158975d43fa057a57707f4bf59f1b813d3 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/_equalize.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/_lower_to_native_backend.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/_lower_to_native_backend.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95a36595e4274f0b0e401ca99d91369d683e07a6 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/_lower_to_native_backend.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/convert.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/convert.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4399cce813e4591d4a4ff9b5c0bccd498a1c2862 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/convert.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/custom_config.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/custom_config.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f3a4408a7f31dabd7c6353bbab25bb6bb33ed76d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/custom_config.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/fuse.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/fuse.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..50a4db3e47374fdfaa0f9c9ff8356e914bf8ae5b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/fuse.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/fuse_handler.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/fuse_handler.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9dbede6c86eae4b439e22bfb2e7a0404d09d62c1 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/fuse_handler.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/graph_module.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/graph_module.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..736ed6bd531e3c97f944a8137e0c0439ae23ad4c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/graph_module.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/lower_to_fbgemm.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/lower_to_fbgemm.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..639f18b2bfbcacdb2cdd33a0fba9120c305677ba Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/lower_to_fbgemm.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/lower_to_qnnpack.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/lower_to_qnnpack.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..859dfd567710176b88bc4bd30bc2f9e4e3703d66 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/lower_to_qnnpack.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/lstm_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/lstm_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..035c2a734ab3dd6aecfed2c341b7fb046aa7b16f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/lstm_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/match_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/match_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53aa7f9c752399be8715e7a448ee2145db8db876 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/match_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/pattern_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/pattern_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b5cd6527a54ea91d9d53e4c254b9e83c89b76aa Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/pattern_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/prepare.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/prepare.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b66272334ad4c3edb52cfdacb8ae070c4915033c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/prepare.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/qconfig_mapping_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/qconfig_mapping_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..942230b6c0f2ac4b69c5fd087e3c4ee0805fd1e8 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/qconfig_mapping_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/quantize_handler.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/quantize_handler.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..78c858b3f72ec163187a6af82a13195bf0b403a5 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/quantize_handler.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/tracer.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/tracer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..249405e5e4d24ba2ba665311579608cd2196192f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/tracer.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f42c5a1419bf918f19a801cef0a115ec86a4185 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1db4f67d532eea702296f2b7064c434662d42017 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/detector.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/detector.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c11f6ddd89dd531e1193e4e85eb2ebc68429497c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/detector.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/model_report.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/model_report.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..42a30a7fb8269099e4411c823aec3c07f1a30421 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/model_report.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/model_report_observer.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/model_report_observer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91021f4f76b7004cccc25a080f3bfb771c41bf3c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/model_report_observer.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/model_report_visualizer.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/model_report_visualizer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a4c40727681de4df48f7b97baebe5266f16925a9 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/model_report_visualizer.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/detector.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/detector.py new file mode 100644 index 0000000000000000000000000000000000000000..0a48bbbaaee901871d41396e0583642c4d486dce --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/detector.py @@ -0,0 +1,1743 @@ +# mypy: allow-untyped-defs +from abc import ABC, abstractmethod +from collections.abc import Callable +from typing import Any + +import torch +import torch.ao.nn.qat as nnqat +import torch.nn as nn +from torch.ao.quantization.fake_quantize import FakeQuantize +from torch.ao.quantization.fx._equalize import ( + default_equalization_qconfig, + EqualizationQConfig, +) +from torch.ao.quantization.fx._model_report.model_report_observer import ( + ModelReportObserver, +) +from torch.ao.quantization.fx.graph_module import GraphModule +from torch.ao.quantization.observer import ( + _is_activation_post_process, + default_dynamic_quant_observer, + default_observer, + default_per_channel_weight_observer, + default_weight_observer, + ObserverBase, +) +from torch.ao.quantization.qconfig import ( + _assert_valid_qconfig, + default_qconfig, + QConfig, +) + + +# Names for observer insert keys +DETECTOR_TARGET_NODE_KEY = "target_node" +DETECTOR_OBS_TO_INSERT_KEY = "observer_to_insert" +DETECTOR_IS_POST_OBS_KEY = "is_post_observer" +DETECTOR_OBS_ARGS_KEY = "observer_args" + + +# Mapping related code +class DetectorQConfigInfo: + r""" + This class contains the QConfig information for a single module. + The list of variables / values this contains can grow depending on the + extensibility of the qconfig mapping feature set but this currently includes: + - if activation observer is dynamic + - if weight observer is per channel + + + Args: + module_fqn (str): The fully qualified name (fqn) of the module that this + information contains info relevant to qconfig for + """ + + def __init__(self, module_fqn: str): + super().__init__() + self.module_fqn = module_fqn + + # populate this section with all the variables we might find important + # change from none if your detector is actually using this + self.is_activation_dynamic = False + self.is_weight_per_channel = False + + # equalization related options + self.is_equalization_recommended = False + + def generate_quantization_qconfig(self, module: torch.nn.Module) -> QConfig: + r""" + Args: + module (torch.nn.Module) The module we are generating + the qconfig for + + Returns the generated quantization QConfig according to what a valid configuration is + """ + # Apply suggestions to new qconfig + module_qconfig = default_qconfig + + # keep track of dynamic and per_channel recommendations + recommendations_list = [] + # append as if a list of combinations + recommendations_list.append( + (self.is_activation_dynamic, self.is_weight_per_channel) + ) + recommendations_list.append( + (self.is_activation_dynamic, False) + ) # only trying dynamic rec + recommendations_list.append( + (False, self.is_weight_per_channel) + ) # only trying dynamic + + # now we try each of the combinations + for rec in recommendations_list: + # rec[0] -> dynamic recommended + # rec[1] -> per channel recommended + activation = default_dynamic_quant_observer if rec[0] else default_observer + weight = ( + default_per_channel_weight_observer + if rec[1] + else default_weight_observer + ) + test_config = QConfig(activation, weight) + try: + _assert_valid_qconfig(test_config, module) + module_qconfig = test_config + break + except AssertionError: + # if not a valid configuration, we move on to the next one in priority + continue + + # return the QConfig chosen + return module_qconfig + + def generate_equalization_qconfig(self) -> EqualizationQConfig: + r""" + This returns the equalization configuration for a module. + + For now, it just returns the default, but as more equalization options become + possible, this method can get more fleshed out with more nuanced granularity. + + + Returns the generated equalization QConfig according to what a valid configuration is + """ + # in this case, we just return default equalization config + # we know this is valid because only valid modules would even + # have this option + return default_equalization_qconfig + + +# Adding base class for detectors +class DetectorBase(ABC): + r"""Base Detector Module + Any detector class should derive from this class. + + Concrete detectors should follow the same general API, which includes: + - A method to calculate and return observer insertion points + - Should return both the fqns and the Observer class to insert + - A method to return a report based on the detector + - Should return a str-based report and dict info in Tuple[str,Dict] format + """ + + def __init__(self) -> None: + super().__init__() + self.detector_config_info = None + + @abstractmethod + def determine_observer_insert_points(self, model) -> dict: + r""" + Args + model (nn.Module or subclass): model to find observer insertion points + + Returns a Dict mapping from unique observer fqns (where we want to insert them) to a Dict. + This dict maps string keys to detector specific information + """ + + @abstractmethod + def get_detector_name(self) -> str: + r"""Returns the name of the current detector""" + + @abstractmethod + def get_qconfig_info(self, model) -> dict[str, DetectorQConfigInfo]: + r"""Returns the DetectorQConfigInfo for each module_fqn relevant + Args + model (nn.Module or subclass): model to find observer insertion points + + Returns a Dict mapping from unique observer fqns (where we want to insert them) to: + A DetectorQConfigInfo with the information to generate a QConfig for a specific module + """ + + def _get_targeting_node( + self, prepared_fx_model: GraphModule, target_fqn: str + ) -> torch.fx.node.Node: + r""" + Takes in a GraphModule and the target_fqn and finds the node whose target is this fqn. + + If it's not found, it means it is most likely inside a fused layer + We just go one layer up in terms of the fqn we are searching for until we find parent node + If we get to empty string, then we know that it doesn't exist + + The reason for the recursion is that if the model that we are looking for got fused, + we will have module fqn as e.g. x.linear.0 but the graph will only have a node for the fused module, + which would have fqn as x.linear so they will not match. + To handle this, if we don't match, we then take off the last bit of the fqn e.g. x.linear.0 -> x.linear, + or more generally foo.bar.baz -> foo.bar and search again, this will allow us to locate the correct module + even in cases with fusion + + Args: + prepared_fx_model (GraphModule): The prepared Fx GraphModule + target_fqn (str): The fqn of the layer we are trying to target + + Returns the node object we are trying to add observers around + """ + for node in prepared_fx_model.graph.nodes: + # if the node's target is our target, return it + if node.target == target_fqn: + return node + + # getting here means node not found + # if no "." we are already at base and failed + parent_fqn_sep_index = target_fqn.rfind(".") + if parent_fqn_sep_index == -1: + raise ValueError("passed in target_fqn not found in graph's targets.") + else: + # recursively call it with parent fqn + return self._get_targeting_node( + prepared_fx_model, target_fqn[:parent_fqn_sep_index] + ) + + @abstractmethod + def generate_detector_report(self, model) -> tuple[str, dict[str, Any]]: + r""" + Args + model (nn.Module or subclass): model to find observer insertion points + + Returns a Tuple of two elements: + Str: string report of the suggested improvements + Dict: contains useful data collected by the observer pertinent to this report + """ + + +class PerChannelDetector(DetectorBase): + r"""This class is used to detect if any Linear or Conv layers in a model utilize per_channel quantization. + Only Linear and Conv layers can use per_channel as of now so only these two are currently checked. + + per_channel quantization can lead to major benefits in the form of accuracy. + Therefore, if the backend used by the user supports it, it is recommended to use + + Args: + backend (str, optional): the backend the user wishes to use in production + Default value is current torch.backends.quantized.engine + """ + + # Keys for return dictionary + BACKEND_KEY = "backend" + PER_CHAN_SUPPORTED_KEY = "per_channel_quantization_supported" + PER_CHAN_USED_KEY = "per_channel_quantization_used" + + # Default map for representing supported per channel quantization modules for different backends + DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES: dict[str, set[Any]] = { + "fbgemm": { + nn.Linear, + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + nnqat.Linear, + nnqat.Conv1d, + nnqat.Conv2d, + nnqat.Conv3d, + }, + "qnnpack": { + nn.Linear, + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + nnqat.Linear, + nnqat.Conv1d, + nnqat.Conv2d, + nnqat.Conv3d, + }, + "onednn": { + nn.Linear, + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + nnqat.Linear, + nnqat.Conv1d, + nnqat.Conv2d, + nnqat.Conv3d, + }, + "x86": { + nn.Linear, + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + nnqat.Linear, + nnqat.Conv1d, + nnqat.Conv2d, + nnqat.Conv3d, + }, + } + + def __init__(self, backend: str = torch.backends.quantized.engine): + super().__init__() + + # store the backend information + self.backend_chosen = backend + self.supported_modules = set() + if self.backend_chosen in self.DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES: + self.supported_modules = self.DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES[ + self.backend_chosen + ] + else: + raise ValueError( + f"Not configured to work with {self.backend_chosen}. Try a different default backend" + ) + + def get_detector_name(self) -> str: + r"""returns the string name of this detector""" + return "per_channel_detector" + + def get_qconfig_info(self, model) -> dict[str, DetectorQConfigInfo]: + r"""Returns the DetectorQConfigInfo for each module_fqn relevant + Args + model (nn.Module or subclass): model to find observer insertion points + + Returns a Dict mapping from unique observer fqns (where we want to insert them) to: + A DetectorQConfigInfo with the information to generate a QConfig for a specific module + """ + # run the helper function to populate the dictionary + per_channel_info = self._detect_per_channel_helper(model) + + # we actually have a qconfig info object we are populating + module_fqn_to_detector_qconfig_info = {} + + for module_fqn in per_channel_info: + # create a detector info instance + detector_qconfig_info = DetectorQConfigInfo(module_fqn) + + # see if per channel quantization is supported + per_chan_supported: bool = per_channel_info[module_fqn][ + self.PER_CHAN_SUPPORTED_KEY + ] + detector_qconfig_info.is_weight_per_channel = per_chan_supported + module_fqn_to_detector_qconfig_info[module_fqn] = detector_qconfig_info + + return module_fqn_to_detector_qconfig_info + + def determine_observer_insert_points(self, model: nn.Module) -> dict: + r""" + There is no observers inserted for the PerChannelDetector. + + Returns an empty dictionary since no observers are added or needed + """ + return {} + + def _detect_per_channel_helper(self, model: nn.Module): + r""" + determines if per_channel quantization is supported in modules and submodules. + + Returns a dictionary in the higher level _detect_per_channel function. + Each entry maps the fully-qualified-name to information on whether per_channel quantization. + + Args: + model: The current module that is being checked to see if it is per_channel quantizable + + Returns dictionary mapping fqns to if per_channel quantization is possible + """ + # create dict we will return + per_channel_info: dict = {} + + # get the fully qualified name and check if in list of modules to include and list of modules to ignore + for fqn, module in model.named_modules(): + is_in_include_list = any( + isinstance(module, x) for x in self.supported_modules + ) + + # check if the module per_channel is supported + # based on backend + per_channel_supported = False + + if is_in_include_list: + per_channel_supported = True + + # assert statement for MyPy + q_config_file = module.qconfig + if not isinstance(q_config_file, QConfig): + raise AssertionError("module.qconfig must be a QConfig") + + # this object should either be fake quant or observer + q_or_s_obj = module.qconfig.weight.p.func() + if not isinstance(q_or_s_obj, (FakeQuantize, ObserverBase)): + raise AssertionError( + "module.qconfig.weight must be a FakeQuantize or ObserverBase" + ) + + per_channel_used = False # will be true if found in qconfig + + if hasattr( + q_or_s_obj, "ch_axis" + ): # then we know that per_channel quantization used + # all fake quants have channel axis so need to check is_per_channel + if isinstance(q_or_s_obj, FakeQuantize): + if ( + hasattr(q_or_s_obj, "is_per_channel") + and q_or_s_obj.is_per_channel + ): + per_channel_used = True + elif isinstance(q_or_s_obj, ObserverBase): + # should be an observer otherwise + per_channel_used = True + else: + raise ValueError("Should be either observer or fake quant") + + per_channel_info[fqn] = { + self.PER_CHAN_SUPPORTED_KEY: per_channel_supported, + self.PER_CHAN_USED_KEY: per_channel_used, + self.BACKEND_KEY: self.backend_chosen, + } + + return per_channel_info + + def generate_detector_report(self, model: nn.Module) -> tuple[str, dict[str, Any]]: + r"""Checks if any Linear or Conv layers in the model utilize per_channel quantization. + Only Linear and Conv layers can use per_channel as of now so only these two are currently checked. + + Looks at q_config format and backend to determine if per_channel can be utilized. + Uses the DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES structure to determine support + + Args: + model: The prepared and calibrated model we want to check if using per_channel + + Returns a tuple with two elements: + String report of potential actions to improve model (if per_channel quantization is available in backend) + Dictionary mapping per_channel quantizable elements to: + whether per_channel quantization is supported by the backend + if it is being utilized in the current model + """ + + # run the helper function to populate the dictionary + per_channel_info = self._detect_per_channel_helper(model) + + # String to let the user know of further optimizations + further_optims_str = ( + f"Further Optimizations for backend {self.backend_chosen}: \n" + ) + + optimizations_possible = False + for fqn in per_channel_info: + fqn_dict = per_channel_info[fqn] + if ( + fqn_dict[self.PER_CHAN_SUPPORTED_KEY] + and not fqn_dict[self.PER_CHAN_USED_KEY] + ): + optimizations_possible = True + further_optims_str += ( + f"Module {fqn} can be configured to use per_channel quantization.\n" + ) + + if optimizations_possible: + further_optims_str += "To use per_channel quantization, make sure the qconfig has a per_channel weight observer." + else: + further_optims_str += "No further per_channel optimizations possible." + + # return the string and the dictionary form of same information + return (further_optims_str, per_channel_info) + + +class DynamicStaticDetector(DetectorBase): + r""" + Determines whether dynamic or static quantization is more appropriate for a given module. + + Takes advantage of the ModelReportObserver that records range information. + Stationary distribution of data are strictly above tolerance level for the comparison statistic: + + S = average_batch_activation_range/epoch_activation_range + + Nonstationary distributions are below or at the tolerance level for this metric. + + If the distribution of data right after the module is non-stationary, recommend dynamic quantization + Otherwise recommend static quantization + + Args: + tolerance (float, optional): The threshold where S metric is stationary above and non-stationary otherwise. Default: 0.5 + """ + + # names for the pre and post observers that are inserted + DEFAULT_PRE_OBSERVER_NAME = "model_report_pre_observer" + DEFAULT_POST_OBSERVER_NAME = "model_report_post_observer" + + # naming conventions for stationary vs non-stationary data + STATIONARY_STR = "stationary" + NON_STATIONARY_STR = "non-stationary" + + # naming for activation + INPUT_ACTIVATION_PREFIX = "input_activation_" + OUTPUT_ACTIVATION_PREFIX = "output_activation_" + + # naming conventions for the keys of the return module info + TOLERANCE_KEY = "dynamic_static_tolerance" + DEFAULT_DYNAMIC_REC_KEY = "dynamic_recommended" + PRE_OBS_COMP_STAT_KEY = INPUT_ACTIVATION_PREFIX + "dynamic_static_comp_stat" + POST_OBS_COMP_STAT_KEY = OUTPUT_ACTIVATION_PREFIX + "dynamic_static_comp_stat" + PRE_OBS_DATA_DIST_KEY = ( + INPUT_ACTIVATION_PREFIX + "dynamic_static_data_classification" + ) + POST_OBS_DATA_DIST_KEY = ( + OUTPUT_ACTIVATION_PREFIX + "dynamic_static_data_classification" + ) + IS_CURRENTLY_SUPPORTED_KEY = "is_dynamic_supported" + + # modules that are supported both dynamic and static for this report function + DEFAULT_DYNAMIC_STATIC_CHECK_SUPPORTED = {nn.Linear} + + # modules that will be supported soon for both + DEFAULT_DYNAMIC_STATIC_FUTURE_SUPPORTED = {nn.Conv1d, nn.Conv2d, nn.Conv3d} + + def __init__(self, tolerance=0.5): + super().__init__() + + # set tolerance level and initialize a set to keep track of useful fqn locations + self.tolerance = tolerance + self.useful_observer_fqns: set[str] = set() + + def determine_observer_insert_points( + self, prepared_fx_model: GraphModule + ) -> dict[str, dict[str, Any]]: + r""" + Determines where observers need to be inserted for the Dynamic vs Static detector. + For this detector, we want to place observers on either side of linear layers in the model. + + Currently inserts observers for: + linear layers + + Args: + prepared_fx_model (GraphModule): The prepared Fx GraphModule + + Returns a Dict mapping from unique observer fqns (where we want to insert them) to a Dict with: + key "target_node" -> the node we are trying to observe with this observer (torch.fx.node.Node) + key "observer_to_insert" -> the observer we wish to insert (ObserverBase) + key "is_post_observer" -> True if this is meant to be a post-observer for target_node, False if pre-observer + key "observer_args" -> The arguments that are meant to be passed into the observer + """ + + # observer for this detector is ModelReportObserver + obs_ctr = ModelReportObserver + + # return dict + obs_fqn_to_info: dict[str, dict[str, Any]] = {} + + for fqn, module in prepared_fx_model.named_modules(): + # make sure module is supported + if self._is_supported(module, insert=True): + # if it's a supported type, we want to get node and add observer insert locations + targeted_node = self._get_targeting_node(prepared_fx_model, fqn) + + # add entry for pre-observer + pre_obs_fqn = fqn + "." + self.DEFAULT_PRE_OBSERVER_NAME + + obs_fqn_to_info[pre_obs_fqn] = { + DETECTOR_TARGET_NODE_KEY: targeted_node, + DETECTOR_OBS_TO_INSERT_KEY: obs_ctr(), + DETECTOR_IS_POST_OBS_KEY: False, + DETECTOR_OBS_ARGS_KEY: targeted_node.args, + } + + # add entry for post-observer + post_obs_fqn = fqn + "." + self.DEFAULT_POST_OBSERVER_NAME + + obs_fqn_to_info[post_obs_fqn] = { + DETECTOR_TARGET_NODE_KEY: targeted_node, + DETECTOR_OBS_TO_INSERT_KEY: obs_ctr(), + DETECTOR_IS_POST_OBS_KEY: True, + DETECTOR_OBS_ARGS_KEY: (targeted_node,), + } + + return obs_fqn_to_info + + def get_detector_name(self) -> str: + r"""returns the string name of this detector""" + return "dynamic_vs_static_detector" + + def get_qconfig_info(self, model) -> dict[str, DetectorQConfigInfo]: + r"""Returns the DetectorQConfigInfo for each module_fqn relevant + Args + model (nn.Module or subclass): model to find observer insertion points + + Returns a Dict mapping from unique observer fqns (where we want to insert them) to: + A DetectorQConfigInfo with the information to generate a QConfig for a specific module + """ + # run the helper function to populate the dictionary + dynamic_static_info = self._generate_dict_info(model) + + # we actually have a qconfig info object we are populating + module_fqn_to_detector_qconfig_info = {} + + for module_fqn in dynamic_static_info: + # create a detector info instance + detector_qconfig_info = DetectorQConfigInfo(module_fqn) + + # see if per channel quantization is supported + dynamic_static_recommended: bool = dynamic_static_info[module_fqn][ + self.DEFAULT_DYNAMIC_REC_KEY + ] + detector_qconfig_info.is_activation_dynamic = dynamic_static_recommended + module_fqn_to_detector_qconfig_info[module_fqn] = detector_qconfig_info + + return module_fqn_to_detector_qconfig_info + + def _is_supported(self, module: nn.Module, insert: bool = False) -> bool: + r"""Returns whether the given module is supported for observers + + Args + module: The module to check and ensure is supported + insert: True if this is check for observer insertion, false if for report gen + + Returns True if the module is supported by observer, False otherwise + """ + # check to see if module is of a supported type + is_supported_type = any( + isinstance(module, x) for x in self.DEFAULT_DYNAMIC_STATIC_CHECK_SUPPORTED + ) + + # check if it will be supported + future_supported_type = any( + isinstance(module, x) for x in self.DEFAULT_DYNAMIC_STATIC_FUTURE_SUPPORTED + ) + + # supported + supported = is_supported_type or future_supported_type + + # this is check for observer insertion + if insert: + return supported + else: + # this is for report gen and we also need to check if it contains observers + has_obs = hasattr(module, self.DEFAULT_PRE_OBSERVER_NAME) and hasattr( + module, self.DEFAULT_POST_OBSERVER_NAME + ) + return supported and has_obs + + def _generate_dict_info(self, model: GraphModule) -> dict[str, Any]: + r""" + Helper function for generate_detector_report that does the generation of the dictionary. + This process is done as specified in generate_detector_report documentation + + Args: + model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers + + Returns a Dictionary mapping modules with ModelReportObservers around them to: + whether dynamic quantization is recommended + their S metric of input to module + whether input to module is stationary or non-stationary + their S metric of output of module + whether output of module is stationary or non-stationary + the tolerance level to decided whether input/output is stationary or non-stationary + whether it is currently supported or planned for the future + """ + # store modules dynamic vs static information + module_dynamic_static_info = {} + + # This for loop goes through the modules, and extracts all relevant information into module_dynamic_static_info + # This information primary includes whether the data distributions around a supported module is stationary or not + # Based on this, it is recorded whether dynamic or static quantization is recommended + + # loop through all submodules included nested ones + for fqn, module in model.named_modules(): + # if module is Linear has the ModelReportObserver attached to it + if self._is_supported(module): + # get pre and post observers for the module + pre_obs = getattr(module, self.DEFAULT_PRE_OBSERVER_NAME) + post_obs = getattr(module, self.DEFAULT_POST_OBSERVER_NAME) + + # get the statistics for each module + pre_stat = pre_obs.get_batch_to_epoch_ratio() + post_stat = post_obs.get_batch_to_epoch_ratio() + + # record module, pre and post stat, and whether to do dynamic or static based off it + # true if post observer data distribution is non-stationary, false if it's stationary + dynamic_recommended = post_stat <= self.tolerance + + # specify the classifications for whether data distributions considered stationary or non-stationary + pre_obs_dist_classif = ( + self.STATIONARY_STR + if pre_stat > self.tolerance + else self.NON_STATIONARY_STR + ) + post_obs_dist_classif = ( + self.STATIONARY_STR + if post_stat > self.tolerance + else self.NON_STATIONARY_STR + ) + + # check if current support or future support + is_supported_type = any( + isinstance(module, x) + for x in self.DEFAULT_DYNAMIC_STATIC_CHECK_SUPPORTED + ) + + # store the set of important information for this module + module_info = { + self.TOLERANCE_KEY: self.tolerance, + self.DEFAULT_DYNAMIC_REC_KEY: dynamic_recommended, + self.PRE_OBS_COMP_STAT_KEY: pre_stat, + self.PRE_OBS_DATA_DIST_KEY: pre_obs_dist_classif, + self.POST_OBS_COMP_STAT_KEY: post_stat, + self.POST_OBS_DATA_DIST_KEY: post_obs_dist_classif, + self.IS_CURRENTLY_SUPPORTED_KEY: is_supported_type, + } + + module_dynamic_static_info[fqn] = module_info + + return module_dynamic_static_info + + def generate_detector_report( + self, model: GraphModule + ) -> tuple[str, dict[str, Any]]: + r""" + Determines whether dynamic or static quantization is more appropriate for a given module. + + Takes advantage of the ModelReportObserver that records range information. + Stationary distribution of data are strictly above tolerance level for the comparison statistic: + + S = average_batch_activation_range/epoch_activation_range + + Nonstationary distributions are below or at the tolerance level for this metric. + + If the distribution of data right after the module is non-stationary, recommend dynamic quantization + Otherwise recommend static quantization + + This will then generate suggestions for dynamic vs static quantization focused around Linear. + + Args: + model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers + + Returns a tuple with two elements: + String report of of whether dynamic or static quantization is recommended for certain modules + Dictionary mapping modules with ModelReportObservers around them to: + whether dynamic quantization is recommended + their S metric of input to module + whether input to module is stationary or non-stationary + their S metric of output of module + whether output of module is stationary or non-stationary + the tolerance level to decided whether input/output is stationary or non-stationary + whether it is currently supported or planned for the future + """ + + # get the dictionary of the information to format the string report + module_dynamic_static_info = self._generate_dict_info(model) + + dynamic_vs_static_string = "Dynamic vs. Static Quantization suggestions: \n" + + modules_added: bool = False # check to make sure at least 1 module added. + + dynamic_benefit = ( + " You will get more accurate results if you use dynamic quantization" + ) + static_benefit = ( + " You can increase model efficiency if you use static quantization" + ) + future_support_str = ( + ". This layer is not yet supported for dynamic quantization" + ) + # This for loop goes through the information collected in module_dynamic_static_info and: + # Populates the string based report with the information from module_dynamic_static_info + # Compiles the complete report by appending relevant formatted strings + + for module_fqn in module_dynamic_static_info: + # there is at least 1 module for suggestion + modules_added = True + module_info = module_dynamic_static_info[module_fqn] + suggestion_string_template = ( + "For module {} it is suggested to use {} quantization because {}.\n" + ) + + # decide what string formatting values will be + quantization_type = "" + quantization_reasoning = "the distribution of data before {} is {} and the distribution after is {}." + + benefit_str = "" + + # strings for if dynamic quantized per tensor is needed + recommend_per_tensor = ( + ". We recommend to add a {} before this module if it is static." + ) + rec_lay_to_add = "dynamic quantize per tensor layer" + dynamic_per_tensor_string = recommend_per_tensor.format(rec_lay_to_add) + dynamic_per_tensor_reasoning_string = " This is because the input to this module has a non-stationary distribution" + + # start composing explanation + if module_info[self.DEFAULT_DYNAMIC_REC_KEY]: + quantization_type = "dynamic" + # check if currently supported or future supported + benefit_str = dynamic_benefit + if not module_info[self.IS_CURRENTLY_SUPPORTED_KEY]: + benefit_str += future_support_str + else: + quantization_type = "static" + benefit_str = static_benefit + + # now set the quantization explanation string + quantization_reasoning = ( + quantization_reasoning.format( + module_fqn, + module_info[self.PRE_OBS_DATA_DIST_KEY], + module_info[self.POST_OBS_DATA_DIST_KEY], + ) + + benefit_str + ) + + # if we have a non-stationary input -> linear -> stationary we suggested static + # however, we want to also recommend they add a dynamic quantize per tensor right if this change is made + if ( + module_info[self.PRE_OBS_DATA_DIST_KEY] == self.NON_STATIONARY_STR + and module_info[self.POST_OBS_DATA_DIST_KEY] == self.STATIONARY_STR + ): + quantization_reasoning = ( + quantization_reasoning + + dynamic_per_tensor_string + + dynamic_per_tensor_reasoning_string + ) + + # format the overall suggestion string with the specific inputs + module_suggestion_string = suggestion_string_template.format( + module_fqn, quantization_type, quantization_reasoning + ) + + # append to overall suggestion + dynamic_vs_static_string += module_suggestion_string + + if not modules_added: + dynamic_vs_static_string += "No applicable layers for suggestions. Only linear and conv are valid.\n" + + # return the string as well as the dictionary of information + return (dynamic_vs_static_string, module_dynamic_static_info) + + +class InputWeightEqualizationDetector(DetectorBase): + r""" + Determines whether input-weight equalization can help improve quantization for certain modules. + + Specifically, this list of modules includes: + linear + conv + + Determines whether input-weight equalization is recommended based on the comp stat: + s_c = sqrt(w_c/W)/sqrt(i_c/I) + where: + w_c is range of weight for channel c, W is range of weight over all channels + i_c is range of input for channel c, I is range of input over all channels + + if s_c >= threshold or <= 1 / threshold, recommends input-weight equalization + + Args: + ratio_threshold (float): The threshold for s_c to determine if input-weight equalization is suggested + Should be between 0 and 1 (both non-inclusive) + ch_axis (int, optional): The channel axis being observed to determine input weight equalization + Default: 1 + + * :attr:`ratio_threshold`: The threshold for s_c to determine if input-weight equalization is suggested + Should be between 0 and 1 + + * :attr:`ch_axis`: The channel axis being observed to determine input weight equalization + + * :attr:`SUPPORTED_MODULES`: This specifies the modules that are supported for input-weight equalization + + * :attr:`DEFAULT_PRE_OBSERVER_NAME`: The name of the pre-observer to be inserted for this detector + """ + + SUPPORTED_MODULES: set[Callable] = { + nn.Linear, + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + nnqat.Linear, + nnqat.Conv1d, + nnqat.Conv2d, + nnqat.Conv3d, + } + + # names for the pre and post observers that are inserted + DEFAULT_PRE_OBSERVER_NAME: str = "model_report_pre_observer" + + # weight / activation prefix for each of the below info + WEIGHT_PREFIX = "weight_" + ACTIVATION_PREFIX = "input_activation_" + + # string names for keys of info dictionaries + PER_CHANNEL_MAX_KEY = "per_channel_max" + PER_CHANNEL_MIN_KEY = "per_channel_min" + GLOBAL_MAX_KEY = "global_max" + GLOBAL_MIN_KEY = "global_min" + + # keys for return dict of recommendations + RECOMMENDED_KEY = "input_weight_equalization_recommended" + COMP_METRIC_KEY = "input_weight_channel_comparison_metrics" + THRESHOLD_KEY = "input_weight_threshold" + CHANNEL_KEY = "input_weight_channel_axis" + + # default weight and info strings + WEIGHT_STR = "weight" + INPUT_STR = "input" + + # default for what ratio we recommend input weight + DEFAULT_RECOMMEND_INPUT_WEIGHT_CHANNEL_RATIO = 0.4 + + def __init__(self, ratio_threshold: float, ch_axis: int = 1): + # ensure passed in inputs are valid + if ratio_threshold <= 0 or ratio_threshold >= 1: + raise ValueError("Make sure threshold is > 0 and < 1") + + # initialize attributes based on args + self.ratio_threshold: float = ratio_threshold + self.ch_axis: int = ch_axis + + def _is_supported(self, module: nn.Module, insert: bool = False) -> bool: + r"""Returns whether the given module is supported for observers + + Args + module: The module to check and ensure is supported + insert: True if this is check for observer insertion, false if for report gen + + Returns True if the module is supported by observer, False otherwise + """ + # check to see if module is of a supported type + is_supported_type = any(type(module) is x for x in self.SUPPORTED_MODULES) + + # this is check for observer insertion + if insert: + return is_supported_type + else: + # this is for report gen and we also need to check if it contains observers + has_obs = hasattr(module, self.DEFAULT_PRE_OBSERVER_NAME) + return is_supported_type and has_obs + + def get_qconfig_info(self, model) -> dict[str, DetectorQConfigInfo]: + r"""Returns the DetectorQConfigInfo for each module_fqn relevant + Args + model (nn.Module or subclass): model to find observer insertion points + + Returns a Dict mapping from unique observer fqns (where we want to insert them) to: + A DetectorQConfigInfo with the information to generate a QConfig for a specific module + """ + # run the helper function to populate the dictionary + # find the range of inputs + input_values: dict[str, dict] = self._extract_input_info(model) + + # find the range of weights + weight_values: dict[str, dict] = self._extract_weight_info(model) + + # calculate per_channel comparison statistic s_c + comp_stats: dict[str, torch.Tensor] = self._generate_comparison_values( + input_values, weight_values + ) + + # generate the return dictionary + input_weight_equalization_info: dict[str, dict] = self._generate_dict_info( + input_values, weight_values, comp_stats + ) + + # we actually have a qconfig info object we are populating + module_fqn_to_detector_qconfig_info = {} + + for module_fqn in input_weight_equalization_info: + # create a detector info instance + detector_qconfig_info = DetectorQConfigInfo(module_fqn) + + # see if per channel quantization is supported + input_weight_recommended: bool = input_weight_equalization_info[module_fqn][ + self.RECOMMENDED_KEY + ] + detector_qconfig_info.is_equalization_recommended = input_weight_recommended + module_fqn_to_detector_qconfig_info[module_fqn] = detector_qconfig_info + + return module_fqn_to_detector_qconfig_info + + def determine_observer_insert_points( + self, prepared_fx_model: GraphModule + ) -> dict[str, dict[str, Any]]: + r"""Determines where observers need to be inserted for the Input Weight Equalization Detector. + For this detector, we want to place observers in front of supported layers. + + Currently inserts observers for: + linear layers + conv layers + + Args: + prepared_fx_model (GraphModule): The prepared Fx GraphModule + + Returns a Dict mapping from unique observer fqns (where we want to insert them) to a Dict with: + key "target_node" -> the node we are trying to observe with this observer (torch.fx.node.Node) + key "observer_to_insert" -> the observer we wish to insert (ObserverBase) + key "is_post_observer" -> True if this is meant to be a post-observer for target_node, False if pre-observer + key "observer_args" -> The arguments that are meant to be passed into the observer + """ + + # observer for this detector is ModelReportObserver + obs_ctr = ModelReportObserver + + # return dict + obs_fqn_to_info: dict[str, dict[str, Any]] = {} + + for fqn, module in prepared_fx_model.named_modules(): + # check to see if module is of a supported type + if self._is_supported(module, insert=True): + # if it's a supported type, we want to get node and add observer insert locations + targeted_node = self._get_targeting_node(prepared_fx_model, fqn) + + # add entry for pre-observer + pre_obs_fqn = fqn + "." + self.DEFAULT_PRE_OBSERVER_NAME + + obs_fqn_to_info[pre_obs_fqn] = { + DETECTOR_TARGET_NODE_KEY: targeted_node, + DETECTOR_OBS_TO_INSERT_KEY: obs_ctr(ch_axis=self.ch_axis), + DETECTOR_IS_POST_OBS_KEY: False, + DETECTOR_OBS_ARGS_KEY: targeted_node.args, + } + + return obs_fqn_to_info + + def get_detector_name(self) -> str: + r"""Returns the name of this detector""" + return "input_weight_equalization_detector" + + def _extract_input_info(self, model: GraphModule) -> dict[str, dict]: + r""" + Takes in a calibrated GraphModule and then finds the relevant observers. + It then extracts the input information for each observer returns it + + Args + model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers + + Returns a dict mapping relevant module fqns (str) to a dict with keys: + "input_activation_per_channel_max" : maps to the per_channel max values + "input_activation_per_channel_min" : maps to the per_channel min values + "input_activation_global_max" : maps to the global max recorded + "input_activation_global_min" : maps to the global min recorded + """ + + # return dictionary mapping observer fqns to desired info + input_info: dict[str, dict] = {} + + for fqn, module in model.named_modules(): + # if module is supported and it has a pre-observer + if self._is_supported(module): + # get pre observer for the module + pre_obs = getattr(module, self.DEFAULT_PRE_OBSERVER_NAME) + + input_info[fqn] = { + self.ACTIVATION_PREFIX + self.PER_CHANNEL_MAX_KEY: pre_obs.max_val, + self.ACTIVATION_PREFIX + self.PER_CHANNEL_MIN_KEY: pre_obs.min_val, + self.ACTIVATION_PREFIX + self.GLOBAL_MAX_KEY: max(pre_obs.max_val), + self.ACTIVATION_PREFIX + self.GLOBAL_MIN_KEY: min(pre_obs.min_val), + } + + return input_info + + def _extract_weight_info(self, model: GraphModule) -> dict[str, dict]: + r""" + Takes in a calibrated GraphModule and then finds the relevant observers. + It then extracts the weight information for each layer an observer is attached to. + + Args + model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers + + Returns a dict mapping module fqns (str) to a dict with keys: + "per_channel_max" : maps to the per_channel max values + "per_channel_min" : maps to the per_channel min values + "global_max" : maps to the global max recorded + "global_min" : maps to the global min recorded + """ + # return dictionary mapping observer fqns to desired info + weight_info: dict[str, dict] = {} + + for fqn, module in model.named_modules(): + # if module is supported and it has a pre-observer + if self._is_supported(module): + # we don't need actual observer, just the module weights + # calculate min and max vals + device = module.weight.device + min_val: torch.Tensor = torch.tensor([float("inf")], device=device) + max_val: torch.Tensor = torch.tensor([float("-inf")], device=device) + x_copy = module.weight + x_dim = x_copy.size() + + new_axis_list = [i for i in range(len(x_dim))] # noqa: C416 + new_axis_list[self.ch_axis] = 0 + new_axis_list[0] = self.ch_axis + y = x_copy.permute(new_axis_list) + + # Need to match dtype of min/max because the updates to buffers + # are done in place and types need to match for comparisons + y = y.to(min_val.dtype) + y = torch.flatten(y, start_dim=1) + if min_val.numel() == 0 or max_val.numel() == 0: + min_val, max_val = torch.aminmax(y, dim=1) + else: + min_val_cur, max_val_cur = torch.aminmax(y, dim=1) + min_val = torch.min(min_val_cur, min_val) + max_val = torch.max(max_val_cur, max_val) + + weight_info[fqn] = { + self.WEIGHT_PREFIX + self.PER_CHANNEL_MAX_KEY: max_val, + self.WEIGHT_PREFIX + self.PER_CHANNEL_MIN_KEY: min_val, + self.WEIGHT_PREFIX + self.GLOBAL_MAX_KEY: max(max_val), + self.WEIGHT_PREFIX + self.GLOBAL_MIN_KEY: min(min_val), + } + + return weight_info + + def _calculate_range_ratio( + self, info_dict: dict, info_str: str, module_fqn: str + ) -> torch.Tensor: + r""" + Takes in an info dict and calculates the s_c matrix. + + Args: + info_dict (dict): A dictionary of either input or weight range info + info_str (str): A str describing whether currently looking at weight or input info + Either "weight" or "input" + module_fqn (str): The fqn of the module we are looking at + + Returns a tensor of values, where each value is the s_c stat for a different channel + """ + # calculate the ratios of the info + # get the prefix str + prefix_str = ( + self.ACTIVATION_PREFIX if info_str == self.INPUT_STR else self.WEIGHT_PREFIX + ) + + per_channel_range = ( + info_dict[prefix_str + self.PER_CHANNEL_MAX_KEY] + - info_dict[prefix_str + self.PER_CHANNEL_MIN_KEY] + ) + global_range = ( + info_dict[prefix_str + self.GLOBAL_MAX_KEY] + - info_dict[prefix_str + self.GLOBAL_MIN_KEY] + ) + + if global_range == 0: + range_zero_explanation = "We recommend removing this channel as it doesn't provide any useful information." + raise ValueError( + f"The range of the {info_str} data for module {module_fqn} is 0, " + f"which means you have a constant value channel. {range_zero_explanation}" + ) + + ratio = per_channel_range / global_range + + return ratio + + def _generate_comparison_values( + self, input_info: dict, weight_info: dict + ) -> dict[str, torch.Tensor]: + r""" + Takes in the information on the min and max values of the inputs and weights and: + Calculates the comp stat for each channel: s_c = sqrt(w_c/W)/sqrt(i_c/I) + + Args: + input_info (dict): A dict mapping each observer to input range information + weight_info (dict): A dict mapping each observer to weight range information + + Returns a dict mapping relevant observer fqns (str) to a 1-D tensor. + Each value is a different s_c value for a different channel + """ + # create return dictionary for each observer + module_fqn_to_channel: dict[str, torch.Tensor] = {} + + # for each module (both passed in dicts should have same keys) + for module_fqn in input_info: + # raise error if not in weight info + if module_fqn not in weight_info: + raise KeyError( + f"Unable to find weight range stats for module {module_fqn}" + ) + + # calculate the ratios of the weight info and input info + weight_ratio = self._calculate_range_ratio( + weight_info[module_fqn], self.WEIGHT_STR, module_fqn + ) + input_ratio = self._calculate_range_ratio( + input_info[module_fqn], self.INPUT_STR, module_fqn + ) + + # if mismatched size, because of grouping, we want to replicate weight enough times + weight_channels = len(weight_ratio) + input_channels = len(input_ratio) + if weight_channels != input_channels: + # we try to replicate + if input_channels % weight_channels != 0: + raise AssertionError( + "input channels should be divisible by weight channels." + ) + # get replication factor + rep_factor: int = input_channels // weight_channels + + # weight ratio is (n,), input ratio is (k,), we just repeat weight ratio k // n + weight_ratio = weight_ratio.repeat(rep_factor) + + # calculate the s metric per channel + s = torch.sqrt(weight_ratio) / torch.sqrt(input_ratio) + module_fqn_to_channel[module_fqn] = s + + # return compiled observer ratios + return module_fqn_to_channel + + def _generate_dict_info( + self, input_info: dict, weight_info: dict, comp_stats: dict + ) -> dict[str, dict]: + r""" + Helper function for generate_detector_report that does the generation of the dictionary. + This process is done as specified in generate_detector_report documentation + + Args: + input_info (dict): A dict mapping each module to input range information + weight_info (dict): A dict mapping each module to weight range information + comp_stats (dict): A dict mapping each module to its corresponding comp stat + + Returns a dictionary mapping each module with relevant ModelReportObservers around them to: + whether input weight equalization is recommended + their s_c metric compared to the threshold + the threshold used to make the recommendation + the channel used for recording data + the input channel range info + the weight channel range info + """ + # store modules input weight equalization info + input_weight_equalization_info: dict[str, dict] = {} + + # for each module we add separate set of suggestions + for module_fqn in input_info: + # get relevant info for this module + mod_input_info: dict = input_info[module_fqn] + mod_weight_info: dict = weight_info[module_fqn] + mod_comp_stat: dict = comp_stats[module_fqn] + + # decide if each channel should have input weight equalization or not + channel_rec_vals: list = [] + + for val in mod_comp_stat: + float_rep: float = val.item() + + # decide if recommending input weight equalization + recommended: bool = ( + float_rep >= self.ratio_threshold + and float_rep <= 1 / self.ratio_threshold + ) + channel_rec_vals.append(recommended) + + # build the return dict input + # also unpack input and weight dicts into it + input_weight_equalization_info[module_fqn] = { + self.RECOMMENDED_KEY: channel_rec_vals, + self.COMP_METRIC_KEY: mod_comp_stat, + self.THRESHOLD_KEY: self.ratio_threshold, + self.CHANNEL_KEY: self.ch_axis, + **mod_input_info, + **mod_weight_info, + } + + # return our compiled info for each module + return input_weight_equalization_info + + def generate_detector_report( + self, model: GraphModule + ) -> tuple[str, dict[str, Any]]: + r""" + Determines whether input weight equalization is appropriate for a given module. + + Takes advantage of the ModelReport Observer which records per channel information of input range + It then uses the passed in weight info inconjunction to compute the desired ratio + Finally, it gives suggestions based on this information for each module of interest + + Args: + model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers + + Returns a tuple with two elements: + String report of of whether input weight equalization is recommended for certain modules + Dictionary mapping modules of interest to: + whether input weight equalization is recommended + their s_c metric compared to the threshold + the threshold used to make the recommendation + the channel used for recording data + the input channel range info + the weight channel range info + """ + + # find the range of inputs + input_values: dict[str, dict] = self._extract_input_info(model) + + # find the range of weights + weight_values: dict[str, dict] = self._extract_weight_info(model) + + # calculate per_channel comparison statistic s_c + comp_stats: dict[str, torch.Tensor] = self._generate_comparison_values( + input_values, weight_values + ) + + # generate the return dictionary + input_weight_equalization_info: dict[str, dict] = self._generate_dict_info( + input_values, weight_values, comp_stats + ) + + # now we can generate report based on this information + input_weight_string = "Input-Weight Equalization suggestions: \n" + + # some strings to be formatted depending on module we are adding + module_suggestion_str = "For Module {} looked at with axis {}: \n" + channel_suggestion_str = ( + "\tWe suggest {} input weight equalization because {}\n" + ) + use_str = "to use" + no_use_str = "to not use" + input_weight_benefit_str = "{}/{} channels would benefit and we expect significant reduction in quantization error." + input_weight_non_benefit_reasoning = ( + "{}/{} channels benefitting from input-weight equalization being applied." + ) + input_weight_non_benefit_str = "we don't expect much improvement from input-weight equalization based on {}" + + # added module check + added_module: bool = False + + # compile the suggestion string + for module_fqn in input_weight_equalization_info: + # we added at least 1 module + added_module = True + # add the module level description + input_weight_string += module_suggestion_str.format( + module_fqn, self.ch_axis + ) + + mod_info: dict[str, Any] = input_weight_equalization_info[module_fqn] + + # gather info on how many channels would benefit from input weight and + recommendation_per_channel: torch.Tensor = mod_info[self.RECOMMENDED_KEY] + num_recs = sum(recommendation_per_channel) + + if ( + num_recs / len(recommendation_per_channel) + >= self.DEFAULT_RECOMMEND_INPUT_WEIGHT_CHANNEL_RATIO + ): + input_benefit_formatted = input_weight_benefit_str.format( + num_recs, len(recommendation_per_channel) + ) + channel_str = channel_suggestion_str.format( + use_str, input_benefit_formatted + ) + input_weight_string += channel_str + else: + non_benefit_reason_formatted = ( + input_weight_non_benefit_reasoning.format( + num_recs, len(recommendation_per_channel) + ) + ) + non_benefit_str = input_weight_non_benefit_str.format( + non_benefit_reason_formatted + ) + channel_str = channel_suggestion_str.format(no_use_str, non_benefit_str) + input_weight_string += channel_str + + # if no modules looked at, amend return string + if not added_module: + input_weight_string += ( + "No applicable layers for suggestions. Only linear and conv valid.\n" + ) + + # return a tuple with the string explanation and the compiled dict info + return (input_weight_string, input_weight_equalization_info) + + +class OutlierDetector(DetectorBase): + r""" + Determines whether there are significant outliers in activation data around a certain layer. + + This is ideally used in conjunction with information on stationary vs. non-stationary distribution: + If the data is stationary, and there are significant outliers, then we want to flag them + We want to do this on a per channel basis for detecting outliers + + Determines whether activation data is flagged as outlier based on if data is stationary and: + p_r = avg(100th percentile / "reference_percentile"th percentile) + where: + p_r is average percentile ratio across all batches in the epoch + reference_percentile is a percentile values between 0 and 100 exclusive + + if p_r is above some threshold, then we consider the activations to have significant outliers + + Args: + ratio_threshold (float, optional): The threshold for p_r to determine if there are outliers in activations + Should be >= 1 + Default: 3.5 + reference_percentile (float, optional): The denominator to find the relative scale of the 100th percentile + Should be between 0 and 1 + Default: 0.975 + fraction_batches_used_threshold (float, optional): Threshold of fraction of batches per channel to determine outlier + If fraction is below this, we deem number of samples used to calculate outliers as insignificant and alert user + regardless of whether we detected outliers or not in channel to take a closer look at channel results + Should be between 0 and 1 + Default: 0.95 + ch_axis (int, optional): The channel axis being observed to determine input weight equalization + Default: 1 + + * :attr:`ratio_threshold`: The threshold for p_r to determine if there are outliers in activations + The p_r value (average ratio of 100th percentile/reference_percentile) is compared to ratio_threshold + If it is significantly greater, then we consider it an outlier + This threshold was calculated based on the ratio of the percentiles in a normal distribution + The calculations behind value choice: https://drive.google.com/file/d/1N2wdtXWI-kOH8S7HH4-PYB_NmqzZil4p/view?usp=sharing + + * :attr:`reference_percentile`: The denominator of the top fraction to find the relative scale of the 100th percentile + Should be between 0 and 1 + The calculations behind value choice: https://drive.google.com/file/d/1N2wdtXWI-kOH8S7HH4-PYB_NmqzZil4p/view?usp=sharing + + * :attr:`fraction_batches_used_threshold`: The fraction of batches to determine outliers for each channel should be above this + Some batches may not be used because of 0-based errors, so this is to ensure a good amount of the total batches are used + Should be between 0 and 1 + + * :attr:`ch_axis`: The channel axis being observed to determine outliers + + * :attr:`DEFAULT_PRE_OBSERVER_NAME`: The name of the pre-observer to be inserted for this detector + """ + + # names for the pre observers that are inserted + DEFAULT_PRE_OBSERVER_NAME: str = "model_report_pre_observer" + + # pre activation prefix + INPUT_ACTIVATION_PREFIX = "input_activation_" + + # names for dict keys + OUTLIER_KEY = "outliers_detected" + NUM_BATCHES_KEY = "outlier_detection_batches_used" + IS_SUFFICIENT_BATCHES_KEY = "outlier_detection_is_sufficient_batches" + COMP_METRIC_KEY = "outlier_detection_percentile_ratios" + RATIO_THRES_KEY = "outlier_detection_ratio_threshold" + REF_PERCENTILE_KEY = "outlier_detection_reference_percentile" + CHANNEL_AXIS_KEY = "outlier_detection_channel_axis" + MAX_VALS_KEY = INPUT_ACTIVATION_PREFIX + "per_channel_max" + CONSTANT_COUNTS_KEY = "constant_batch_counts" + + def __init__( + self, + ratio_threshold: float = 3.5, + reference_percentile: float = 0.975, + fraction_batches_used_threshold: float = 0.95, + ch_axis: int = 1, + ): + # initialize the variables of interest + self.ratio_threshold = ratio_threshold + + # make sure passed in percentile is valid + if reference_percentile < 0 or reference_percentile > 1: + raise AssertionError("reference_percentile must be between 0 and 1") + if not ( + fraction_batches_used_threshold >= 0 + and fraction_batches_used_threshold <= 1 + ): + raise AssertionError( + "fraction_batches_used_threshold must be between 0 and 1" + ) + self.reference_percentile = reference_percentile + self.fraction_batches_used_threshold = fraction_batches_used_threshold + self.ch_axis = ch_axis + + def get_detector_name(self) -> str: + r"""Returns the name of this detector""" + return "outlier_detector" + + def _supports_insertion(self, module: nn.Module) -> bool: + r"""Returns whether the given module is supported for observers insertion + + Any module that doesn't have children and isn't an observer itself is supported + + Args + module: The module to check and ensure is supported + + Returns True if the module is supported by observer, False otherwise + """ + # case for insertion of module + # check if the module has any children and isn't observer + num_children = len(list(module.children())) + return num_children == 0 and not _is_activation_post_process(module) + + def get_qconfig_info(self, model) -> dict[str, DetectorQConfigInfo]: + r"""Returns the DetectorQConfigInfo for each module_fqn relevant + Args + model (nn.Module or subclass): model to find observer insertion points + + Returns a Dict mapping from unique observer fqns (where we want to insert them) to: + A DetectorQConfigInfo with the information to generate a QConfig for a specific module + """ + # currently doesn't do anything for outlier detector + return {} + + def _supports_report_gen(self, module: nn.Module) -> bool: + r"""Returns whether the given module is supported for report generation + + Any module that has a model report pre-observer is supported + + Args + module: The module to check and ensure is supported + + Returns True if the module is supported by observer, False otherwise + """ + return hasattr(module, self.DEFAULT_PRE_OBSERVER_NAME) + + def determine_observer_insert_points( + self, prepared_fx_model: GraphModule + ) -> dict[str, dict[str, Any]]: + r"""Determines where observers need to be inserted for the Outlier Detector. + + For this detector, we want to place observers in front of supported layers. + + Currently inserts observers for: + all layers that do not have children (leaf level layers) + + Args: + prepared_fx_model (GraphModule): The prepared Fx GraphModule + + Returns a Dict mapping from unique observer fqns (where we want to insert them) to a Dict with: + key "target_node" -> the node we are trying to observe with this observer (torch.fx.node.Node) + key "observer_to_insert" -> the observer we wish to insert (ObserverBase) + key "is_post_observer" -> True if this is meant to be a post-observer for target_node, False if pre-observer + key "observer_args" -> The arguments that are meant to be passed into the observer + """ + # observer for this detector is ModelReportObserver + obs_ctr = ModelReportObserver + + # return dict + obs_fqn_to_info: dict[str, dict[str, Any]] = {} + + for fqn, module in prepared_fx_model.named_modules(): + # check to see if module is of a supported type + if self._supports_insertion(module): + # if it's a supported type, we want to get node and add observer insert locations + targeted_node = self._get_targeting_node(prepared_fx_model, fqn) + + # add entry for pre-observer + pre_obs_fqn = fqn + "." + self.DEFAULT_PRE_OBSERVER_NAME + + obs_fqn_to_info[pre_obs_fqn] = { + DETECTOR_TARGET_NODE_KEY: targeted_node, + DETECTOR_OBS_TO_INSERT_KEY: obs_ctr( + ch_axis=self.ch_axis, comp_percentile=self.reference_percentile + ), + DETECTOR_IS_POST_OBS_KEY: False, + DETECTOR_OBS_ARGS_KEY: targeted_node.args, + } + + return obs_fqn_to_info + + def _calculate_outlier_info( + self, + percentile_ratios: torch.Tensor, + counted_batches: torch.Tensor, + total_batches: int, + ) -> dict[str, list[bool]]: + r""" + Gives info on whether the percentile ratios calculated would be considered outliers + Also gives information on whether the collected data is statistically significant to make this claim + + Args: + percentile_ratios (torch.Tensor): The average percentile_ratios per channel calculated by the observer + counted_batches (torch.Tensor): The number of batches used for average calculation per tensor + total_batches (int): The total number of batches that passed through observer in this epoch + + Returns a dictionary mapping: + "outliers_detected" : list of bools per channel that are true if it is considered an outlier + "is_sufficient_batches": if o_r was >= fraction_batches_used_threshold: + where o_r = counted_batches / total_batches + """ + outlier_dict: dict[str, list[bool]] = { + self.OUTLIER_KEY: [], + self.IS_SUFFICIENT_BATCHES_KEY: [], + } + + # get both as flattened lists for easy mapping + ratios_list: list = percentile_ratios.tolist() + num_batches_list: list = counted_batches.tolist() + + # calculate whether channels were statistically significant + significant_size = [ + batch_size / total_batches >= self.fraction_batches_used_threshold + for batch_size in num_batches_list + ] + outlier_dict[self.IS_SUFFICIENT_BATCHES_KEY] = significant_size + + # calculate for each channel whether it's an outlier or not based on ratio + outlier_detected = [ratio > self.ratio_threshold for ratio in ratios_list] + outlier_dict[self.OUTLIER_KEY] = outlier_detected + + # return the dictionary with the two lists + return outlier_dict + + def _generate_info_dict(self, model: GraphModule) -> dict[str, dict]: + r""" + Helper function for generate_detector_report that does the generation of the dictionary. + This process is done as specified in generate_detector_report documentation + + Args: + model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers + + Returns a dict mapping relevant module fqns to: + whether there were outliers found in activation before + the number of batches used for each channel + whether fraction of applicable batches used is above fraction_batches_used_threshold + their p_r metric compared to the threshold + the threshold used to make the recommendation + the reference_percentile used to make the recommendation + the channel axis used to determine individual channels + the constant batch counts per channel + the per channel max values + """ + # return dictionary mapping observer fqns to desired info + info_dict: dict[str, dict] = {} + + for fqn, module in model.named_modules(): + # if module is supported and it has a pre-observer + if self._supports_report_gen(module): + # get pre observer for the module + pre_obs: ModelReportObserver = getattr( + module, self.DEFAULT_PRE_OBSERVER_NAME + ) + + # get the number of batches and calculated ratio thresholds + num_batches: torch.Tensor = pre_obs.percentile_batches_tracked + average_ratios: torch.Tensor = pre_obs.average_percentile_ratio + channel_batch_cnts: torch.Tensor = pre_obs.constant_channels + total_batches: int = pre_obs.num_batches_tracked + + # also get the max values + max_vals: torch.Tensor = pre_obs.max_val + + # we have to specifically modify how we are recording negative ratio for pre-relu layers + for index, ratio_val in enumerate(average_ratios): + # check if we have a negative ratio + # a ratio might be negative if we have a situation where the 100th percentile is + # > 0 while the nth percentile is < 0, in which case this would not be detected + # as an outlier. Since we care more about magnitude, we make it positive. + if ratio_val.item() < 0: + # first make it positive + average_ratios[index] = -ratio_val + + if ratio_val.item() < 1: + # if it's less than 1 we have the flip it as well + average_ratios[index] = 1 / ratio_val + + outlier_calcs = self._calculate_outlier_info( + average_ratios, num_batches, total_batches + ) + + # calculate whether ratios were outliers + info_dict[fqn] = { + self.CHANNEL_AXIS_KEY: self.ch_axis, + self.REF_PERCENTILE_KEY: self.reference_percentile, + self.RATIO_THRES_KEY: self.ratio_threshold, + self.COMP_METRIC_KEY: average_ratios, + self.NUM_BATCHES_KEY: num_batches, + self.OUTLIER_KEY: outlier_calcs[self.OUTLIER_KEY], + self.IS_SUFFICIENT_BATCHES_KEY: outlier_calcs[ + self.IS_SUFFICIENT_BATCHES_KEY + ], + self.CONSTANT_COUNTS_KEY: channel_batch_cnts, + self.MAX_VALS_KEY: max_vals, + } + + return info_dict + + def generate_detector_report( + self, model: GraphModule + ) -> tuple[str, dict[str, Any]]: + r""" + Determines whether input weight equalization is appropriate for a given module. + + Takes advantage of the ModelReport Observer which records the relevant percentile information + + Args: + model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers + + Returns a tuple with two elements: + String report of of whether there are outliers in the activations around certain modules + Dictionary mapping modules of interest to: + whether there were outliers found in activation before + the number of batches used for each channel + whether fraction of applicable batches used is above fraction_batches_used_threshold + their p_r metric compared to the threshold + the threshold used to make the recommendation + the reference_percentile used to make the recommendation + the channel axis used to determine individual channels + the constant batch counts per channel + the per channel max values + """ + # generate the information dictionary of outlier information + info_dict = self._generate_info_dict(model) + + # now we can generate report based on this information + outlier_string = "Outlier detection report: \n" + + # added module check + added_module: bool = False + + # some strings to be formatted depending on module we are adding + module_suggestion_str = "For Module {} looked at with axis {}: \n" + channel_suggestion_str = "\tFor channel {}, we found outliers in the preceding activation data with {}.\n" + channel_max_value_str = "a max value across all batches of {}" + note_string = "Note: outlier detection is only reliable for {}. We recommend {} to ensure the most accurate results." + note_distribution = "stationary distributions" + note_rec = "running the static vs. dynamic detector to ensure activation data before modules above is stationary" + + # suggestion for constant batch check since that can make it no outliers + constant_str = "\tFor channel {}, we found {} constant value batches. {}\n" + constant_suggestion = "We recommend taking a look at the dict and data to see how frequent this occurred and why." + + # compile the suggestion string + for module_fqn in info_dict: + # get module specific info + mod_info: dict[str, Any] = info_dict[module_fqn] + # check to see if we already added high level model desc + added_model_desc = False + # look at each individual channel and add a suggestion + for index, outlier_detected in enumerate(mod_info[self.OUTLIER_KEY]): + if outlier_detected: + # we found at least 1 outlier + if not added_model_desc: + # add the module level description + outlier_string += module_suggestion_str.format( + module_fqn, self.ch_axis + ) + added_model_desc = True + + # we mark that we found at least one outlier + added_module = True + max_value_found_str = channel_max_value_str.format( + mod_info[self.MAX_VALS_KEY][index] + ) + channel_str = channel_suggestion_str.format( + index, max_value_found_str + ) + outlier_string += channel_str + + # also check if we found constant batch + if mod_info[self.CONSTANT_COUNTS_KEY][index] != 0: + # make sure we add a module level highlight. + if not added_model_desc: + # add the module level description + outlier_string += module_suggestion_str.format( + module_fqn, self.ch_axis + ) + added_model_desc = True + + constant_values_for_channel = mod_info[self.CONSTANT_COUNTS_KEY][ + index + ] + formatted_str = constant_str.format( + index, constant_values_for_channel, constant_suggestion + ) + outlier_string += formatted_str + # we also added at least one thing to description + added_module = True + + # if found outlier, give suggestion, else give default response + if added_module: + # compose the note string + note_composed = note_string.format(note_distribution, note_rec) + outlier_string += note_composed + else: + outlier_string += "There were no outliers found in the activations.\n" + + return (outlier_string, info_dict) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/model_report.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/model_report.py new file mode 100644 index 0000000000000000000000000000000000000000..0ffbff88dd2d80dc237ae779eddd6fad5d26daee --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/model_report.py @@ -0,0 +1,666 @@ +# mypy: allow-untyped-defs +from collections import OrderedDict +from collections.abc import Callable +from typing import Any + +import torch +from torch.ao.quantization.fx._equalize import EqualizationQConfig +from torch.ao.quantization.fx._model_report.detector import ( + DETECTOR_IS_POST_OBS_KEY, + DETECTOR_OBS_ARGS_KEY, + DETECTOR_OBS_TO_INSERT_KEY, + DETECTOR_TARGET_NODE_KEY, + DetectorBase, + DetectorQConfigInfo, +) +from torch.ao.quantization.fx._model_report.model_report_visualizer import ( + ModelReportVisualizer, +) +from torch.ao.quantization.fx.graph_module import GraphModule +from torch.ao.quantization.observer import ObserverBase +from torch.ao.quantization.qconfig_mapping import QConfig, QConfigMapping + + +class ModelReport: + r""" + The ModelReport class aims to provide users an easy way to diagnose issues that they run into + with their models. The class works with all traceable GraphModules to help diagnose issues, + though the requirements on the type of model more-so depends on the specific report the user + is trying to generate. With respect to the reports, the ModelReport class is initialized with + a set of Detector classes, each of which generate reports on quantization configuration + issues a use might have. + + Currently supports generating reports on: + - Suggestions for per-channel vs. per-tensor quantization (nn.Module) + - Suggestions for dynamic vs static quantization for linear layers (Graph Modules) + - Suggestions for input-weight equalization for linear and conv layers (Graph Modules) + - Suggestions for outlier detection for all layers (Graph Modules) + + The ModelReport class has the primary functionality of inserting observers (primarily the ModelReportObserver) + where needed for each detector to gather the information it needs, and then after calibration, the ModelReport + class compiles the report generated by each Detector class into a single report to return to the user. It also + has the capability to remove all the observers it inserted as well. + + * :attr:`_model` The model we wish to generate the report for. Must be a traceable GraphModule + + * :attr:`_desired_report_detectors` The set of Detectors representing desired reports from the ModelReport class + Make sure that these are all unique types of detectors [do not have more than 1 of the same class] + + * :attr:`_desired_detector_names` The set of detector names of the _desired_report_detectors. + This set is generated by calling the get_detector_name() of each detector + + * :attr:`_detector_name_to_observer_fqns` The mapping from each detector to fqns of observers of interest + The purpose of this is to keep track of what observers were inserted for each detector, so that they + can be removed at the end if desired + + * :attr:`_prepared_flag` A boolean flag that keeps track of whether we have prepared the model or not + This is to ensure we only insert observers once with the ModelReport instance + + * :attr:`_removed_observers` A boolean to track if we have removed observers already + The purpose is to ensure we don't attempt to remove observers twice with the same ModelReport + instance. This also allows the functionality where we can generate the report multiple times + as long as we haven't removed the observers yet. + + Note: + This class was initially designed to work with the Fx Graph Mode workflow in mind. However, + full functionality is available as long as there is a traceable GraphModule that is being used. + One method to get a traceable GraphModule without going through the Fx workflow is to use + the QuantizationTracer class. + + General Flow for Fx workflow: + 1.) Initialize ModelReport object with reports of interest by passing in initialized detector objects and model + 2.) Prepare your model with prepare_fx + 3.) Call model_report.prepare_detailed_calibration to add relevant observers + 4.) Calibrate your model with data + 5.) Call model_report.generate_report on your model to generate report and optionally remove added observers + Optional + 6.) Call model_report.generate_visualizer to get a ModelReportVisualizer instance + 7.) To help in parsing report information and debugging, view report info as a: + - Table + - Histogram + - Line plot + 8.) Call model_report.generate_qconfigs to generate the qconfigs based on the report suggestions + + Example (with QuantizationTracer): + >>> # xdoctest: +SKIP + >>> # get the necessary qconfig + >>> config = PrepareCustomConfig() + >>> skipped_module_names, skipped_module_classes = ( + ... get_skipped_module_name_and_classes(config, False) + ... ) + + >>> # initialize our model and get GraphModule + >>> model = SomeModel() + >>> tracer = QuantizationTracer(skipped_module_names, skipped_module_classes) + >>> graph_module = GraphModule(model, tracer.trace(model)) + + >>> # get our set of detectors and ModelReport instance + >>> detector_set = set( + ... [ + ... DynamicStaticDetector(tolerance=0.5), + ... InputWeightEqualizationDetector(ratio_threshold=0.7), + ... ] + ... ) + >>> tracer_reporter = ModelReport(graph_module, tracer_detector_set) + + >>> # now we insert the observers and calibrate the model + >>> tracer_model_with_observers = tracer_reporter.prepare_detailed_calibration() + >>> for i in range(num_callibration_batches): + >>> example_input = get_callibration_input() + >>> tracer_model_with_observers(example_input) + + >>> # finally we generate the reports and optionally remove the observers we inserted + >>> reports = tracer_reporter.generate_model_report( + ... remove_inserted_observers=True + ... ) + + >>> # Optional: we can generate the qconfig mapping based on the suggestions + >>> qconfigs = model_report.generate_qconfig_mapping() + + >>> # Optional: we can generate the equalization mapping based on the suggestions + >>> qconfigs = model_report.generate_equalization_mapping() + + >>> # Optional: we get a ModelReportVisualizer instance to do any visualizations desired + >>> model_report_visualizer = tracer_reporter.generate_visualizer() + + """ + + def __init__(self, model: GraphModule, desired_report_detectors: set[DetectorBase]): + if len(desired_report_detectors) == 0: + raise ValueError("Should include at least 1 desired report") + + # keep track of the model we wish to generate report for + self._model: GraphModule = model + + # keep the reports private so they can't be modified + self._desired_report_detectors = desired_report_detectors + self._desired_detector_names = { + detector.get_detector_name() for detector in desired_report_detectors + } + + # keep a mapping of desired reports to observers of interest + # this is to get the readings, and to remove them, can create a large set + # this set can then be used to traverse the graph and remove added observers + self._detector_name_to_observer_fqns: dict[str, set[str]] = {} + + # initialize each report to have empty set of observers of interest + for desired_report in self._desired_detector_names: + self._detector_name_to_observer_fqns[desired_report] = set() + + # flags to ensure that we can only prepare and remove observers once + self._prepared_flag = False + self._removed_observers = False + + # store the reports that we generated for visualization purposes + # initially empty since no reports generated + self._generated_reports: dict[str, dict] = {} + + def get_desired_reports_names(self) -> set[str]: + """Returns a copy of the desired reports for viewing""" + return self._desired_detector_names.copy() + + def get_observers_of_interest(self) -> dict[str, set[str]]: + """Returns a copy of the observers of interest for viewing""" + return self._detector_name_to_observer_fqns.copy() + + def prepare_detailed_calibration(self) -> GraphModule: + r""" + Takes in a graph model and inserts the following observers: + - ModelReportObserver + + Each observer is inserted based on the desired_reports into the relevant locations + + Right now, each report in self._desired_detector_names has independent insertions + However, if a module already has a Observer of the same type, the insertion will not occur + This is because all of the same type of Observer collect same information, so redundant + + Returns the same GraphModule with the observers inserted + """ + + # if already prepared once, cannot prepare again + if self._prepared_flag: + raise ValueError( + "Already ran preparing detailed calibration. Run the report generation next after calibration." + ) + + # loop through each detector, find where placements should be, and keep track + insert_observers_fqns: dict[str, Any] = {} + + for detector in self._desired_report_detectors: + # determine observer points for each detector + obs_fqn_to_info = detector.determine_observer_insert_points(self._model) + # map each insert point to the observer to use + insert_observers_fqns.update(obs_fqn_to_info) + # update the set of observers this report cares about + self._detector_name_to_observer_fqns[detector.get_detector_name()] = set( + obs_fqn_to_info.keys() + ) + + # now insert all the observers at their desired locations + for observer_fqn in insert_observers_fqns: + target_node = insert_observers_fqns[observer_fqn][DETECTOR_TARGET_NODE_KEY] + insert_obs = insert_observers_fqns[observer_fqn][DETECTOR_OBS_TO_INSERT_KEY] + insert_post = insert_observers_fqns[observer_fqn][DETECTOR_IS_POST_OBS_KEY] + observer_args = insert_observers_fqns[observer_fqn][DETECTOR_OBS_ARGS_KEY] + self._insert_observer_around_module( + observer_fqn, target_node, insert_obs, observer_args, insert_post + ) + + self._prepared_flag = True + + return self._model + + def _insert_observer_around_module( + self, + obs_fqn: str, + target_node: torch.fx.node.Node, + obs_to_insert: ObserverBase, + observer_args: tuple, + insert_post: bool, + ): + r""" + Helper function that inserts the observer into both the graph structure and the module of the model + + Args + node_fqn (str): The fully qualified name of the observer we want to insert + target_node (torch.fx.node.Node): The node in model we are inserting observers around + obs_to_insert (ObserverBase): The observer we are inserting around target_node + observer_args (Tuple): The arguments we want to pass into the observer + insert_post (bool): whether this is meant to be a post observer for this node + """ + # if we are inserting post, then our target node is the next node + if insert_post: + target_node = target_node.next + + with self._model.graph.inserting_before(target_node): + self._model.add_submodule(obs_fqn, obs_to_insert) + self._model.graph.create_node( + op="call_module", target=obs_fqn, args=observer_args + ) + + # recompile model after inserts are made + self._model.recompile() + + def _get_node_from_fqn(self, node_fqn: str) -> torch.fx.node.Node: + r""" + Takes in a node fqn and returns the node based on the fqn + + Args + node_fqn (str): The fully qualified name of the node we want to find in model + + Returns the Node object of the given node_fqn otherwise returns None + """ + node_to_return = None + for node in self._model.graph.nodes: + # if the target matches the fqn, it's the node we are looking for + if node.target == node_fqn: + node_to_return = node + break + + if node_to_return is None: + raise ValueError("The node_fqn is was not found within the module.") + + # assert for MyPy + if not isinstance(node_to_return, torch.fx.node.Node): + raise AssertionError("node_to_return must be a torch.fx.node.Node") + + return node_to_return + + def generate_model_report( + self, remove_inserted_observers: bool + ) -> dict[str, tuple[str, dict]]: + r""" + Generates all the requested reports. + + Note: + You should have calibrated the model with relevant data before calling this + + The reports generated are specified by the desired_reports specified in desired_reports + + Can optionally remove all the observers inserted by the ModelReport instance + + Args: + remove_inserted_observers (bool): True to remove the observers inserted by this ModelReport instance + + Returns a mapping of each desired report name to a tuple with: + The textual summary of that report information + A dictionary containing relevant statistics or information for that report + + Note: + Throws exception if we try to generate report on model we already removed observers from + Throws exception if we try to generate report without preparing for calibration + """ + # if we haven't prepped model for calibration, then we shouldn't generate report yet + if not self._prepared_flag: + raise Exception( # noqa: TRY002 + "Cannot generate report without preparing model for calibration" + ) + + # if we already removed the observers, we cannot generate report + if self._removed_observers: + raise Exception( # noqa: TRY002 + "Cannot generate report on model you already removed observers from" + ) + + # keep track of all the reports of interest and their outputs + reports_of_interest = {} + + for detector in self._desired_report_detectors: + # generate the individual report for the detector + report_output = detector.generate_detector_report(self._model) + reports_of_interest[detector.get_detector_name()] = report_output + + # if user wishes to remove inserted observers, go ahead and remove + if remove_inserted_observers: + self._removed_observers = True + # get the set of all Observers inserted by this instance of ModelReport + all_observers_of_interest: set[str] = set() + for desired_report in self._detector_name_to_observer_fqns: + observers_of_interest = self._detector_name_to_observer_fqns[ + desired_report + ] + all_observers_of_interest.update(observers_of_interest) + + # go through all_observers_of_interest and remove them from the graph and model + for observer_fqn in all_observers_of_interest: + # remove the observer from the model + self._model.delete_submodule(observer_fqn) + + # remove the observer from the graph structure + node_obj = self._get_node_from_fqn(observer_fqn) + + if node_obj: + self._model.graph.erase_node(node_obj) + else: + raise ValueError("Node no longer exists in GraphModule structure") + + # remember to recompile the model + self._model.recompile() + + # save the generated reports for visualization purposes + saved_reports: dict[str, dict] = { + report_name: report_tuple[1] + for report_name, report_tuple in reports_of_interest.items() + } + + self._generated_reports = saved_reports + + # return the reports of interest + return reports_of_interest + + def _is_same_info_for_same_key(self, info_dict_a: dict, info_dict_b: dict) -> bool: + r""" + Takes in two dictionaries and ensures that any common keys between the two have the same + values. + + Args: + info_dict_a (Dict): First dictionary we wish to compare + info_dict_b (Dict): Second dictionary we wish to compare + + Returns True if all shared keys have same values, false otherwise + """ + # get the set of keys for both + dict_a_keys: set = set(info_dict_a.keys()) + dict_b_keys: set = set(info_dict_b.keys()) + + # get the insersection keys and check if same value for both dicts + intersecting_keys: set = dict_a_keys.intersection(dict_b_keys) + + for key in intersecting_keys: + dict_a_val = info_dict_a[key] + dict_b_val = info_dict_b[key] + + # if it's a tensor we have to handle separately + if type(dict_a_val) is torch.Tensor: + # if dict_b_val not tensor, automatically false + if ( + type(dict_b_val) is not torch.Tensor + or sum(dict_a_val != dict_b_val) != 0 + ): + return False + else: + # for non-tensor vals + if dict_a_val != dict_b_val: + return False + + # if no non matching shared keys found, return true + return True + + def _reformat_reports_for_visualizer(self) -> OrderedDict: + r""" + Takes the generated reports and reformats them into the format that is desired by the + ModelReportVisualizer + + Returns an OrderedDict mapping module_fqns to their features + """ + # we want to reorder and reformat the information so it is ordered in terms of order + # found in the model + + # first create new dict with all modules as keys and features under respective module + module_fqns_to_features: dict[str, dict] = {} + + for report_name in self._generated_reports: + # get mod -> feature dict and go through + module_info = self._generated_reports[report_name] + + for module_fqn in module_info: + # check if already in our accumulation dict + if module_fqn in module_fqns_to_features: + # we merge all the features together + new_info: dict = module_info[module_fqn] + present_info: dict = module_fqns_to_features[module_fqn] + + # merge them together into the new unioned dict + # same features keys -> same info, so okay if override + + # do safety check to make sure shared keys have same info + if self._is_same_info_for_same_key(new_info, present_info): + module_fqns_to_features[module_fqn] = { + **new_info, + **present_info, + } + else: + error_str = "You have the same key with different values across detectors. " + error_str += "Someone incorrectly implemented a detector with conflicting keys to existing detectors." + raise ValueError(error_str) + else: + # we just set it + module_fqns_to_features[module_fqn] = module_info[module_fqn] + + # our ordered dict so that modules can be ordered in order of how they appear in model + features_by_module: OrderedDict[str, dict] = OrderedDict() + + # we loop through modules in graph in order + for fqn, _module in self._model.named_modules(): + # find that fqn in fqns_to_features + if fqn in module_fqns_to_features: + # add it to our ordered dict + features_by_module[fqn] = module_fqns_to_features[fqn] + + # return the ordered dict of info we created + return features_by_module + + def generate_visualizer(self) -> ModelReportVisualizer: + r""" + Generates a ModelReportVisualizer instance using the reports generated + by the generate_model_report() method. + + Returns the generated ModelReportVisualizer instance initialized + + Note: + Throws exception if attempt to get visualizers without generating report + """ + # check if user has generated reports at least once + if len(self._generated_reports) == 0: + raise Exception( # noqa: TRY002 + "Unable to generate visualizers without first generating reports" + ) + + # get the ordered dict mapping modules to their full set of collected features / stats + module_fqns_to_features: OrderedDict = self._reformat_reports_for_visualizer() + + # create and return ModelReportVisualizer instance + visualizer: ModelReportVisualizer = ModelReportVisualizer( + module_fqns_to_features + ) + + return visualizer + + def _generate_qconfig_mapping_helper( + self, + detector_qconfig_info_combined: dict[str, DetectorQConfigInfo], + generation_function: Callable, + ) -> QConfigMapping: + r""" + This helper takes in the compiled detector qconfig info that + has been compiled together and merges it into a QConfigMapping + """ + # keep track of the qconfigmapping + qconfig_mapping = QConfigMapping() + + # loop through each module / fqn and attempt to create QConfigMapping + for fqn, module in self._model.named_modules(): + # if we have a qconfig info for this module + if fqn in detector_qconfig_info_combined: + qconfig_info_compiled = detector_qconfig_info_combined[fqn] + + # now generate the qconfig and add it to the mapping + generated_qconfig = generation_function(qconfig_info_compiled, module) + + # add to our config + qconfig_mapping.set_module_name(fqn, generated_qconfig) + + # return compiled mapping + return qconfig_mapping + + def _update_detector_quantizaiton_qconfig_info( + self, combined_info: DetectorQConfigInfo, new_info: DetectorQConfigInfo + ): + r""" + Takes in the old and new information and updates the combined information. + + Args: + combined_info (DetectorQConfigInfo): The DetectorQConfigInfo we are compiling all of the information in + new_info (DetectorQConfigInfo): The DetectorQConfigInfo with the information we are trying to merge the new info + into it + """ + combined_info.is_activation_dynamic = ( + combined_info.is_activation_dynamic or new_info.is_activation_dynamic + ) + combined_info.is_weight_per_channel = ( + combined_info.is_weight_per_channel or new_info.is_weight_per_channel + ) + + def _update_detector_equalization_qconfig_info( + self, combined_info: DetectorQConfigInfo, new_info: DetectorQConfigInfo + ): + r""" + Takes in the old and new information and updates the combined information. + + Args: + combined_info (DetectorQConfigInfo): The DetectorQConfigInfo we are compiling all of the information in + new_info (DetectorQConfigInfo): The DetectorQConfigInfo with the information we are trying to merge the new info + into it + """ + is_equalization_recommended = ( + combined_info.is_equalization_recommended + or new_info.is_equalization_recommended + ) + combined_info.is_equalization_recommended = is_equalization_recommended + + def _generate_module_fqn_to_detector_info_mapping( + self, update_qconfig_info_function: Callable + ) -> dict[str, DetectorQConfigInfo]: + r""" + Generates a QConfigMapping based on the suggestions of the + ModelReport API. The generated mapping encompasses all the + different types of feedback from the different detectors + all into one place. + + These configs are based on the suggestions provided by the ModelReport API + and can only be generated once the reports have been generated. + + Args: + update_qconfig_info_function (Callable) takes in a function that takes in two DetectorQConfigInfo + and updates the one that is being compiled + + Returns a Dict mapping module_fqns to DetectorQConfigInfo objects + + Note: + Throws exception if we try to generate mapping on model we already removed observers from + Throws exception if we try to generate mapping without preparing for calibration + """ + # if we haven't prepped model for calibration, then we shouldn't generate mapping yet + if not self._prepared_flag: + raise Exception( # noqa: TRY002 + "Cannot generate report without preparing model for calibration" + ) + + # if we already removed the observers, we cannot mapping + if self._removed_observers: + raise Exception( # noqa: TRY002 + "Cannot generate report on model you already removed observers from" + ) + + # keep track of qconfig info for each module across detectors + detector_qconfig_info_combined: dict[str, DetectorQConfigInfo] = {} + + for detector in self._desired_report_detectors: + # get the info from the detector + detector_info: dict[str, DetectorQConfigInfo] = detector.get_qconfig_info( + self._model + ) + + # we go through the modules + for module_fqn in detector_info: + # see if we already have info on it + if module_fqn in detector_qconfig_info_combined: + # we combine the current options with what is there + current_options = detector_qconfig_info_combined[module_fqn] + detector_options = detector_info[module_fqn] + + update_qconfig_info_function(current_options, detector_options) + else: + # we just use this for now + detector_qconfig_info_combined[module_fqn] = detector_info[ + module_fqn + ] + + return detector_qconfig_info_combined + + def generate_qconfig_mapping(self) -> QConfigMapping: + r""" + Generates a QConfigMapping based on the suggestions of the + ModelReport API. The generated mapping encompasses all the + different types of feedback from the different detectors + all into one place. + + These configs are based on the suggestions provided by the ModelReport API + and can only be generated once the reports have been generated. + + Returns a QConfigMapping for the quantization configuration + + Note: + Throws exception if we try to generate mapping on model we already removed observers from + Throws exception if we try to generate mapping without preparing for calibration + """ + # get the mapping info + detector_qconfig_info_combined = ( + self._generate_module_fqn_to_detector_info_mapping( + self._update_detector_quantizaiton_qconfig_info + ) + ) + + # we will do a bit of processing and remove fqns that don't have input weight recommended + + # now we generate the QConfig for each of the options + mapping: QConfigMapping = self._generate_qconfig_mapping_helper( + detector_qconfig_info_combined, self._quantization_config_generator + ) + + # return the generated mapping + return mapping + + def _quantization_config_generator( + self, detector_qconfig_info: DetectorQConfigInfo, module: torch.nn.Module + ) -> QConfig: + r""" + Returns the quantization configuration generated by the DetectorQConfigInfo object + """ + return detector_qconfig_info.generate_quantization_qconfig(module) + + def _equalization_config_generator( + self, detector_qconfig_info: DetectorQConfigInfo, module: torch.nn.Module + ) -> EqualizationQConfig: + r""" + We ignore the module argument here, and only focus on thedetector_qconfig_info + + Returns the equalization configuration generated by the DetectorQConfigInfo object + """ + return detector_qconfig_info.generate_equalization_qconfig() + + def generate_equalization_mapping(self) -> QConfigMapping: + r""" + Generates a QConfigMapping based on the suggestions of the + ModelReport API for equalization. The generated mapping encompasses all the + different types of feedback from the input-weight equalization detector. + + These configs are based on the suggestions provided by the ModelReport API + and can only be generated once the reports have been generated. + + Returns a QConfigMapping for the equalization configuration + """ + # get the mapping info + detector_qconfig_info_combined = ( + self._generate_module_fqn_to_detector_info_mapping( + self._update_detector_equalization_qconfig_info + ) + ) + + # now we generate the QConfig for each of the options + mapping: QConfigMapping = self._generate_qconfig_mapping_helper( + detector_qconfig_info_combined, self._equalization_config_generator + ) + + # return the generated mapping + return mapping diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/model_report_observer.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/model_report_observer.py new file mode 100644 index 0000000000000000000000000000000000000000..a809dc60838e574e0bd484ee9698e9d1a0a5ee47 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/model_report_observer.py @@ -0,0 +1,285 @@ +# mypy: allow-untyped-defs +import torch +from torch.ao.quantization.observer import ObserverBase + + +class ModelReportObserver(ObserverBase): + r"""This observer is used to record additional information regarding keeping track + of S = average_batch_activation_range/epoch_activation_range. + + The purpose of this information is to prepare a report to present to users on whether + Dynamic or Static Quantization is more appropriate for their model given the general + distributions of their data. + + Args: + ch_axis (int, optional): The channel axis for which the range and outlier stats are computed + Default: 1 + comp_percentile (float, optional): The percentile to compare against 100 percentile to find outliers + Should be between 0 and 1 exclusive + Default: 0.9 + + * :attr:`num_batches_tracked` specifies number of batches passed through the observer + + * :attr:`average_batch_activation_range` defines average across the ranges of each batch passed through + + * :attr:`epoch_activation_min` defines the minimum value passed through the observer + + * :attr:`epoch_activation_max` defines the maximum value passed through the observer + + * :attr:`ch_axis` defines the channel being used to compute per channel min max stats + + * :attr:`min_val` defines the per channel minimum values passed through + + * :attr:`max_val` defines the per channel maximum values passed through + + * :attr:`comp_percentile` defines comparison percentile to find outliers + + * :attr:`average_percentile_ratio` defines the per channel average percentile ratios + + * :attr:`percentile_batches_tracked` defines the number of percentile batches tracked for each channel + + * :attr:`constant_channels` defines the number of batches that aren't constant channels per channel + + Note: this tool is meant for FX Graph Mode Quantization + """ + + epoch_activation_min: torch.Tensor + epoch_activation_max: torch.Tensor + min_val: torch.Tensor + max_val: torch.Tensor + comp_percentile: torch.Tensor + average_percentile_ratio: torch.Tensor + percentile_batches_tracked: torch.Tensor + constant_channels: torch.Tensor + + def __init__(self, ch_axis: int = 1, comp_percentile: float = 0.9): + super().__init__(torch.qint8) + self.num_batches_tracked = 0 + + # keep track of the min and mix of the range for average batch and epoch as a whole + self.average_batch_activation_range: torch.Tensor = torch.tensor(float(0)) + self.register_buffer("epoch_activation_min", torch.tensor(float("inf"))) + self.register_buffer("epoch_activation_max", torch.tensor(float("-inf"))) + + # keep track of per channel min max information using the given channel + self.ch_axis: int = ch_axis + self.register_buffer("min_val", torch.tensor([])) + self.register_buffer("max_val", torch.tensor([])) + + # keep track of percentile ratio information per channel + self.register_buffer("comp_percentile", torch.tensor([comp_percentile])) + self.register_buffer("average_percentile_ratio", torch.tensor([])) + self.register_buffer("percentile_batches_tracked", torch.tensor([])) + self.register_buffer("constant_channels", torch.tensor([])) + + def forward(self, x): + x_copy = x.detach() # avoid keeping autograd tape + x_copy = x_copy.to(self.epoch_activation_min.dtype) + + x_copy = self._calculate_range_stats(x_copy) + x_copy = self._calculate_min_max_stats(x_copy) + x_copy = self._calculate_percentile_stats(x_copy) + + # return the passed in the value + return x + + def _calculate_range_stats(self, x_copy): + r"""Calculates and stores range stats with forward values. + + Args + x_copy: A copy of the forward data + + Returns the passed in x_copy + """ + # get the min, max values of the data + min_val_cur, max_val_cur = torch.aminmax(x_copy) + + # calculate new epoch range values + epoch_min_val = torch.min(self.epoch_activation_min, min_val_cur) + epoch_max_val = torch.max(self.epoch_activation_max, max_val_cur) + + self.epoch_activation_min.copy_(epoch_min_val) + self.epoch_activation_max.copy_(epoch_max_val) + + # calculate the average batch activation range + current_batch_range = max_val_cur - min_val_cur + new_range = ( + self.average_batch_activation_range * self.num_batches_tracked + + current_batch_range + ) / (self.num_batches_tracked + 1) + + self.average_batch_activation_range = new_range + self.num_batches_tracked += 1 # new batch was processed + + return x_copy + + def _calculate_min_max_stats(self, x_copy): + r"""Calculates and stores the per_channel min, max stats with forward values. + Does calculation based on channel axis: self.ch_axis + + Args + x_copy: A copy of the forward data + + Returns the passed in x_copy + """ + # get the current min and max vals + min_val = self.min_val + max_val = self.max_val + x_dim = x_copy.size() + + new_axis_list = [i for i in range(len(x_dim))] # noqa: C416 + new_axis_list[self.ch_axis] = 0 + new_axis_list[0] = self.ch_axis + y = x_copy.permute(new_axis_list) + # Need to match dtype of min/max because the updates to buffers + # are done in place and types need to match for comparisons + y = y.to(self.min_val.dtype) + y = torch.flatten(y, start_dim=1) + if min_val.numel() == 0 or max_val.numel() == 0: + min_val, max_val = torch.aminmax(y, dim=1) + else: + min_val_cur, max_val_cur = torch.aminmax(y, dim=1) + min_val = torch.min(min_val_cur, min_val) + max_val = torch.max(max_val_cur, max_val) + + self.min_val.resize_(min_val.shape) + self.max_val.resize_(max_val.shape) + self.min_val.copy_(min_val) + self.max_val.copy_(max_val) + + return x_copy + + def _calculate_percentile_stats(self, x_copy): + r"""Calculates and stores the per_channel percentile stats with forward values. + Does calculation based on channel axis: self.ch_axis + + Args + x_copy: A copy of the forward data + + Returns the passed in x_copy + """ + # get the dimension of the copy + x_dim = x_copy.size() + + new_axis_list = [i for i in range(len(x_dim))] # noqa: C416 + new_axis_list[self.ch_axis] = 0 + new_axis_list[0] = self.ch_axis + y = x_copy.permute(new_axis_list) + # Need to match dtype of min/max because the updates to buffers + # are done in place and types need to match for comparisons + y = y.to(self.min_val.dtype) + y = torch.flatten(y, start_dim=1) + y = y.to(dtype=self.min_val.dtype, device="cpu") + + # find the percentile values along the axis + # we want both 100th percentile and comp_percentile + # we also want to find 0th quartile to see if we have constant channel + quantiles_list = [0, self.comp_percentile, 1.00] + quantiles_to_find = torch.tensor(quantiles_list, dtype=self.min_val.dtype) + + # find the quantiles + desired_quantiles = torch.quantile( + y, quantiles_to_find, dim=self.ch_axis, interpolation="lower" + ) + zero_quantile = desired_quantiles[0] + comp_quantile = desired_quantiles[1] + hundreth_quartile = desired_quantiles[2] + + # if any of the channels have 0s, we ignore that channel for this calculation + any_non_zero_quantile_value: torch.Tensor = ( + comp_quantile != torch.tensor([0]) + ) | (hundreth_quartile != torch.tensor([0])) + any_non_zero_quantile_value = ( + any_non_zero_quantile_value.int() + ) # transform boolean values to int values + + # we also check if we have a constant channel + any_constant_channels: torch.Tensor = ( + hundreth_quartile - zero_quantile + ) == torch.tensor([0]) + any_constant_channels = ( + any_constant_channels.int() + ) # transform boolean values to int values + + # possibilities to get nan as an answer + # will ignore any of these three cases with 0s and just not deal with them for now + # case (1) 0 in numerator: issue if 0 is largest, all negative, and rest are really negative + # case (2) 0 in denominator: is possible unless case 3, we just ignore + # case (3) 0 in both: not outlier, channel just kinda useless, ignore + + # get the ratio and get rid of nan values + quantile_ratios = hundreth_quartile / comp_quantile + quantile_ratios = torch.nan_to_num(quantile_ratios) + # update averages, remembering to only update if didn't have zeros + ratio_if_not_zero = any_non_zero_quantile_value * quantile_ratios + + # if num_batches and average_ratio are not initialized, we want to initialize them + if ( + self.percentile_batches_tracked.shape[0] == 0 + or self.average_percentile_ratio.shape[0] == 0 + ): + self.percentile_batches_tracked = torch.zeros_like( + any_non_zero_quantile_value + ) + self.average_percentile_ratio = torch.zeros_like(ratio_if_not_zero) + + # also initialize the constant channel var if that is not initialized separately + if self.constant_channels.shape[0] == 0: + self.constant_channels = torch.zeros_like(any_constant_channels) + + # get current num batches and average ratio + num_batches = self.percentile_batches_tracked + average_ratio = self.average_percentile_ratio + + # calculate new_number of batches, new_ratios, and get rid of nans because of 0 size batches + new_number_of_batches: torch.Tensor = num_batches + any_non_zero_quantile_value + new_ratios: torch.Tensor = ( + (average_ratio * num_batches) + ratio_if_not_zero + ) / new_number_of_batches + new_ratios = torch.nan_to_num(new_ratios) + + # update the number of non-constant channels + new_constant_count: torch.Tensor = ( + self.constant_channels + any_constant_channels + ) + + # update the values locally + self.percentile_batches_tracked.copy_(new_number_of_batches) + self.average_percentile_ratio.copy_(new_ratios) + self.constant_channels.copy_(new_constant_count) + + return x_copy + + @torch.jit.export + def get_batch_to_epoch_ratio(self): + epoch_activation_range = self.epoch_activation_max - self.epoch_activation_min + + if epoch_activation_range == torch.tensor(float(0)): + raise ValueError("Range for Epoch is 0") + elif epoch_activation_range == torch.tensor(float("inf")): + raise ValueError( + "No data has been run through observer or infinity value present" + ) + else: + return self.average_batch_activation_range / epoch_activation_range + + @torch.jit.export + def reset_batch_and_epoch_values(self): + # set all the values back to their original defaults for a new epoch + # keep device + device = self.max_val.device + self.num_batches_tracked = 0 + self.average_batch_activation_range = torch.tensor(float(0), device=device) + self.epoch_activation_min = torch.tensor(float("inf"), device=device) + self.epoch_activation_max = torch.tensor(float("-inf"), device=device) + self.min_val = torch.tensor([], device=device) + self.max_val = torch.tensor([], device=device) + self.average_percentile_ratio = torch.tensor([], device=device) + self.percentile_batches_tracked = torch.tensor([], device=device) + self.constant_channels = torch.tensor([], device=device) + + @torch.jit.export + def calculate_qparams(self): # type: ignore[override] + raise Exception( # noqa: TRY002 + "calculate_qparams should not be called for ModelReportObserver" + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/model_report_visualizer.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/model_report_visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..2e58772660c5a9067f727bf066b5519f65f37637 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/model_report_visualizer.py @@ -0,0 +1,712 @@ +# mypy: allow-untyped-defs +from collections import OrderedDict, OrderedDict as OrdDict +from typing import Any + +import torch + + +# try to import tablate +got_tabulate = True +try: + from tabulate import tabulate +except ImportError: + got_tabulate = False + + +# var to see if we could import matplotlib +got_matplotlib = True +try: + import matplotlib.pyplot as plt +except ImportError: + got_matplotlib = False + + +class ModelReportVisualizer: + r""" + The ModelReportVisualizer class aims to provide users a way to visualize some of the statistics + that were generated by the ModelReport API. However, at a higher level, the class aims to provide + some level of visualization of statistics to PyTorch in order to make it easier to parse data and + diagnose any potential issues with data or a specific model. With respect to the visualizations, + the ModelReportVisualizer class currently supports several methods of visualizing data. + + Supported Visualization Methods Include: + - Table format + - Plot format (line graph) + - Histogram format + + For all of the existing visualization methods, there is the option to filter data based on: + - A module fqn prefix + - Feature [required for the plot and histogram] + + * :attr:`generated_reports` The reports generated by the ModelReport class in the structure below + Ensure sure that features that are the same across different report contain the same name + Ensure that objects representing the same features are the same type / dimension (where applicable) + + Note: + Currently, the ModelReportVisualizer class supports visualization of data generated by the + ModelReport class. However, this structure is extensible and should allow the visualization of + other information as long as the information is structured in the following general format: + + Report Structure + -- module_fqn [module with attached detectors] + | + -- feature keys [not every detector extracts same information] + [same collected info has same keys, unless can be specific to detector] + + + The goal behind the class is that the generated visualizations can be used in conjunction with the generated + report for people to get a better understanding of issues and what the fix might be. It is also just to provide + a good visualization platform, since it might be hard to parse through the ModelReport returned dictionary as + that grows in size. + + General Use Flow Expected + 1.) Initialize ModelReport object with reports of interest by passing in initialized detector objects + 2.) Prepare your model with prepare_fx + 3.) Call model_report.prepare_detailed_calibration on your model to add relevant observers + 4.) Calibrate your model with data + 5.) Call model_report.generate_report on your model to generate report and optionally remove added observers + 6.) Use output of model_report.generate_report to initialize ModelReportVisualizer instance + 7.) Use instance to view different views of data as desired, applying filters as needed + 8.) Either see the super detailed information or just the actual printed or shown table / plot / histogram + + """ + + # keys for table dict + TABLE_TENSOR_KEY = "tensor_level_info" + TABLE_CHANNEL_KEY = "channel_level_info" + + # Constants for header vals + NUM_NON_FEATURE_TENSOR_HEADERS = 2 + NUM_NON_FEATURE_CHANNEL_HEADERS = 3 + + # Constants for row index in header + CHANNEL_NUM_INDEX = 2 + + def __init__(self, generated_reports: OrderedDict[str, Any]): + r""" + Initializes the ModelReportVisualizer instance with the necessary reports. + + Args: + generated_reports (Dict[str, Any]): The reports generated by the ModelReport class + can also be a dictionary generated in another manner, as long as format is same + """ + self.generated_reports = generated_reports + + def get_all_unique_module_fqns(self) -> set[str]: + r""" + The purpose of this method is to provide a user the set of all module_fqns so that if + they wish to use some of the filtering capabilities of the ModelReportVisualizer class, + they don't need to manually parse the generated_reports dictionary to get this information. + + Returns all the unique module fqns present in the reports the ModelReportVisualizer + instance was initialized with. + """ + # returns the keys of the ordered dict + return set(self.generated_reports.keys()) + + def get_all_unique_feature_names( + self, plottable_features_only: bool = True + ) -> set[str]: + r""" + The purpose of this method is to provide a user the set of all feature names so that if + they wish to use the filtering capabilities of the generate_table_view(), or use either of + the generate_plot_view() or generate_histogram_view(), they don't need to manually parse + the generated_reports dictionary to get this information. + + Args: + plottable_features_only (bool): True if the user is only looking for plottable features, + False otherwise + plottable features are those that are tensor values + Default: True (only return those feature names that are plottable) + + Returns all the unique module fqns present in the reports the ModelReportVisualizer + instance was initialized with. + """ + unique_feature_names = set() + for module_fqn in self.generated_reports: + # get dict of the features + feature_dict: dict[str, Any] = self.generated_reports[module_fqn] + + # loop through features + for feature_name in feature_dict: + # if we need plottable, ensure type of val is tensor + if ( + not plottable_features_only + or type(feature_dict[feature_name]) is torch.Tensor + ): + unique_feature_names.add(feature_name) + + # return our compiled set of unique feature names + return unique_feature_names + + def _get_filtered_data( + self, feature_filter: str, module_fqn_filter: str + ) -> OrderedDict[str, Any]: + r""" + Filters the data and returns it in the same ordered dictionary format so the relevant views can be displayed. + + Args: + feature_filter (str): The feature filter, if we want to filter the set of data to only include + a certain set of features that include feature_filter + If feature = "", then we do not filter based on any features + module_fqn_filter (str): The filter on prefix for the module fqn. All modules that have fqn with + this prefix will be included + If module_fqn_filter = "" we do not filter based on module fqn, and include all modules + + First, the data is filtered based on module_fqn, and then filtered based on feature + Returns an OrderedDict (sorted in order of model) mapping: + module_fqns -> feature_names -> values + """ + # create return dict + filtered_dict: OrderedDict[str, Any] = OrdDict() + + for module_fqn in self.generated_reports: + # first filter based on module + if module_fqn_filter == "" or module_fqn_filter in module_fqn: + # create entry for module and loop through features + filtered_dict[module_fqn] = {} + module_reports = self.generated_reports[module_fqn] + for feature_name in module_reports: + # check if filtering on features and do so if desired + if feature_filter == "" or feature_filter in feature_name: + filtered_dict[module_fqn][feature_name] = module_reports[ + feature_name + ] + + # we have populated the filtered dict, and must return it + + return filtered_dict + + def _generate_tensor_table( + self, + filtered_data: OrderedDict[str, dict[str, Any]], + tensor_features: list[str], + ) -> tuple[list, list]: + r""" + Takes in the filtered data and features list and generates the tensor headers and table + + Currently meant to generate the headers and table for both the tensor information. + + Args: + filtered_data (OrderedDict[str, Dict[str, Any]]): An OrderedDict (sorted in order of model) mapping: + module_fqns -> feature_names -> values + tensor_features (List[str]): A list of the tensor level features + + Returns a tuple with: + A list of the headers of the tensor table + A list of lists containing the table information row by row + The 0th index row will contain the headers of the columns + The rest of the rows will contain data + """ + # now we compose the tensor information table + tensor_table: list[list[Any]] = [] + tensor_headers: list[str] = [] + + # append the table row to the table only if we have features + if len(tensor_features) > 0: + # now we add all the data + for index, module_fqn in enumerate(filtered_data): + # we make a new row for the tensor table + tensor_table_row = [index, module_fqn] + for feature in tensor_features: + # we iterate in same order of added features + + if feature in filtered_data[module_fqn]: + # add value if applicable to module + feature_val = filtered_data[module_fqn][feature] + else: + # add that it is not applicable + feature_val = "Not Applicable" + + # if it's a tensor we want to extract val + if isinstance(feature_val, torch.Tensor): + feature_val = feature_val.item() + + # we add to our list of values + # pyrefly: ignore [bad-argument-type] + tensor_table_row.append(feature_val) + + tensor_table.append(tensor_table_row) + + # add row of headers of we actually have something, otherwise just empty + if len(tensor_table) != 0: + tensor_headers = ["idx", "layer_fqn"] + tensor_features + + return (tensor_headers, tensor_table) + + def _generate_channels_table( + self, + filtered_data: OrderedDict[str, Any], + channel_features: list[str], + num_channels: int, + ) -> tuple[list, list]: + r""" + Takes in the filtered data and features list and generates the channels headers and table + + Currently meant to generate the headers and table for both the channels information. + + Args: + filtered_data (OrderedDict[str, Any]): An OrderedDict (sorted in order of model) mapping: + module_fqns -> feature_names -> values + channel_features (List[str]): A list of the channel level features + num_channels (int): Number of channels in the channel data + + Returns a tuple with: + A list of the headers of the channel table + A list of lists containing the table information row by row + The 0th index row will contain the headers of the columns + The rest of the rows will contain data + """ + # now we compose the table for the channel information table + channel_table: list[list[Any]] = [] + channel_headers: list[str] = [] + + # counter to keep track of number of entries in + channel_table_entry_counter: int = 0 + + if len(channel_features) > 0: + # now we add all channel data + for module_fqn in filtered_data: + # we iterate over all channels + for channel in range(num_channels): + # we make a new row for the channel + new_channel_row = [channel_table_entry_counter, module_fqn, channel] + for feature in channel_features: + if feature in filtered_data[module_fqn]: + # add value if applicable to module + feature_val = filtered_data[module_fqn][feature][channel] + else: + # add that it is not applicable + feature_val = "Not Applicable" + + # if it's a tensor we want to extract val + if type(feature_val) is torch.Tensor: + feature_val = feature_val.item() + + # add value to channel specific row + # pyrefly: ignore [bad-argument-type] + new_channel_row.append(feature_val) + + # add to table and increment row index counter + channel_table.append(new_channel_row) + channel_table_entry_counter += 1 + + # add row of headers of we actually have something, otherwise just empty + if len(channel_table) != 0: + channel_headers = ["idx", "layer_fqn", "channel"] + channel_features + + return (channel_headers, channel_table) + + def generate_filtered_tables( + self, feature_filter: str = "", module_fqn_filter: str = "" + ) -> dict[str, tuple[list, list]]: + r""" + Takes in optional filter values and generates two tables with desired information. + + The generated tables are presented in both a list-of-lists format + + The reason for the two tables are that they handle different things: + 1.) the first table handles all tensor level information + 2.) the second table handles and displays all channel based information + + The reasoning for this is that having all the info in one table can make it ambiguous which collected + statistics are global, and which are actually per-channel, so it's better to split it up into two + tables. This also makes the information much easier to digest given the plethora of statistics collected + + Tensor table columns: + idx layer_fqn feature_1 feature_2 feature_3 .... feature_n + ---- --------- --------- --------- --------- --------- + + Per-Channel table columns: + idx layer_fqn channel feature_1 feature_2 feature_3 .... feature_n + ---- --------- ------- --------- --------- --------- --------- + + Args: + feature_filter (str, optional): Filters the features presented to only those that + contain this filter substring + Default = "", results in all the features being printed + module_fqn_filter (str, optional): Only includes modules that contains this string + Default = "", results in all the modules in the reports to be visible in the table + + Returns a dictionary with two keys: + (Dict[str, Tuple[List, List]]) A dict containing two keys: + "tensor_level_info", "channel_level_info" + Each key maps to a tuple with: + A list of the headers of each table + A list of lists containing the table information row by row + The 0th index row will contain the headers of the columns + The rest of the rows will contain data + + Example Use: + >>> # xdoctest: +SKIP("undefined variables") + >>> mod_report_visualizer.generate_filtered_tables( + ... feature_filter="per_channel_min", module_fqn_filter="block1" + ... ) # generates table with per_channel_min info for all modules in block 1 of the model + """ + # first get the filtered data + filtered_data: OrderedDict[str, Any] = self._get_filtered_data( + feature_filter, module_fqn_filter + ) + + # now we split into tensor and per-channel data + tensor_features: set[str] = set() + channel_features: set[str] = set() + + # keep track of the number of channels we have + num_channels: int = 0 + + for module_fqn in filtered_data: + for feature_name in filtered_data[module_fqn]: + # get the data for that specific feature + feature_data = filtered_data[module_fqn][feature_name] + + # check if not zero dim tensor + is_tensor: bool = isinstance(feature_data, torch.Tensor) + is_not_zero_dim: bool = is_tensor and len(feature_data.shape) != 0 + + if is_not_zero_dim or isinstance(feature_data, list): + # works means per channel + channel_features.add(feature_name) + num_channels = len(feature_data) + else: + # means is per-tensor + tensor_features.add(feature_name) + + # we make them lists for iteration purposes + tensor_features_list: list[str] = sorted(tensor_features) + channel_features_list: list[str] = sorted(channel_features) + + # get the tensor info + tensor_headers, tensor_table = self._generate_tensor_table( + filtered_data, tensor_features_list + ) + + # get the channel info + channel_headers, channel_table = self._generate_channels_table( + filtered_data, channel_features_list, num_channels + ) + + # let's now create the dictionary to return + table_dict = { + self.TABLE_TENSOR_KEY: (tensor_headers, tensor_table), + self.TABLE_CHANNEL_KEY: (channel_headers, channel_table), + } + + # return the two tables + return table_dict + + def generate_table_visualization( + self, feature_filter: str = "", module_fqn_filter: str = "" + ): + r""" + Takes in optional filter values and prints out formatted tables of the information. + + The reason for the two tables printed out instead of one large one are that they handle different things: + 1.) the first table handles all tensor level information + 2.) the second table handles and displays all channel based information + + The reasoning for this is that having all the info in one table can make it ambiguous which collected + statistics are global, and which are actually per-channel, so it's better to split it up into two + tables. This also makes the information much easier to digest given the plethora of statistics collected + + Tensor table columns: + idx layer_fqn feature_1 feature_2 feature_3 .... feature_n + ---- --------- --------- --------- --------- --------- + + Per-Channel table columns: + + idx layer_fqn channel feature_1 feature_2 feature_3 .... feature_n + ---- --------- ------- --------- --------- --------- --------- + + Args: + feature_filter (str, optional): Filters the features presented to only those that + contain this filter substring + Default = "", results in all the features being printed + module_fqn_filter (str, optional): Only includes modules that contains this string + Default = "", results in all the modules in the reports to be visible in the table + + Example Use: + >>> # xdoctest: +SKIP("undefined variables") + >>> mod_report_visualizer.generate_table_visualization( + ... feature_filter="per_channel_min", module_fqn_filter="block1" + ... ) + >>> # prints out neatly formatted table with per_channel_min info + >>> # for all modules in block 1 of the model + """ + # see if we got tabulate + if not got_tabulate: + print("Make sure to install tabulate and try again.") + return None + + # get the table dict and the specific tables of interest + table_dict = self.generate_filtered_tables(feature_filter, module_fqn_filter) + tensor_headers, tensor_table = table_dict[self.TABLE_TENSOR_KEY] + channel_headers, channel_table = table_dict[self.TABLE_CHANNEL_KEY] + + # get the table string and print it out + # now we have populated the tables for each one + # let's create the strings to be returned + table_str = "" + # the tables will have some headers columns that are non-feature + # ex. table index, module name, channel index, etc. + # we want to look at header columns for features, that come after those headers + if len(tensor_headers) > self.NUM_NON_FEATURE_TENSOR_HEADERS: + # if we have at least one tensor level feature to be added we add tensor table + table_str += "Tensor Level Information \n" + table_str += tabulate(tensor_table, headers=tensor_headers) + if len(channel_headers) > self.NUM_NON_FEATURE_CHANNEL_HEADERS: + # if we have at least one channel level feature to be added we add tensor table + table_str += "\n\n Channel Level Information \n" + table_str += tabulate(channel_table, headers=channel_headers) + + # if no features at all, let user know + if table_str == "": + table_str = "No data points to generate table with." + + print(table_str) + + def _get_plottable_data( + self, feature_filter: str, module_fqn_filter: str + ) -> tuple[list, list[list], bool]: + r""" + Takes in the feature filters and module filters and outputs the x and y data for plotting + + Args: + feature_filter (str): Filters the features presented to only those that + contain this filter substring + module_fqn_filter (str): Only includes modules that contains this string + + Returns a tuple of three elements + The first is a list containing relevant x-axis data + The second is a list containing the corresponding y-axis data + If the data is per channel + """ + # get the table dict and the specific tables of interest + table_dict = self.generate_filtered_tables(feature_filter, module_fqn_filter) + tensor_headers, tensor_table = table_dict[self.TABLE_TENSOR_KEY] + channel_headers, channel_table = table_dict[self.TABLE_CHANNEL_KEY] + + # make sure it is only 1 feature that is being plotted + # get the number of features in each of these + tensor_info_features_count = ( + len(tensor_headers) - ModelReportVisualizer.NUM_NON_FEATURE_TENSOR_HEADERS + ) + channel_info_features_count = ( + len(channel_headers) - ModelReportVisualizer.NUM_NON_FEATURE_CHANNEL_HEADERS + ) + + # see if valid tensor or channel plot + is_valid_per_tensor_plot: bool = tensor_info_features_count == 1 + is_valid_per_channel_plot: bool = channel_info_features_count == 1 + + # offset should either be one of tensor or channel table or neither + feature_column_offset = ModelReportVisualizer.NUM_NON_FEATURE_TENSOR_HEADERS + table = tensor_table + + # if a per_channel plot, we have different offset and table + if is_valid_per_channel_plot: + feature_column_offset = ( + ModelReportVisualizer.NUM_NON_FEATURE_CHANNEL_HEADERS + ) + table = channel_table + + x_data: list = [] + y_data: list[list] = [] + # the feature will either be a tensor feature or channel feature + if is_valid_per_tensor_plot: + for table_row_num, row in enumerate(table): + # get x_value to append + x_val_to_append = table_row_num + # the index of the feature will the 0 + num non feature columns + tensor_feature_index = feature_column_offset + row_value = row[tensor_feature_index] + if type(row_value) is not str: + x_data.append(x_val_to_append) + y_data.append(row_value) + elif is_valid_per_channel_plot: + # gather the x_data and multiple y_data + # calculate the number of channels + num_channels: int = max(row[self.CHANNEL_NUM_INDEX] for row in table) + 1 + + # separate data list per channel + y_data.extend([] for _ in range(num_channels)) + + for table_row_num, row in enumerate(table): + # get x_value to append + x_val_to_append = table_row_num + current_channel = row[ + self.CHANNEL_NUM_INDEX + ] # initially chose current channel + new_module_index: int = table_row_num // num_channels + x_val_to_append = new_module_index + + # the index of the feature will the 0 + num non feature columns + tensor_feature_index = feature_column_offset + row_value = row[tensor_feature_index] + if type(row_value) is not str: + # only append if new index we are appending + if len(x_data) == 0 or x_data[-1] != x_val_to_append: + x_data.append(x_val_to_append) + + # append value for that channel + y_data[current_channel].append(row_value) + else: + # more than one feature was chosen + error_str = "Make sure to pick only a single feature with your filter to plot a graph." + error_str += " We recommend calling get_all_unique_feature_names() to find unique feature names." + error_str += " Pick one of those features to plot." + raise ValueError(error_str) + + # return x, y values, and if data is per-channel + return (x_data, y_data, is_valid_per_channel_plot) + + def generate_plot_visualization( + self, feature_filter: str, module_fqn_filter: str = "" + ): + r""" + Takes in a feature and optional module_filter and plots of the desired data. + + For per channel features, it averages the value across the channels and plots a point + per module. The reason for this is that for models with hundreds of channels, it can + be hard to differentiate one channel line from another, and so the point of generating + a single average point per module is to give a sense of general trends that encourage + further deep dives. + + Note: + Only features in the report that have tensor value data are plottable by this class + When the tensor information is plotted, it will plot: + idx as the x val, feature value as the y_val + When the channel information is plotted, it will plot: + the first idx of each module as the x val, feature value as the y_val [for each channel] + The reason for this is that we want to be able to compare values across the + channels for same layer, and it will be hard if values are staggered by idx + This means each module is represented by only 1 x value + Args: + feature_filter (str): Filters the features presented to only those that + contain this filter substring + module_fqn_filter (str, optional): Only includes modules that contains this string + Default = "", results in all the modules in the reports to be visible in the table + + Example Use: + >>> # xdoctest: +SKIP("undefined variables") + >>> mod_report_visualizer.generate_plot_visualization( + ... feature_filter="per_channel_min", module_fqn_filter="block1" + ... ) + >>> # outputs line plot of per_channel_min information for all + >>> # modules in block1 of model each channel gets it's own line, + >>> # and it's plotted across the in-order modules on the x-axis + """ + # checks if we have matplotlib and let's user know to install it if don't + if not got_matplotlib: + print("make sure to install matplotlib and try again.") + return None + + # get the x and y data and if per channel + x_data, y_data, data_per_channel = self._get_plottable_data( + feature_filter, module_fqn_filter + ) + + # plot based on whether data is per channel or not + ax = plt.subplot() + ax.set_ylabel(feature_filter) + ax.set_title(feature_filter + " Plot") + plt.xticks(x_data) # only show ticks for actual points + + if data_per_channel: + ax.set_xlabel("First idx of module") + # set the legend as well + # plot a single line that is average of the channel values + num_modules = len( + y_data[0] + ) # all y_data have same length, so get num modules + num_channels = len( + y_data + ) # we want num channels to be able to calculate average later + + avg_vals = [ + sum(y_data[:][index]) / num_channels for index in range(num_modules) + ] + + # plot the three things we measured + ax.plot( + x_data, avg_vals, label=f"Average Value Across {num_channels} Channels" + ) + ax.legend(loc="upper right") + else: + ax.set_xlabel("idx") + ax.plot(x_data, y_data) + + # actually show the plot + plt.show() + + def generate_histogram_visualization( + self, feature_filter: str, module_fqn_filter: str = "", num_bins: int = 10 + ): + r""" + Takes in a feature and optional module_filter and plots the histogram of desired data. + + Note: + Only features in the report that have tensor value data can be viewed as a histogram + If you want to plot a histogram from all the channel values of a specific feature for + a specific model, make sure to specify both the model and the feature properly + in the filters and you should be able to see a distribution of the channel data + + Args: + feature_filter (str, optional): Filters the features presented to only those that + contain this filter substring + Default = "", results in all the features being printed + module_fqn_filter (str, optional): Only includes modules that contains this string + Default = "", results in all the modules in the reports to be visible in the table + num_bins (int, optional): The number of bins to create the histogram with + Default = 10, the values will be split into 10 equal sized bins + + Example Use: + >>> # xdoctest: +SKIP + >>> mod_report_visualizer.generategenerate_histogram_visualization_plot_visualization( + ... feature_filter="per_channel_min", module_fqn_filter="block1" + ... ) + # outputs histogram of per_channel_min information for all modules in block1 of model + information is gathered across all channels for all modules in block 1 for the + per_channel_min and is displayed in a histogram of equally sized bins + """ + # checks if we have matplotlib and let's user know to install it if don't + if not got_matplotlib: + print("make sure to install matplotlib and try again.") + return None + + # get the x and y data and if per channel + _x_data, y_data, data_per_channel = self._get_plottable_data( + feature_filter, module_fqn_filter + ) + + # for histogram, we just care about plotting the y data + # plot based on whether data is per channel or not + ax = plt.subplot() + ax.set_xlabel(feature_filter) + ax.set_ylabel("Frequency") + ax.set_title(feature_filter + " Histogram") + + if data_per_channel: + # set the legend as well + # combine all the data + all_data = [] + for channel_info in y_data: + all_data.extend(channel_info) + + _val, bins, _ = plt.hist( + all_data, + bins=num_bins, + stacked=True, + rwidth=0.8, + ) + plt.xticks(bins) + else: + _val, bins, _ = plt.hist( + y_data, + bins=num_bins, + stacked=False, + rwidth=0.8, + ) + plt.xticks(bins) + + plt.show() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/observer.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/observer.py new file mode 100644 index 0000000000000000000000000000000000000000..abb81c2a54d0091e16ff7cbbf6ef6bb2112485de --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/observer.py @@ -0,0 +1,2155 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +# temporarily skip RUF for this file for now, we can re-enable +# after move the affine quantization related things to torchao +# noqa: RUF +""" +This module implements observers which are used to collect statistics about +the values observed during calibration (PTQ) or training (QAT). +""" + +import operator +import re +import warnings +from abc import ABCMeta, abstractmethod +from collections import OrderedDict +from functools import partial +from typing import Any + +import torch +import torch.nn as nn +from torch.ao.quantization.utils import ( + calculate_qmin_qmax, + check_min_max_valid, + is_per_channel, + is_per_tensor, + validate_qmin_qmax, +) +from torch.fx import Node + + +__all__ = [ + "default_affine_fixed_qparams_observer", + "default_debug_observer", + "default_dynamic_quant_observer", + "default_fixed_qparams_range_0to1_observer", + "default_fixed_qparams_range_neg1to1_observer", + "default_float_qparams_observer", + "default_float_qparams_observer_4bit", + "default_histogram_observer", + "default_observer", + "default_per_channel_weight_observer", + "default_placeholder_observer", + "default_reuse_input_observer", + "default_symmetric_fixed_qparams_observer", + "default_weight_observer", + "get_observer_state_dict", + "load_observer_state_dict", + "per_channel_weight_observer_range_neg_127_to_127", + "weight_observer_range_neg_127_to_127", + "FixedQParamsObserver", + "HistogramObserver", + "MinMaxObserver", + "MovingAverageMinMaxObserver", + "MovingAveragePerChannelMinMaxObserver", + "NoopObserver", + "ObserverBase", + "PerChannelMinMaxObserver", + "PlaceholderObserver", + "RecordingObserver", + "ReuseInputObserver", + "UniformQuantizationObserverBase", + "AffineQuantizedObserverBase", + "Granularity", + "MappingType", + "PerAxis", + "PerBlock", + "PerGroup", + "PerRow", + "PerTensor", + "PerToken", + "TorchAODType", + "ZeroPointDomain", + "get_block_size", +] + + +class _PartialWrapper: + def __init__(self, p): + self.p = p + self.callable_args = {} + + def __call__(self, *args, **keywords): + # call each arg in callable_args and add them partial, then run with keywords + # skip if arg_name in keywords so its possible to overwrite + for arg_name in self.callable_args: + if arg_name not in keywords: + keywords = {**keywords, arg_name: self.callable_args[arg_name]()} + return self.p(*args, **keywords) + + def __repr__(self): + return self.p.__repr__() + self.callable_args.__repr__() + + def with_args(self, **kwargs): + return _with_args(self, **kwargs) + + def with_callable_args(self, **kwargs): + result = _PartialWrapper(p=self.p) + result.callable_args = {**self.callable_args, **kwargs} + return result + + +def _with_args(cls_or_self, **kwargs): + r"""Wrapper that allows creation of class factories. + + This can be useful when there is a need to create classes with the same + constructor arguments, but different instances. Can be used in conjunction with + _callable_args + + Example:: + + >>> # xdoctest: +SKIP("Undefined vars") + >>> Foo.with_args = classmethod(_with_args) + >>> foo_builder = Foo.with_args(a=3, b=4).with_args(answer=42) + >>> foo_instance1 = foo_builder() + >>> foo_instance2 = foo_builder() + >>> id(foo_instance1) == id(foo_instance2) + False + """ + r = _PartialWrapper(partial(cls_or_self, **kwargs)) + return r + + +def _with_callable_args(cls_or_self, **kwargs): + r"""Wrapper that allows creation of class factories args that need to be + called at construction time. + + This can be useful when there is a need to create classes with the same + constructor arguments, but different instances and those arguments should only + be calculated at construction time. Can be used in conjunction with _with_args + + Example:: + + >>> # xdoctest: +SKIP("Undefined vars") + >>> Foo.with_callable_args = classmethod(_with_callable_args) + >>> Foo.with_args = classmethod(_with_args) + >>> foo_builder = Foo.with_callable_args(cur_time=get_time_func).with_args(name="dan") + >>> foo_instance1 = foo_builder() + >>> # wait 50 + >>> foo_instance2 = foo_builder() + >>> id(foo_instance1.creation_time) == id(foo_instance2.creation_time) + False + """ + r = _PartialWrapper(partial(cls_or_self)) + return r.with_callable_args(**kwargs) + + +ABC: Any = ABCMeta("ABC", (object,), {}) # compatible with Python 2 *and* 3: + + +class ObserverBase(ABC, nn.Module): + r"""Base observer Module. + Any observer implementation should derive from this class. + + Concrete observers should follow the same API. In forward, they will update + the statistics of the observed Tensor. And they should provide a + `calculate_qparams` function that computes the quantization parameters given + the collected statistics. + + Args: + dtype: dtype argument to the `quantize` node needed to implement the + reference model spec. + is_dynamic: indicator for whether the observer is a placeholder for dynamic quantization + or static quantization + """ + + def __init__(self, dtype, is_dynamic: bool = False): + super().__init__() + self.dtype = dtype + self.is_dynamic = is_dynamic + + @abstractmethod + def forward(self, x): + pass + + @abstractmethod + def calculate_qparams(self, **kwargs): + pass + + with_args = classmethod(_with_args) + with_callable_args = classmethod(_with_callable_args) + + +class UniformQuantizationObserverBase(ObserverBase): + r"""Common base for all observers using uniform quantization to calculate + scale and zero_point. + + Args: + dtype: dtype argument to the `quantize` node needed to implement the + reference model spec. + qscheme: Quantization scheme to be used. + reduce_range: Reduces the range of the quantized data type by 1 bit. + This is sometimes required to avoid instruction overflow. + quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup. + quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup. + eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`. + + .. warning:: + + :attr:`dtype` can only take ``torch.qint8`` or ``torch.quint8``. + or `torch.int8` or `torch.uint8` + + .. warning:: + + :attr:`qscheme` can only take one of the following options: + + - ``torch.per_tensor_affine`` + - ``torch.per_tensor_symmetric`` + - ``torch.per_channel_affine`` + - ``torch.per_channel_symmetric`` + """ + + # Note: the version is shared by all observer types + # + # Version 1/None + # self + # + # Version 2 (base class only, does not include child class buffers) + # self + # |--- eps : Tensor + # + # Version 3 + # for HistogramObserver only, changed the shape of uninitialized + # min_val and max_val buffers from torch.Size([0]) to torch.Size([]) + # for PerChannelObservers, changed the name of the buffers from min_vals + # to min_val and from max_vals to max_val. + _version = 3 + + eps: torch.Tensor + + def __init__( + self, + dtype=torch.quint8, + qscheme=torch.per_tensor_affine, + reduce_range=False, + quant_min=None, + quant_max=None, + factory_kwargs=None, + eps=torch.finfo(torch.float32).eps, + is_dynamic=False, + **kwargs, + ) -> None: + factory_kwargs = torch.nn.factory_kwargs(factory_kwargs) + super().__init__(dtype=dtype, is_dynamic=is_dynamic, **kwargs) + self.qscheme = qscheme + if reduce_range: + warnings.warn( + "Please use quant_min and quant_max to specify the range for observers. \ + reduce_range will be deprecated in a future release of PyTorch.", + stacklevel=2, + ) + self.reduce_range = reduce_range + self.register_buffer("eps", torch.tensor([eps], **factory_kwargs)) + if self.qscheme not in ( + torch.per_tensor_affine, + torch.per_tensor_symmetric, + torch.per_channel_affine, + torch.per_channel_symmetric, + torch.per_channel_affine_float_qparams, + ): + raise AssertionError( + "Default Observer only works for per_tensor_affine, per_tensor_symmetric, " + "per_channel_affine, per_channel_symmetric and per_channel_float_qparams quantization scheme" + ) + + _ALLOWED_DTYPES = ( + torch.qint8, + torch.quint8, + torch.quint4x2, + torch.qint32, + torch.int8, + torch.uint8, + torch.int16, + torch.int32, + torch.float8_e5m2, + torch.float8_e4m3fn, + torch.uint16, + ) + + if self.dtype not in _ALLOWED_DTYPES: + raise AssertionError( + f"Default Observer only works for {_ALLOWED_DTYPES} data type" + ) + self.has_customized_qrange = (quant_min is not None) and (quant_max is not None) + if self.has_customized_qrange: + # pyrefly: ignore [bad-argument-type] + validate_qmin_qmax(quant_min, quant_max) + self.quant_min, self.quant_max = calculate_qmin_qmax( + # pyrefly: ignore [bad-argument-type] + quant_min, + # pyrefly: ignore [bad-argument-type] + quant_max, + self.has_customized_qrange, + self.dtype, + self.reduce_range, + ) + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + version = local_metadata.get("version", None) + + if version is None or version == 1: + # eps was moved to a buffer in version 2 + eps = torch.tensor([torch.finfo(torch.float32).eps]) + state_dict[prefix + "eps"] = eps + + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + + @torch.jit.export + def _validate_qmin_qmax(self, quant_min: int, quant_max: int) -> None: + r"""Validates that the user-specified quantization range is properly initialized + and within the given bound supported by the observer dtype. + + To accommodate lower-bit quantization with respect to the existing torch.qint8 and + torch.quint8 datatypes, the user can choose to use dynamic quantization range by passing + in a tuple of initial qmin and qmax values. One use case is these customized qmin and qmax + values are used to calculate static estimates of the scale and zero point for aggressive lower-bit + fake quantization. These estimates are compared against parameters learned through backpropagation. + The related literatures for scale and zero point via backpropagation are as follows: + + Learned Step Size Quantization: https://openreview.net/pdf?id=rkgO66VKDS + Trained Quantization Thresholds: https://arxiv.org/pdf/1903.08066.pdf + """ + # The variable names are prefixed with "initial" because their values (qmin and qmax) might be adjusted + # based on whether quantization range is reduced and the datatype (signed/unsigned) used by the observer. + if not quant_min <= 0 <= quant_max: + raise AssertionError("Used-specified quantization range must include 0.") + if quant_min >= quant_max: + raise AssertionError( + "qmin must be strictly less than qmax for user-specified quantization range." + ) + + @torch.jit.export + def _calculate_qparams( + self, min_val: torch.Tensor, max_val: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + r"""Calculates the quantization parameters, given min and max + value tensors. Works for both per tensor and per channel cases + + Args: + min_val: Minimum values per channel + max_val: Maximum values per channel + + Returns: + scales: Scales tensor of shape (#channels,) + zero_points: Zero points tensor of shape (#channels,) + """ + # Functionally equivalent to 'determine_qparams' in utils.py. Observers must be torchscriptable however and qscheme + # as far as I can tell is not allowed to passed as a parameter in torchscript functions. This makes refactoring observer + # to use this utility a massive pain and very gross. For now Im opting just to duplicate as this code + # seems unlikely to change (last update over 1 year ago) and when torchscript is fully deprecated we can refactor. + # TODO(jakeszwe, jerryzh168) + if not check_min_max_valid(min_val, max_val): + return torch.tensor([1.0], device=min_val.device.type), torch.tensor( + [0], device=min_val.device.type + ) + + quant_min, quant_max = self.quant_min, self.quant_max + min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) + max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) + + device = min_val_neg.device + scale = torch.ones(min_val_neg.size(), dtype=torch.float32, device=device) + zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device) + + if ( + self.qscheme == torch.per_tensor_symmetric + or self.qscheme == torch.per_channel_symmetric + ): + max_val_pos = torch.max(-min_val_neg, max_val_pos) + scale = max_val_pos / (float(quant_max - quant_min) / 2) + scale = torch.max(scale, self.eps) + if self.dtype in [torch.quint8, torch.uint8]: + if self.has_customized_qrange: + # When customized quantization range is used, down-rounded midpoint of the range is chosen. + zero_point = zero_point.new_full( + zero_point.size(), (quant_min + quant_max) // 2 + ) + else: + zero_point = zero_point.new_full(zero_point.size(), 128) + elif self.dtype == torch.uint16: + zero_point = zero_point.new_full(zero_point.size(), 2**15) + elif self.qscheme == torch.per_channel_affine_float_qparams: + scale = (max_val - min_val) / float(quant_max - quant_min) + scale = torch.where(scale > self.eps, scale, torch.ones_like(scale)) + # We use the quantize function + # xq = Round(Xf * inv_scale + zero_point), + # setting zero_point to (-1 * min *inv_scale) we get + # Xq = Round((Xf - min) * inv_scale) + zero_point = -1 * min_val / scale + else: + scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min) + scale = torch.max(scale, self.eps) + zero_point = quant_min - torch.round(min_val_neg / scale).to(torch.int) + zero_point = torch.clamp(zero_point, quant_min, quant_max) + + # For scalar values, cast them to Tensors of size 1 to keep the shape + # consistent with default values in FakeQuantize. + if len(scale.shape) == 0: + # TODO: switch to scale.item() after adding JIT support + scale = torch.tensor([float(scale)], dtype=scale.dtype, device=device) + if len(zero_point.shape) == 0: + # TODO: switch to zero_point.item() after adding JIT support + zero_point = torch.tensor( + [int(zero_point)], dtype=zero_point.dtype, device=device + ) + if self.qscheme == torch.per_channel_affine_float_qparams: + zero_point = torch.tensor( + [float(zero_point)], dtype=zero_point.dtype, device=device + ) + + return scale, zero_point + + @torch.jit.export + def reset_min_max_vals(self): + raise NotImplementedError("Cannot reset min/max values in the given observer.") + + +# Originally, this class was called `_ObserverBase`. Keeping the old name around +# for backwards compatibility. +# TODO(after v1.13): delete this +_ObserverBase = UniformQuantizationObserverBase + + +class MinMaxObserver(UniformQuantizationObserverBase): + r"""Observer module for computing the quantization parameters based on the + running min and max values. + + This observer uses the tensor min/max statistics to compute the quantization + parameters. The module records the running minimum and maximum of incoming + tensors, and uses this statistic to compute the quantization parameters. + + Args: + dtype: dtype argument to the `quantize` node needed to implement the + reference model spec. + qscheme: Quantization scheme to be used + reduce_range: Reduces the range of the quantized data type by 1 bit + quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup. + quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup. + eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`. + + Given running min/max as :math:`x_\text{min}` and :math:`x_\text{max}`, + scale :math:`s` and zero point :math:`z` are computed as: + + The running minimum/maximum :math:`x_\text{min/max}` is computed as: + + .. math:: + + \begin{array}{ll} + x_\text{min} &= \begin{cases} + \min(X) & \text{if~}x_\text{min} = \text{None} \\ + \min\left(x_\text{min}, \min(X)\right) & \text{otherwise} + \end{cases}\\ + x_\text{max} &= \begin{cases} + \max(X) & \text{if~}x_\text{max} = \text{None} \\ + \max\left(x_\text{max}, \max(X)\right) & \text{otherwise} + \end{cases}\\ + \end{array} + + where :math:`X` is the observed tensor. + + The scale :math:`s` and zero point :math:`z` are then computed as: + + .. math:: + + \begin{aligned} + \text{if Symmetric:}&\\ + &s = 2 \max(|x_\text{min}|, x_\text{max}) / + \left( Q_\text{max} - Q_\text{min} \right) \\ + &z = \begin{cases} + 0 & \text{if dtype is qint8} \\ + 128 & \text{otherwise} + \end{cases}\\ + \text{Otherwise:}&\\ + &s = \left( x_\text{max} - x_\text{min} \right ) / + \left( Q_\text{max} - Q_\text{min} \right ) \\ + &z = Q_\text{min} - \text{round}(x_\text{min} / s) + \end{aligned} + + where :math:`Q_\text{min}` and :math:`Q_\text{max}` are the minimum and + maximum of the quantized data type. + + .. warning:: :attr:`dtype` can only take ``torch.qint8`` or ``torch.quint8``. + + .. note:: If the running minimum equals to the running maximum, the scale + and zero_point are set to 1.0 and 0. + """ + + min_val: torch.Tensor + max_val: torch.Tensor + + def __init__( + self, + dtype=torch.quint8, + qscheme=torch.per_tensor_affine, + reduce_range=False, + quant_min=None, + quant_max=None, + factory_kwargs=None, + eps=torch.finfo(torch.float32).eps, + is_dynamic=False, + **kwargs, + ) -> None: + if not is_per_tensor(qscheme): + raise NotImplementedError( + "MinMaxObserver's qscheme only support torch.per_tensor_symmetric \ + and torch.per_tensor_affine." + ) + # TODO: MinMaxObserver by itself doesn't support dynamic quantization, but + # if it's inherited by MovingAverageObserver, and averaging_constant is 1, it + # supports dynamic quantization, we may need to better error checking here + + # For x86 quantized kernels, we need to ensure that the vpmaddubsw + # instruction does not overflow. We allow for a reduce_range argument to + # observers that reduces the quantized range to (0,127) or (-64, 63). + # For more details see aten/src/ATen/native/quantized/cpu/qconv.cpp + # This is not an optimal choice for non x86 backends as it loses a bit + # of precision for activations. + super().__init__( + dtype=dtype, + qscheme=qscheme, + reduce_range=reduce_range, + quant_min=quant_min, + quant_max=quant_max, + factory_kwargs=factory_kwargs, + eps=eps, + is_dynamic=is_dynamic, + **kwargs, + ) + factory_kwargs = torch.nn.factory_kwargs(factory_kwargs) + self.register_buffer("min_val", torch.tensor(float("inf"), **factory_kwargs)) + self.register_buffer("max_val", torch.tensor(float("-inf"), **factory_kwargs)) + if ( + self.qscheme == torch.per_tensor_symmetric + and self.reduce_range + and self.dtype == torch.quint8 + ): + raise NotImplementedError( + "Cannot reduce range for symmetric \ + quantization for quint8" + ) + + def forward(self, x_orig): + r"""Records the running minimum and maximum of ``x``.""" + if x_orig.numel() == 0: + return x_orig + x = x_orig.detach() # avoid keeping autograd tape + x = x.to(self.min_val.dtype) + min_val_cur, max_val_cur = torch.aminmax(x) + min_val = torch.min(min_val_cur, self.min_val) + max_val = torch.max(max_val_cur, self.max_val) + self.min_val.copy_(min_val) + self.max_val.copy_(max_val) + return x_orig + + @torch.jit.export + def calculate_qparams(self): # type: ignore[override] + r"""Calculates the quantization parameters.""" + return self._calculate_qparams(self.min_val, self.max_val) + + @torch.jit.export + def extra_repr(self): + return f"min_val={self.min_val}, max_val={self.max_val}" + + @torch.jit.export + def reset_min_max_vals(self): + """Resets the min/max values.""" + self.min_val.copy_(torch.tensor(float("inf"))) + self.max_val.copy_(torch.tensor(float("-inf"))) + + +class MovingAverageMinMaxObserver(MinMaxObserver): + r"""Observer module for computing the quantization parameters based on the + moving average of the min and max values. + + This observer computes the quantization parameters based on the moving + averages of minimums and maximums of the incoming tensors. The module + records the average minimum and maximum of incoming tensors, and uses this + statistic to compute the quantization parameters. + + Args: + averaging_constant: Averaging constant for min/max. + dtype: dtype argument to the `quantize` node needed to implement the + reference model spec. + qscheme: Quantization scheme to be used + reduce_range: Reduces the range of the quantized data type by 1 bit + quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup. + quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup. + eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`. + + The moving average min/max is computed as follows + + .. math:: + + \begin{array}{ll} + x_\text{min} = \begin{cases} + \min(X) & \text{if~}x_\text{min} = \text{None} \\ + (1 - c) x_\text{min} + c \min(X) & \text{otherwise} + \end{cases}\\ + x_\text{max} = \begin{cases} + \max(X) & \text{if~}x_\text{max} = \text{None} \\ + (1 - c) x_\text{max} + c \max(X) & \text{otherwise} + \end{cases}\\ + \end{array} + + where :math:`x_\text{min/max}` is the running average min/max, :math:`X` is + is the incoming tensor, and :math:`c` is the ``averaging_constant``. + + The scale and zero point are then computed as in + :class:`~torch.ao.quantization.observer.MinMaxObserver`. + + .. note:: Only works with ``torch.per_tensor_affine`` quantization scheme. + + .. note:: If the running minimum equals to the running maximum, the scale + and zero_point are set to 1.0 and 0. + """ + + def __init__( + self, + averaging_constant=0.01, + dtype=torch.quint8, + qscheme=torch.per_tensor_affine, + reduce_range=False, + quant_min=None, + quant_max=None, + eps=torch.finfo(torch.float32).eps, + is_dynamic=False, + **kwargs, + ) -> None: + if not is_per_tensor(qscheme): + raise NotImplementedError( + f"MovingAverageMinMaxObserver's qscheme only support \ + torch.per_tensor_symmetric and torch.per_tensor_affine. \ + but got: {qscheme}" + ) + self.averaging_constant = averaging_constant + if is_dynamic and self.averaging_constant != 1: + raise NotImplementedError( + "MovingAverageMinMaxObserver doesn't support dynamic quantization for " + f"averaging constant of {self.averaging_constant}" + ) + super().__init__( + dtype=dtype, + qscheme=qscheme, + reduce_range=reduce_range, + quant_min=quant_min, + quant_max=quant_max, + eps=eps, + is_dynamic=is_dynamic, + **kwargs, + ) + + def forward(self, x_orig): + if x_orig.numel() == 0: + return x_orig + x = x_orig.detach() # avoid keeping autograd tape + x = x.to(self.min_val.dtype) + min_val = self.min_val + max_val = self.max_val + if min_val == float("inf") and max_val == float("-inf"): + min_val, max_val = torch.aminmax(x) + else: + min_val_cur, max_val_cur = torch.aminmax(x) + min_val = min_val + self.averaging_constant * (min_val_cur - min_val) + max_val = max_val + self.averaging_constant * (max_val_cur - max_val) + self.min_val.copy_(min_val) + self.max_val.copy_(max_val) + return x_orig + + +class PerChannelMinMaxObserver(UniformQuantizationObserverBase): + r"""Observer module for computing the quantization parameters based on the + running per channel min and max values. + + This observer uses the tensor min/max statistics to compute the per channel + quantization parameters. The module records the running minimum and maximum + of incoming tensors, and uses this statistic to compute the quantization + parameters. + + Args: + ch_axis: Channel axis + dtype: dtype argument to the `quantize` node needed to implement the + reference model spec. + qscheme: Quantization scheme to be used + reduce_range: Reduces the range of the quantized data type by 1 bit + quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup. + quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup. + eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`. + + The quantization parameters are computed the same way as in + :class:`~torch.ao.quantization.observer.MinMaxObserver`, with the difference + that the running min/max values are stored per channel. + Scales and zero points are thus computed per channel as well. + + .. note:: If the running minimum equals to the running maximum, the scales + and zero_points are set to 1.0 and 0. + """ + + min_val: torch.Tensor + max_val: torch.Tensor + + def __init__( + self, + ch_axis=0, + dtype=torch.quint8, + qscheme=torch.per_channel_affine, + reduce_range=False, + quant_min=None, + quant_max=None, + factory_kwargs=None, + eps=torch.finfo(torch.float32).eps, + is_dynamic=False, + **kwargs, + ) -> None: + if not is_per_channel(qscheme): + raise NotImplementedError( + "PerChannelMinMaxObserver's qscheme only support \ + torch.per_channel_symmetric, torch.per_channel_affine and torch.per_channel_affine_float_qparams." + ) + if is_dynamic: + raise NotImplementedError( + "PerChannelMinMaxObserver doesn't support dynamic quantization" + ) + super().__init__( + dtype=dtype, + qscheme=qscheme, + reduce_range=reduce_range, + quant_min=quant_min, + quant_max=quant_max, + factory_kwargs=factory_kwargs, + eps=eps, + is_dynamic=is_dynamic, + **kwargs, + ) + factory_kwargs = torch.nn.factory_kwargs(factory_kwargs) + self.ch_axis = ch_axis + self.register_buffer("min_val", torch.tensor([], **factory_kwargs)) + self.register_buffer("max_val", torch.tensor([], **factory_kwargs)) + if ( + self.qscheme == torch.per_channel_symmetric + and self.reduce_range + and self.dtype == torch.quint8 + ): + raise NotImplementedError( + "Cannot reduce range for symmetric quantization for quint8" + ) + + def forward(self, x_orig): + return self._forward(x_orig) + + def _forward(self, x_orig): + if x_orig.numel() == 0: + return x_orig + x = x_orig.detach() # avoid keeping autograd tape + min_val = self.min_val + max_val = self.max_val + x_dim = x.size() + + new_axis_list = [i for i in range(len(x_dim))] # noqa: C416 + new_axis_list[self.ch_axis] = 0 + new_axis_list[0] = self.ch_axis + y = x.permute(new_axis_list) + # Need to match dtype of min/max because the updates to buffers + # are done in place and types need to match for comparisons + y = y.to(self.min_val.dtype) + y = torch.flatten(y, start_dim=1) + if min_val.numel() == 0 or max_val.numel() == 0: + min_val, max_val = torch.aminmax(y, dim=1) + else: + min_val_cur, max_val_cur = torch.aminmax(y, dim=1) + min_val = torch.min(min_val_cur, min_val) + max_val = torch.max(max_val_cur, max_val) + self.min_val.resize_(min_val.shape) + self.max_val.resize_(max_val.shape) + self.min_val.copy_(min_val) + self.max_val.copy_(max_val) + return x_orig + + @torch.jit.export + def calculate_qparams(self): # type: ignore[override] + return self._calculate_qparams(self.min_val, self.max_val) + + def extra_repr(self): + return f"min_val={self.min_val}, max_val={self.max_val}" + + def _load_from_state_dict( + self, + state_dict: dict[str, Any], + prefix: str, + local_metadata: dict[str, torch.Tensor], + strict: bool, + missing_keys: list[str], + unexpected_keys: list[str], + error_msgs: list[str], + ): + version = local_metadata.get("version") + if version is not None and version < 3: + local_state = ["min_vals", "max_vals"] + expected_min_name = "min_vals" + expected_max_name = "max_vals" + else: + local_state = ["min_val", "max_val"] + expected_min_name = "min_val" + expected_max_name = "max_val" + for name in local_state: + key = prefix + name + if key in state_dict: + val = state_dict[key] + # Custom handling to allow loading min_val or max_val + # of size N into uninitialized buffers of size 0. The + # buffers are resized here, and the values are copied in + # the default state_dict loading code of the parent. + if name == expected_min_name: + self.min_val.resize_(val.shape) + elif name == expected_max_name: + self.max_val.resize_(val.shape) + else: + warnings.warn( + f"Observer load_from_state_dict got unexpected name {name}", + stacklevel=2, + ) + # For torchscript module we need to update the attributes here since we do not + # call the `_load_from_state_dict` function defined module.py + if torch.jit.is_scripting(): + if name == expected_min_name: + self.min_val.copy_(val) + elif name == expected_max_name: + self.max_val.copy_(val) + else: + warnings.warn( + f"Observer load_from_state_dict got unexpected name {name}", + stacklevel=2, + ) + elif strict: + missing_keys.append(key) + + if not torch.jit.is_scripting(): + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + False, + missing_keys, + unexpected_keys, + error_msgs, + ) + + def _load_from_state_dict_script( + self, + state_dict: dict[str, Any], + prefix: str, + local_metadata: dict[str, torch.Tensor], + strict: bool, + missing_keys: list[str], + unexpected_keys: list[str], + error_msgs: list[str], + ): + self._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + + @torch.jit.export + def reset_min_max_vals(self): + """Resets the min/max values.""" + # This used to be torch.ones but that does not work because + # JIT compiler can optimize it via common subexpression elimination + # in which case both min_val and max_val point to the same tensor. + self.min_val = torch.rand( + 0, + ) + self.max_val = torch.rand( + 0, + ) + + +class MovingAveragePerChannelMinMaxObserver(PerChannelMinMaxObserver): + r"""Observer module for computing the quantization parameters based on the + running per channel min and max values. + + This observer uses the tensor min/max statistics to compute the per channel + quantization parameters. The module records the running minimum and maximum + of incoming tensors, and uses this statistic to compute the quantization + parameters. + + Args: + averaging_constant: Averaging constant for min/max. + ch_axis: Channel axis + dtype: Quantized data type + qscheme: Quantization scheme to be used + reduce_range: Reduces the range of the quantized data type by 1 bit + quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup. + quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup. + eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`. + + The quantization parameters are computed the same way as in + :class:`~torch.ao.quantization.observer.MovingAverageMinMaxObserver`, with the + difference that the running min/max values are stored per channel. + Scales and zero points are thus computed per channel as well. + + .. note:: If the running minimum equals to the running maximum, the scales + and zero_points are set to 1.0 and 0. + """ + + def __init__( + self, + averaging_constant=0.01, + ch_axis=0, + dtype=torch.quint8, + qscheme=torch.per_channel_affine, + reduce_range=False, + quant_min=None, + quant_max=None, + eps=torch.finfo(torch.float32).eps, + is_dynamic=False, + **kwargs, + ) -> None: + if not is_per_channel(qscheme): + raise NotImplementedError( + "MovingAveragePerChannelMinMaxObserver's qscheme only support \ + torch.per_channel_symmetric, torch.per_channel_affine and torch.per_channel_affine_float_qparams." + ) + if is_dynamic: + raise NotImplementedError( + "MovingAveragePerChannelMinMaxObserver doesn't support dynamic quantization" + ) + super().__init__( + ch_axis=ch_axis, + dtype=dtype, + qscheme=qscheme, + reduce_range=reduce_range, + quant_min=quant_min, + quant_max=quant_max, + eps=eps, + is_dynamic=is_dynamic, + **kwargs, + ) + self.averaging_constant = averaging_constant + + def forward(self, x_orig): + if x_orig.numel() == 0: + return x_orig + x = x_orig.detach() # avoid keeping autograd tape + x = x.to(self.min_val.dtype) + min_val = self.min_val + max_val = self.max_val + x_dim = x.size() + + new_axis_list = [i for i in range(len(x_dim))] # noqa: C416 + new_axis_list[self.ch_axis] = 0 + new_axis_list[0] = self.ch_axis + y = x.permute(new_axis_list) + y = torch.flatten(y, start_dim=1) + if min_val.numel() == 0 or max_val.numel() == 0: + min_val, max_val = torch.aminmax(y, dim=1) + else: + min_val_cur, max_val_cur = torch.aminmax(y, dim=1) + min_val = min_val + self.averaging_constant * (min_val_cur - min_val) + max_val = max_val + self.averaging_constant * (max_val_cur - max_val) + self.min_val.resize_(min_val.shape) + self.max_val.resize_(max_val.shape) + self.min_val.copy_(min_val) + self.max_val.copy_(max_val) + return x_orig + + +class HistogramObserver(UniformQuantizationObserverBase): + r""" + The module records the running histogram of tensor values along with + min/max values. ``calculate_qparams`` will calculate scale and zero_point. + + Args: + bins: Number of bins to use for the histogram + dtype: dtype argument to the `quantize` node needed to implement the + reference model spec + qscheme: Quantization scheme to be used + reduce_range: Reduces the range of the quantized data type by 1 bit + eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`. + + The scale and zero point are computed as follows: + + 1. Create the histogram of the incoming inputs. + The histogram is computed continuously, and the ranges per bin change + with every new tensor observed. + 2. Search the distribution in the histogram for optimal min/max values. + The search for the min/max values ensures the minimization of the + quantization error with respect to the floating point model. + 3. Compute the scale and zero point the same way as in the + :class:`~torch.ao.quantization.MinMaxObserver` + """ + + histogram: torch.Tensor + min_val: torch.Tensor + max_val: torch.Tensor + + def __init__( + self, + bins: int = 2048, + dtype: torch.dtype = torch.quint8, + qscheme=torch.per_tensor_affine, + reduce_range=False, + quant_min=None, + quant_max=None, + factory_kwargs=None, + eps=torch.finfo(torch.float32).eps, + is_dynamic=False, + **kwargs, + ) -> None: + if not is_per_tensor(qscheme): + raise NotImplementedError( + "HistogramObserver's qscheme only support torch.per_tensor_symmetric \ + and torch.per_tensor_affine." + ) + if is_dynamic: + raise NotImplementedError( + "HistogramObserver doesn't support dynamic quantization" + ) + # bins: The number of bins used for histogram calculation. + super().__init__( + dtype=dtype, + qscheme=qscheme, + reduce_range=reduce_range, + quant_min=quant_min, + quant_max=quant_max, + factory_kwargs=factory_kwargs, + eps=eps, + is_dynamic=is_dynamic, + **kwargs, + ) + factory_kwargs = torch.nn.factory_kwargs(factory_kwargs) + self.bins = bins + self.register_buffer("histogram", torch.zeros(self.bins, **factory_kwargs)) + self.register_buffer("min_val", torch.tensor(float("inf"), **factory_kwargs)) + self.register_buffer("max_val", torch.tensor(float("-inf"), **factory_kwargs)) + self.dst_nbins = 2 ** torch.iinfo(self.dtype).bits + self.upsample_rate = ( + 16 # used to reduce quantization errors when upscaling histogram + ) + + def _get_norm( + self, delta_begin: torch.Tensor, delta_end: torch.Tensor, density: torch.Tensor + ) -> torch.Tensor: + r""" + Compute the norm of the values uniformaly distributed between + delta_begin and delta_end. + Currently only L2 norm is supported. + + norm = density * (integral_{begin, end} x^2) + = density * (end^3 - begin^3) / 3 + """ + norm = ( + delta_end * delta_end * delta_end - delta_begin * delta_begin * delta_begin + ) / 3 + return density * norm + + def _compute_quantization_error(self, next_start_bin: int, next_end_bin: int): + r""" + Compute the quantization error if we use start_bin to end_bin as the + min and max to do the quantization. + """ + bin_width = (self.max_val.item() - self.min_val.item()) / self.bins + + dst_bin_width = bin_width * (next_end_bin - next_start_bin + 1) / self.dst_nbins + if dst_bin_width == 0.0: + return 0.0 + + src_bin = torch.arange(self.bins, device=self.histogram.device) + # distances from the beginning of first dst_bin to the beginning and + # end of src_bin + src_bin_begin = (src_bin - next_start_bin) * bin_width + src_bin_end = src_bin_begin + bin_width + + # which dst_bins the beginning and end of src_bin belong to? + dst_bin_of_begin = torch.clamp( + torch.div(src_bin_begin, dst_bin_width, rounding_mode="floor"), + 0, + self.dst_nbins - 1, + ) + dst_bin_of_begin_center = (dst_bin_of_begin + 0.5) * dst_bin_width + + dst_bin_of_end = torch.clamp( + torch.div(src_bin_end, dst_bin_width, rounding_mode="floor"), + 0, + self.dst_nbins - 1, + ) + density = self.histogram / bin_width + + norm = torch.zeros(self.bins, device=self.histogram.device) + + delta_begin = src_bin_begin - dst_bin_of_begin_center + delta_end = dst_bin_width / 2 + norm += self._get_norm( + delta_begin, + torch.ones(self.bins, device=self.histogram.device) * delta_end, + density, + ) + + norm += (dst_bin_of_end - dst_bin_of_begin - 1) * self._get_norm( + torch.tensor(-dst_bin_width / 2), torch.tensor(dst_bin_width / 2), density + ) + + dst_bin_of_end_center = dst_bin_of_end * dst_bin_width + dst_bin_width / 2 + + delta_begin = -dst_bin_width / 2 + delta_end = src_bin_end - dst_bin_of_end_center + norm += self._get_norm(torch.tensor(delta_begin), delta_end, density) + + return norm.sum().item() + + def _non_linear_param_search(self) -> tuple[torch.Tensor, torch.Tensor]: + r"""Non-linear parameter search. + + An approximation for L2 error minimization for selecting min/max. + By selecting new min/max, we filter out outliers in input distribution. + This follows the implementation of NormMinimization::NonlinearQuantizationParamsSearch in + caffe2/quantization/server/norm_minimization.cc + """ + if self.histogram.size()[0] != self.bins: + raise AssertionError("bins mismatch") + bin_width = (self.max_val - self.min_val) / self.bins + + # cumulative sum + total = torch.sum(self.histogram).item() + cSum = torch.cumsum(self.histogram, dim=0) + + stepsize = 1e-5 # granularity + alpha = 0.0 # lower bound + beta = 1.0 # upper bound + start_bin = 0 + end_bin = self.bins - 1 + norm_min = float("inf") + + while alpha < beta: + # Find the next step + next_alpha = alpha + stepsize + next_beta = beta - stepsize + + # find the left and right bins between the quantile bounds + l = start_bin + r = end_bin + while l < end_bin and cSum[l] < next_alpha * total: + l = l + 1 + while r > start_bin and cSum[r] > next_beta * total: + r = r - 1 + + # decide the next move + next_start_bin = start_bin + next_end_bin = end_bin + if (l - start_bin) > (end_bin - r): + # move the start bin + next_start_bin = l + alpha = next_alpha + else: + # move the end bin + next_end_bin = r + beta = next_beta + + if next_start_bin == start_bin and next_end_bin == end_bin: + continue + + # calculate the quantization error using next_start_bin and next_end_bin + norm = self._compute_quantization_error(next_start_bin, next_end_bin) + + if norm > norm_min: + break + norm_min = norm + start_bin = next_start_bin + end_bin = next_end_bin + + new_min = self.min_val + bin_width * start_bin + new_max = self.min_val + bin_width * (end_bin + 1) + return new_min, new_max + + def _upscale_histogram( + self, + histogram: torch.Tensor, + orig_min: torch.Tensor, + orig_max: torch.Tensor, + update_min: torch.Tensor, + update_max: torch.Tensor, + ): + # this turns the histogram into a more fine-coarsed histogram to reduce + # bin quantization errors + histogram = histogram.repeat_interleave(self.upsample_rate) / self.upsample_rate + bin_size = (orig_max - orig_min) / (self.bins * self.upsample_rate) + mid_points_histogram = ( + torch.linspace( + orig_min, + orig_max, + self.bins * self.upsample_rate + 1, + device=orig_min.device, + )[:-1].to(histogram.device) + + 0.5 * bin_size + ) + boundaries_new_histogram = torch.linspace( + update_min, update_max, self.bins + 1, device=update_min.device + ).to(histogram.device) + # this maps the mid-points of the histogram to the new histogram's space + bucket_assignments = ( + torch.bucketize(mid_points_histogram, boundaries_new_histogram, right=True) + - 1 + ) + # this then maps the histogram mid-points in the new space, weighted by the original histogram's values + # this is just the old histogram in the new histogram's space + + # In case due to numerical issues the values land higher/lower than the maximum/minimum + bucket_assignments[bucket_assignments >= self.bins] = self.bins - 1 + bucket_assignments[bucket_assignments < 0] = 0 + + update_histogram = torch.bincount( + bucket_assignments, weights=histogram, minlength=self.bins + ) + return update_histogram + + def _combine_histograms( + self, + orig_hist: torch.Tensor, + orig_min: torch.Tensor, + orig_max: torch.Tensor, + update_hist: torch.Tensor, + update_min: torch.Tensor, + update_max: torch.Tensor, + ) -> torch.Tensor: + # If the new min and max are the same as the current min and max, + # we can just add the new histogram to the original histogram + if update_min == orig_min and update_max == orig_max: + return orig_hist + update_hist + + # If the orig hist only has one value (i.e., the min and max are the same) + # we can just add it into new histogram + if orig_min == orig_max: + bin_value = torch.sum(orig_hist) + transformed_orig_hist = ( + torch.histc(orig_min, bins=self.bins, min=update_min, max=update_max) # type: ignore[arg-type] + * bin_value + ) + return transformed_orig_hist + update_hist + + # We assume the update_hist is already in the target range, we will map the orig_max to it + if update_min > orig_min: + raise AssertionError("update_min must be <= orig_min") + if update_max < orig_max: + raise AssertionError("update_max must be >= orig_max") + + # Now we need to turn the old_histogram, into the range of the new histogram + transformed_orig_hist = self._upscale_histogram( + orig_hist, + orig_min, + orig_max, + update_min, + update_max, + ) + + return update_hist + transformed_orig_hist + + def reset_histogram( + self, x: torch.Tensor, min_val: torch.Tensor, max_val: torch.Tensor + ) -> None: + self.min_val.resize_(min_val.shape) + self.min_val.copy_(min_val) + self.max_val.resize_(max_val.shape) + self.max_val.copy_(max_val) + if min_val.numel() != 1 or max_val.numel() != 1: + raise AssertionError("histogram min/max values must be scalar.") + new_histogram = torch.histc(x, self.bins, min=min_val, max=max_val) # type: ignore[arg-type] + self.histogram.detach_().resize_(new_histogram.shape) + self.histogram.copy_(new_histogram) + + def forward(self, x_orig: torch.Tensor) -> torch.Tensor: # pyre-ignore[14] + if x_orig.numel() == 0: + return x_orig + x = x_orig.detach() + x_min, x_max = torch.aminmax(x) + # want to ignore torch.inf since we don't actually + # want to make our quantization range infinite + # and in practice those values will be clamped + if x_min == -torch.inf or x_max == torch.inf: + warnings.warn( + "torch.inf detected in input tensor, ignoring input", stacklevel=2 + ) + x = x[x.abs() != torch.inf] + if x.numel() == 0: + return x_orig + x_min, x_max = torch.aminmax(x) + + current_min = self.min_val + current_max = self.max_val + + is_uninitialized = self.min_val == float("inf") or self.max_val == float("-inf") + if is_uninitialized: + self.reset_histogram(x, x_min, x_max) + else: + update_min, update_max = x_min, x_max + new_min = torch.min(current_min, update_min) + new_max = torch.max(current_max, update_max) + + # TODO: For some reason, this is required for it to pass torchscript test + # new_min and new_max should already have requires_grad set to False + new_min, new_max = new_min.detach(), new_max.detach() + update_histogram = torch.histc( + x, + self.bins, + min=new_min, # type: ignore[arg-type] + max=new_max, # type: ignore[arg-type] + ).to(self.histogram.device) + if new_min == current_min and new_max == current_max: + combined_histogram = self.histogram + update_histogram + self.histogram.detach_().resize_(combined_histogram.shape) + self.histogram.copy_(combined_histogram) + else: + combined_histogram = self._combine_histograms( + self.histogram, + current_min, + current_max, + update_histogram, + new_min, + new_max, + ) + self.histogram.detach_().resize_(combined_histogram.shape) + self.histogram.copy_(combined_histogram) + self.min_val.detach_().resize_(new_min.shape) + self.min_val.copy_(new_min) + self.max_val.detach_().resize_(new_max.shape) + self.max_val.copy_(new_max) + + return x_orig + + @torch.jit.export + def calculate_qparams(self): # type: ignore[override] + is_uninitialized = self.min_val == float("inf") and self.max_val == float( + "-inf" + ) + if is_uninitialized: + warnings.warn( + "must run observer before calling calculate_qparams.\ + Returning default scale and zero point ", + stacklevel=2, + ) + return torch.tensor([1.0], device=self.min_val.device.type), torch.tensor( + [0], device=self.min_val.device.type + ) + if self.bins != len(self.histogram): + raise AssertionError( + "The number of bins in histogram should be equal to the number of bins " + "supplied while making this observer" + ) + + new_min, new_max = self._non_linear_param_search() + + return self._calculate_qparams(new_min, new_max) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + super()._save_to_state_dict(destination, prefix, keep_vars) + destination[prefix + "min_val"] = self.min_val + destination[prefix + "max_val"] = self.max_val + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + version = local_metadata.get("version", None) + + if version is None or version < 3: + # if min_val and max_val are not initialized, update their shape + # to account for the differences between v2 and v3 + min_val_name, max_val_name = prefix + "min_val", prefix + "max_val" + if min_val_name in state_dict: + if state_dict[min_val_name].shape == torch.Size([0]): + state_dict[min_val_name] = torch.tensor(float("inf")) + if max_val_name in state_dict: + if state_dict[max_val_name].shape == torch.Size([0]): + state_dict[max_val_name] = torch.tensor(float("-inf")) + + local_state = ["min_val", "max_val"] + for name in local_state: + key = prefix + name + if key in state_dict: + val = state_dict[key] + setattr(self, name, val) + elif strict: + missing_keys.append(key) + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + + def extra_repr(self): + return f"min_val={self.min_val}, max_val={self.max_val}" + + +class FixedQParamsObserver(ObserverBase): + r""" + Observer that simulates quantize and dequantize with fixed + quantization parameters in training time. Only per tensor + quantization is supported. + + Args: + `scale` (float): fixed scale for the observer + `zero_point` (int): fixed zero point for the observer + `dtype`, `qscheme`, `quant_min`, `quant_max` + """ + + scale: torch.Tensor + zero_point: torch.Tensor + + def __init__( + self, + scale, + zero_point, + dtype=torch.quint8, + qscheme=torch.per_tensor_affine, + quant_min=0, + quant_max=255, + is_dynamic=False, + **kwargs, + ): + if is_dynamic: + raise NotImplementedError( + "FixedQParamsObserver doesn't support dynamic quantization" + ) + super().__init__(dtype=dtype, is_dynamic=is_dynamic, **kwargs) + self.quant_min = quant_min + self.quant_max = quant_max + self.register_buffer("scale", torch.tensor([scale], dtype=torch.float)) + self.register_buffer("zero_point", torch.tensor([zero_point], dtype=torch.int)) + self.dtype = dtype + self.qscheme = qscheme + + def forward(self, X): + return X + + @torch.jit.export + def calculate_qparams(self): # type: ignore[override] + return self.scale, self.zero_point + + +class PlaceholderObserver(ObserverBase): + r""" + Observer that doesn't do anything and just passes its configuration to the + quantized module's ``.from_float()``. + + Can be used for quantization to float16 which doesn't require determining + ranges. + + Args: + dtype: dtype argument to the `quantize` node needed to implement the + reference model spec. + quant_min: minimum value in quantized domain (TODO: align behavior with other observers) + quant_max: maximum value in quantized domain + custom_op_name: (temporary) specify this observer for an operator that doesn't require any observation + (Can be used in Graph Mode Passes for special case ops). + compute_dtype (deprecated): if set, marks the future quantize function to use + dynamic quantization instead of static quantization. + This field is deprecated, use `is_dynamic=True` instead. + is_dynamic: if True, the `quantize` function in the reference model + representation taking stats from this observer instance will + use dynamic quantization. + """ + + def __init__( + self, + dtype=torch.float32, + custom_op_name="", + compute_dtype=None, + quant_min=None, + quant_max=None, + qscheme=None, + eps=None, + is_dynamic=False, + ) -> None: + super().__init__(dtype=dtype, is_dynamic=is_dynamic) + if qscheme is None: + qscheme = torch.per_tensor_affine + if eps is None: + eps = torch.finfo(torch.float32).eps + + # dtype of input of the target operator, e.g. for dynamic quantization + # ops, the dtype will be float32 + self.dtype = dtype + self.qscheme = qscheme + self.quant_min = quant_min + self.quant_max = quant_max + self.eps = eps + self.custom_op = custom_op_name + # used for configuration of computation type for dynamic quantization + if compute_dtype: + is_dynamic = True + warnings.warn( + "Please use `is_dynamic` instead of `compute_dtype`. \ + `compute_dtype` will be deprecated in a future release \ + of PyTorch.", + stacklevel=2, + ) + + def forward(self, x): + return x + + @torch.jit.export + def extra_repr(self): + return f"dtype={self.dtype}, is_dynamic={self.is_dynamic}" + + @torch.jit.export + def calculate_qparams(self): # type: ignore[override] + raise Exception( # noqa: TRY002 + "calculate_qparams should not be called for PlaceholderObserver" + ) + + +class RecordingObserver(ObserverBase): + r""" + The module is mainly for debug and records the tensor values during runtime. + + Args: + dtype: Quantized data type + qscheme: Quantization scheme to be used + reduce_range: Reduces the range of the quantized data type by 1 bit + """ + + __annotations__ = {"tensor_val": list[torch.Tensor | None]} + + def __init__(self, dtype=torch.quint8): + super().__init__(dtype=dtype, is_dynamic=False) + self.tensor_val = [] + + def forward(self, x): + self.tensor_val.append(x.clone()) + return x + + @torch.jit.export + def calculate_qparams(self): # type: ignore[override] + raise Exception( # noqa: TRY002 + "calculate_qparams should not be called for RecordingObserver" + ) + + @torch.jit.export + def get_tensor_value(self): + return self.tensor_val + + +class NoopObserver(ObserverBase): + r""" + Observer that doesn't do anything and just passes its configuration to the + quantized module's ``.from_float()``. + + Primarily used for quantization to float16 which doesn't require determining + ranges. + + Args: + dtype: Quantized data type + custom_op_name: (temporary) specify this observer for an operator that doesn't require any observation + (Can be used in Graph Mode Passes for special case ops). + """ + + def __init__(self, dtype=torch.float16, custom_op_name="") -> None: + super().__init__(dtype=dtype, is_dynamic=False) + self.dtype = dtype + self.custom_op = custom_op_name + + def forward(self, x): + return x + + @torch.jit.export + def calculate_qparams(self): # type: ignore[override] + raise Exception( # noqa: TRY002 + "calculate_qparams should not be called for NoopObserver" + ) + + +class ReuseInputObserver(ObserverBase): + r"""This observer is used when we want to reuse the observer from the operator + that produces the input Tensor, typically used for operators like reshape, e.g. + ``` + x0 = ... + x1 = x0.reshape() + ``` + if we configure x0 to be observed by some observer, let's say MinMaxObserver, + and reshape is configured with ReuseInputObserver, we'll reuse the observer instance + for x0 for x1 (output of reshape). If x0 is not observed, we also won't observe x1. + + Note: this is only enabled in FX Graph Mode Quantization + """ + + def __init__(self) -> None: + super().__init__(torch.quint8, is_dynamic=False) + + def forward(self, x): + return x + + @torch.jit.export + def calculate_qparams(self): # type: ignore[override] + raise Exception( # noqa: TRY002 + "calculate_qparams should not be called for ReuseInputObserver" + ) + + +""" +# Experimental Affine Quantization Feature START +We plan to merge the following with torchao repo after we move pt2e flow to torchao +copied from https://github.com/pytorch/ao/blob/main/torchao/quantization/observer.py +""" +from dataclasses import dataclass +from enum import auto, Enum + + +class MappingType(Enum): + """How floating point number is mapped to integer number + + symmetric mapping means floating point range is symmetrically mapped to integer range + let's say we have floating point range (-3.5, 10.2) and integer range (-8, 7) (int4) + we'll use (-10.2, 10.2) as the range for floating point and map that to (-8, 7) + e.g. scale = (10.2 - (-10.2)) / (7 - (-8)) + + SYMMETRIC_NO_CLIPPING_ERR is a variant of symmetric mapping, where the scale is the max of smin + and smax, where smin = min_val_neg / quant_min, and smax = max_val_pos / quant_max. By calculating + smin and smax individually, there can be less round error on negative values, and no out-of-range + of all floating point values. + + asymmetric mapping means we just directly map the floating point range to integer range, + for the above example, we will map (-3.5, 10.2) to (-8, 7) and calculate quantization parameter + based on this mapping + e.g. scale = (10.2 - (-3.5)) / (7 - (-8)) + """ + + SYMMETRIC = auto() + SYMMETRIC_NO_CLIPPING_ERR = auto() + ASYMMETRIC = auto() + + +class ZeroPointDomain(Enum): + """Enum that indicate whether zero_point is in integer domain or floating point domain + + integer domain: quantized_val = (float_val / scale) (integer) + zero_point (integer) + float domain: quantized_val = (float_val - (zero_point (float) - scale * mid_point)) / scale + none domain: quantized_val = (float_val / scale) + """ + + INT = auto() + FLOAT = auto() + NONE = auto() + + +class TorchAODType(Enum): + """ + Placeholder for dtypes that do not exist in PyTorch core yet. + """ + + # torch.int1 to torch.int7 will be added to PyTorch 2.6 + # These will remain here for BC with older PyTorch versions + INT1 = auto() + INT2 = auto() + INT3 = auto() + INT4 = auto() + INT5 = auto() + INT6 = auto() + INT7 = auto() + + +@dataclass(frozen=True) +class Granularity: + """ + Base class for representing the granularity of quantization. + + This class serves as a parent for specific granularity types used in + quantization operations, such as per-tensor or per-axis quantization. + """ + + +@dataclass(frozen=True) +class PerBlock(Granularity): + """ + Represents per-block granularity in quantization. See + :func:`~torchao.quantization.quant_primitives.quantize_affine` for docs for + `block_size` + + Attributes: + block_size (Tuple[int, ...]): The size of each quantization group + """ + + block_size: tuple[int, ...] + + +@dataclass(frozen=True) +class PerTensor(Granularity): + """ + Represents per-tensor granularity in quantization. + + This granularity type calculates the quantization parameters + based off the entire tensor. + + """ + + +@dataclass(frozen=True) +class PerAxis(Granularity): + """ + Represents per-axis granularity in quantization. + + This granularity type calculates different quantization parameters + along a specified axis of the tensor. + + For example if the input tensor is shape [8, 16] and axis=0, then + the quantization parameters are calculated for each row of the tensor. + Giving a total of 8 quantization parameters. + + Attributes: + axis (int): The axis along which reduction is performed. + """ + + axis: int + + +@dataclass(frozen=True) +class PerGroup(Granularity): + """ + Represents per-channel group granularity in quantization. + + This granularity type calculates different quantization parameters + for each group of elements. + + For example if the input tensor is shape [8, 16], and the group size is 4, then + the input tensor is reshaped to [64, 4] + quantization parameters are calculated for each group of 4 elements, + giving a total of 64 quantization parameters. + + Attributes: + group_size (int): The size of each quantization group + + """ + + group_size: int + + +class PerRow(Granularity): + """ + Represents row-wise granularity in quantization. + + This is a special case of per-axis quantization and is unique to Float8 matmuls + where the input is quantized with a block_size of (1, ..., input.shape[-1]). And the weight + is quantized with a block_size of (1, weight.shape[1]). + """ + + +class PerToken(Granularity): + """ + Represents per-token granularity in quantization. + + This granularity type calculates a different set of quantization parameters + for each token, which is represented as the last dimension of the tensor. + + For example, if the input tensor has shape [2, 3, 4], then there are 6 tokens + with 4 elements each, and we will calculate 6 sets of quantization parameters, + one for each token. + + If the input tensor has only two dimensions, e.g. [8, 16], then this is + equivalent to `PerAxis(axis=0)`, which yields 8 sets of quantization parameters. + """ + + +def get_block_size( + input_shape: tuple[int, ...], granularity: Granularity +) -> tuple[int, ...]: + """Get the block size based on the input shape and granularity type. + + Args: + input_shape: The input tensor shape possibly more than 2 dimensions + granularity: The granularity type of the quantization + """ + if not isinstance(granularity, Granularity): + raise AssertionError( + "Please provide an instance of Granularity, not subclass of it" + ) + if isinstance(granularity, PerTensor): + return input_shape + elif isinstance(granularity, PerAxis): + block_size = list(input_shape) + block_size[granularity.axis] = 1 + return tuple(block_size) + elif isinstance(granularity, PerRow): + return (1,) * (len(input_shape) - 1) + (input_shape[-1],) + elif isinstance(granularity, PerGroup): + if len(input_shape) != 2: + raise AssertionError( + f"Expecting input shape dim to be 2 for per group quantization, gotinput shape: {input_shape}" + ) + return (1, granularity.group_size) + elif isinstance(granularity, PerToken): + block_size = [1] * len(input_shape) + block_size[-1] = input_shape[-1] + return tuple(block_size) + raise ValueError(f"Unsupported Granularity: {granularity}") + + +class AffineQuantizedObserverBase(ABC, torch.nn.Module): + """Observer module for affine quantization (https://github.com/pytorch/ao/tree/main/torchao/quantization#affine-quantization) + + Args: + `granularity` and `block_size`: The granularity of the quantization, + must specify at least one, if both are specified `block_size` takes precedence + Current supported granularity type are `PerTensor` and `PerAxis` + other args: please see `:class:torchao.dtypes.AffineQuantizedTensor` + """ + + with_args = classmethod(_with_args) + + def __init__( + self, + mapping_type: MappingType, + target_dtype: torch.dtype, + granularity: Granularity, + quant_min: int | None = None, + quant_max: int | None = None, + eps: float | None = None, + scale_dtype: torch.dtype | None = None, + zero_point_dtype: torch.dtype | None = None, + preserve_zero: bool = True, + zero_point_domain: ZeroPointDomain | None = ZeroPointDomain.INT, + # there could be some extra args that's ignored + **kwargs, + ): + super().__init__() + if granularity is None: + raise AssertionError("granularity is None") + self.mapping_type = mapping_type + self.target_dtype = target_dtype + self.granularity = granularity + self.quant_min = quant_min + self.quant_max = quant_max + self.eps = eps + self.scale_dtype = scale_dtype + self.zero_point_dtype = zero_point_dtype + self.preserve_zero = preserve_zero + self.zero_point_domain = zero_point_domain + # populatd during forward + self.block_size = None + self.original_dtype = None + + @abstractmethod + def forward(self, input: torch.Tensor) -> torch.Tensor: + """forward function should take the input tensor + and updates internal stats and return the original input Tensor + """ + + @abstractmethod + def calculate_qparams(self) -> tuple[torch.Tensor, torch.Tensor]: + """Calculate quantization parameter based on the stats attached to the observer module + and returns a tuple of scale and zero_point Tensor + """ + + def convert(self, model: torch.fx.GraphModule, observer_node: Node): + """ + Converts the observer node in the graph into its quantized representation + + Args: + model: graph module to convert the observer node in + observer_node: the observer node to convert + """ + from torch.ao.quantization.fx.utils import create_getattr_from_value + + with model.graph.inserting_before(observer_node): + if self.block_size is None: + raise AssertionError("Expecting block_size to be populated") + if self.original_dtype is None: + raise AssertionError("Expecting original_dtype to be populated") + if hasattr(self, "is_dynamic") and self.is_dynamic: + choose_qparams_affine = model.graph.call_function( + torch.ops.pt2e_quant.choose_qparams_affine, + ( + observer_node.args[0], + self.mapping_type.name, + self.block_size, + self.target_dtype, + self.quant_min, + self.quant_max, + self.eps, + self.scale_dtype, + self.zero_point_dtype, + self.preserve_zero, + self.zero_point_domain.name, + ), + ) + scale_node = model.graph.call_function( + operator.getitem, (choose_qparams_affine, 0) + ) + zero_point_node = model.graph.call_function( + operator.getitem, (choose_qparams_affine, 1) + ) + else: + scale, zero_point = self.calculate_qparams() + scale_node = create_getattr_from_value( + model, + model.graph, + "_scale", + scale, + scale.device if isinstance(scale, torch.Tensor) else None, + ) + zero_point_node = create_getattr_from_value( + model, + model.graph, + "_zero_point", + zero_point, + zero_point.device if isinstance(zero_point, torch.Tensor) else None, + ) + + q_node = model.graph.call_function( + torch.ops.pt2e_quant.quantize_affine, + ( + observer_node.args[0], + self.block_size, + scale_node, + zero_point_node, + self.target_dtype, + self.quant_min, + self.quant_max, + self.zero_point_domain.name, + ), + {}, + ) + dq_node = model.graph.call_function( + torch.ops.pt2e_quant.dequantize_affine, + ( + q_node, + self.block_size, + scale_node, + zero_point_node, + self.target_dtype, + self.quant_min, + self.quant_max, + self.zero_point_domain.name, + ), + {"output_dtype": self.original_dtype}, + ) + observer_node.replace_all_uses_with(dq_node) + model.graph.erase_node(observer_node) + + +def _is_observer_script_module(mod, obs_type_name): + """Returns true if given mod is an instance of Observer script module.""" + if isinstance(mod, torch.jit.RecursiveScriptModule): + # qualified name looks like '__torch__.torch.ao.quantization.observer.___torch_mangle_2.MinMaxObserver' + suffix = mod._c.qualified_name.split(".", 1)[1] + name = re.sub(r"\.___torch_mangle_\d+", "", suffix) + return obs_type_name in name + return False + + +# Experimental Affine Quantization Feature END + + +def _is_activation_post_process(module): + return isinstance( + module, + ( + torch.ao.quantization.ObserverBase, + torch.ao.quantization.FakeQuantizeBase, + AffineQuantizedObserverBase, + ), + ) or _is_observer_script_module(module, "quantization.observer") + + +def _is_per_channel_script_obs_instance(module): + if isinstance(module, torch.jit.RecursiveScriptModule): + return _is_observer_script_module( + module, "quantization.observer.PerChannelMinMaxObserver" + ) or _is_observer_script_module( + module, "quantization.observer.MovingAveragePerChannelMinMaxObserver" + ) + return False + + +def get_observer_state_dict(mod): + r""" + Returns the state dict corresponding to the observer stats. + Traverse the model state_dict and extract out the stats. + """ + od = OrderedDict() + if isinstance(mod, torch.jit.RecursiveScriptModule): + for k, v in mod.state_dict().items(): + if "observer" in k: + od[k] = v + else: + # path for GraphModule and nn.Module (eager mode) + for k, v in mod.state_dict().items(): + if "activation_post_process" in k: + od[k] = v + od._metadata = mod.state_dict()._metadata # type: ignore[attr-defined] + return od + + +def load_observer_state_dict(mod, obs_dict): + r""" + Given input model and a state_dict containing model observer stats, + load the stats back into the model. The observer state_dict can be saved + using torch.ao.quantization.get_observer_state_dict + """ + missing_keys: list[str] = [] + unexpected_keys: list[str] = [] + for name, module in mod.named_modules(): + prefix = name + "." + if _is_activation_post_process(module): + if _is_per_channel_script_obs_instance(module): + # For per-channel observers we need to call a custom load_from_state_dict to resize the tensor. + # However this is not called when the module is scripted and we end up calling the default one in module.py + module._load_from_state_dict_script( + obs_dict, prefix, {}, True, missing_keys, unexpected_keys, [] + ) + else: + module._load_from_state_dict( + obs_dict, prefix, {}, False, missing_keys, unexpected_keys, [] + ) + for k in missing_keys: + if "observer" in k or "activation_post_process" in k: + raise Exception( # noqa: TRY002 + f"Missing keys for observer {k} in state_dict" + ) + for k in unexpected_keys: + if "observer" in k or "activation_post_process" in k: + raise Exception( # noqa: TRY002 + f"Unexpected keys for observer {k} in state_dict" + ) + + +# Restrict activations to be in the range (0,127) +default_observer = MinMaxObserver.with_args(quant_min=0, quant_max=127) +""" +Default observer for static quantization, usually used for debugging. +""" + +default_placeholder_observer = PlaceholderObserver +""" +Default placeholder observer, usually used for quantization to torch.float16. +""" + +default_debug_observer = RecordingObserver +""" +Default debug-only observer. +""" + +default_weight_observer = MinMaxObserver.with_args( + dtype=torch.qint8, qscheme=torch.per_tensor_symmetric +) +""" +Default weight observer. +""" + +weight_observer_range_neg_127_to_127 = MinMaxObserver.with_args( + dtype=torch.qint8, + qscheme=torch.per_tensor_symmetric, + quant_min=-127, + quant_max=127, + eps=2**-12, +) +""" +Symmetric weight observer with the 8-bit values restricted to [-127, +127], excluding -128. +""" + +default_histogram_observer = HistogramObserver.with_args(quant_min=0, quant_max=127) +""" +Default histogram observer, usually used for PTQ. +""" + +default_per_channel_weight_observer = PerChannelMinMaxObserver.with_args( + dtype=torch.qint8, qscheme=torch.per_channel_symmetric +) +""" +Default per-channel weight observer, usually used on backends where per-channel +weight quantization is supported, such as `fbgemm`. +""" + +per_channel_weight_observer_range_neg_127_to_127 = PerChannelMinMaxObserver.with_args( + dtype=torch.qint8, + qscheme=torch.per_channel_symmetric, + quant_min=-127, + quant_max=127, + eps=2**-12, +) +""" +Per-channel, symmetric weight observer with the 8-bit values restricted to [-127, +127], excluding -128. +""" + +default_dynamic_quant_observer = PlaceholderObserver.with_args( + dtype=torch.quint8, + quant_min=0, + quant_max=255, + is_dynamic=True, +) +""" +Default observer for dynamic quantization. +""" + +default_float_qparams_observer = PerChannelMinMaxObserver.with_args( + dtype=torch.quint8, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0 +) +""" +Default observer for a floating point zero-point. +""" + +default_float_qparams_observer_4bit = PerChannelMinMaxObserver.with_args( + dtype=torch.quint4x2, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0 +) +""" +Default observer for a floating point zero-point and 4 bit activations. +""" + +# TODO(future PR): remove these defaults and enforce activation functions +# to explicitly specify their output range +default_fixed_qparams_range_neg1to1_observer = FixedQParamsObserver.with_args( + scale=2.0 / 256.0, zero_point=128, dtype=torch.quint8, quant_min=0, quant_max=255 +) +default_fixed_qparams_range_0to1_observer = FixedQParamsObserver.with_args( + scale=1.0 / 256.0, zero_point=0, dtype=torch.quint8, quant_min=0, quant_max=255 +) +# TODO: the following 2 variables are kept for backwards compatibility; remove after a few releases +default_symmetric_fixed_qparams_observer = default_fixed_qparams_range_neg1to1_observer +default_affine_fixed_qparams_observer = default_fixed_qparams_range_0to1_observer + +""" +Default observers for fixed qparams operations. +""" + +default_reuse_input_observer = ReuseInputObserver +""" +Default observer for operators like reshape that reuses the observer of input to +the operator +""" diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/representation/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/representation/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..98e6956fe2fce5620d72df53a04cf7bdc18f5621 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/pt2e/representation/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/qconfig.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/qconfig.py new file mode 100644 index 0000000000000000000000000000000000000000..ff5d1f341751a3b0ea4f720978d3c380e26ccc41 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/qconfig.py @@ -0,0 +1,715 @@ +# mypy: allow-untyped-defs +import copy +import sys +import warnings +from collections import namedtuple +from typing import Any, Optional, Union +from typing_extensions import deprecated + +import torch +import torch.nn as nn +from torch.ao.quantization.fake_quantize import ( + default_dynamic_fake_quant, + default_embedding_fake_quant, + default_embedding_fake_quant_4bit, + default_fake_quant, + default_fused_act_fake_quant, + default_fused_per_channel_wt_fake_quant, + default_fused_wt_fake_quant, + default_per_channel_weight_fake_quant, + default_weight_fake_quant, + FakeQuantize, + FakeQuantizeBase, + fused_per_channel_wt_fake_quant_range_neg_127_to_127, + fused_wt_fake_quant_range_neg_127_to_127, + FusedMovingAvgObsFakeQuantize, +) + +from .observer import ( + _PartialWrapper, + default_debug_observer, + default_dynamic_quant_observer, + default_float_qparams_observer, + default_float_qparams_observer_4bit, + default_observer, + default_per_channel_weight_observer, + default_placeholder_observer, + default_reuse_input_observer, + default_weight_observer, + HistogramObserver, + MinMaxObserver, + MovingAverageMinMaxObserver, + NoopObserver, + ObserverBase, + per_channel_weight_observer_range_neg_127_to_127, + PlaceholderObserver, + ReuseInputObserver, + weight_observer_range_neg_127_to_127, +) + + +__all__ = [ + "QConfig", + # TODO: deprecated, remove + "QConfigDynamic", + "default_qconfig", + "default_debug_qconfig", + "default_per_channel_qconfig", + "default_dynamic_qconfig", + "float16_dynamic_qconfig", + "float16_static_qconfig", + "per_channel_dynamic_qconfig", + "float_qparams_weight_only_qconfig", + "float_qparams_weight_only_qconfig_4bit", + "default_quint8_weight_qconfig", + "default_qat_qconfig", + "default_dynamic_qat_qconfig", + "default_weight_only_qconfig", + "default_activation_only_qconfig", + "default_qat_qconfig_v2", + "default_reuse_input_qconfig", + "default_symmetric_qnnpack_qconfig", + "default_per_channel_symmetric_qnnpack_qconfig", + "default_symmetric_qnnpack_qat_qconfig", + "default_per_channel_symmetric_qnnpack_qat_qconfig", + "default_embedding_qat_qconfig", + "default_embedding_qat_qconfig_4bit", + "get_default_qconfig", + "get_default_qat_qconfig", + "get_default_qconfig_dict", + "get_default_qat_qconfig_dict", + "QConfigAny", + "qconfig_equals", +] + + +# pyrefly: ignore [invalid-inheritance] +class QConfig(namedtuple("QConfig", ["activation", "weight"])): + """ + Describes how to quantize a layer or a part of the network by providing + settings (observer classes) for activations and weights respectively. + + + Note that QConfig needs to contain observer **classes** (like MinMaxObserver) or a callable that returns + instances on invocation, not the concrete observer instances themselves. + Quantization preparation function will instantiate observers multiple times for each of the layers. + + + Observer classes have usually reasonable default arguments, but they can be overwritten with `with_args` + method (that behaves like functools.partial):: + + my_qconfig = QConfig( + activation=MinMaxObserver.with_args(dtype=torch.qint8), + weight=default_observer.with_args(dtype=torch.qint8), + ) + + """ + + __slots__ = () + + def __new__(cls, activation, weight): + # catch common mistakes + if isinstance(activation, nn.Module) or isinstance(weight, nn.Module): + raise ValueError( + "QConfig received observer instance, please pass observer class instead. " + + "Use MyObserver.with_args(x=1) to override arguments to constructor if needed" + ) + return super().__new__(cls, activation, weight) + + +@deprecated( + "`QConfigDynamic` is going to be deprecated in PyTorch 1.12, please use `QConfig` instead", + category=FutureWarning, +) +# pyrefly: ignore [invalid-inheritance] +class QConfigDynamic(namedtuple("QConfigDynamic", ["activation", "weight"])): + """ + Describes how to dynamically quantize a layer or a part of the network by providing + settings (observer classes) for weights. + + It's like QConfig, but for dynamic quantization. + + Note that QConfigDynamic needs to contain observer **classes** (like MinMaxObserver) or a callable that returns + instances on invocation, not the concrete observer instances themselves. + Quantization function will instantiate observers multiple times for each of the layers. + + Observer classes have usually reasonable default arguments, but they can be overwritten with `with_args` + method (that behaves like functools.partial):: + + my_qconfig = QConfigDynamic(weight=default_observer.with_args(dtype=torch.qint8)) + """ + + __slots__ = () + + def __new__(cls, activation=torch.nn.Identity, weight=torch.nn.Identity): + # catch common mistakes + if isinstance(weight, nn.Module): + raise ValueError( + "QConfigDynamic received observer instance, please pass observer class instead. " + + "Use MyObserver.with_args(x=1) to override arguments to constructor if needed" + ) + return super().__new__(cls, activation, weight) + + +default_qconfig = QConfig(activation=default_observer, weight=default_weight_observer) +""" +Default qconfig configuration. +""" + +default_debug_qconfig = QConfig( + weight=default_weight_observer, activation=default_debug_observer +) +""" +Default qconfig configuration for debugging. +""" + +default_per_channel_qconfig = QConfig( + activation=default_observer, weight=default_per_channel_weight_observer +) +""" +Default qconfig configuration for per channel weight quantization. +""" + +default_dynamic_qconfig = QConfig( + activation=default_dynamic_quant_observer, weight=default_weight_observer +) +""" +Default dynamic qconfig. +""" + +float16_dynamic_qconfig = QConfig( + activation=PlaceholderObserver.with_args(dtype=torch.float16, is_dynamic=True), + weight=PlaceholderObserver.with_args(dtype=torch.float16), +) +""" +Dynamic qconfig with weights quantized to `torch.float16`. +""" + +float16_static_qconfig = QConfig( + activation=PlaceholderObserver.with_args(dtype=torch.float16), + weight=PlaceholderObserver.with_args(dtype=torch.float16), +) +""" +Dynamic qconfig with both activations and weights quantized to `torch.float16`. +""" + +per_channel_dynamic_qconfig = QConfig( + activation=default_dynamic_quant_observer, + weight=default_per_channel_weight_observer, +) +""" +Dynamic qconfig with weights quantized per channel. +""" + +float_qparams_weight_only_qconfig = QConfig( + activation=default_placeholder_observer, weight=default_float_qparams_observer +) +""" +Dynamic qconfig with weights quantized with a floating point zero_point. +""" + +float_qparams_weight_only_qconfig_4bit = QConfig( + activation=default_placeholder_observer, weight=default_float_qparams_observer_4bit +) + +default_qat_qconfig = QConfig( + activation=default_fake_quant, weight=default_weight_fake_quant +) +""" +Default qconfig for QAT. +""" + +default_dynamic_qat_qconfig = QConfig( + activation=default_dynamic_fake_quant, weight=default_weight_fake_quant +) +""" +Default qconfig for dynamic QAT. +""" + +default_weight_only_qconfig = QConfig( + activation=torch.nn.Identity, weight=default_weight_fake_quant +) +""" +Default qconfig for quantizing weights only. +""" + +default_activation_only_qconfig = QConfig( + activation=default_fake_quant, weight=torch.nn.Identity +) +""" +Default qconfig for quantizing activations only. +""" + +# QAT config that uses a fused observer + fake quant modules for optimized training performance. +# to modify the activation/weight observers, the default entries in fake_quantize.py can be modified. +default_qat_qconfig_v2 = QConfig( + activation=default_fused_act_fake_quant, weight=default_fused_wt_fake_quant +) +""" +Fused version of `default_qat_config`, has performance benefits. +""" + +default_reuse_input_qconfig = QConfig( + activation=default_reuse_input_observer, weight=NoopObserver +) +""" +Default qconfig for operators that reuse the observers from input Tensor, e.g. reshape +""" + + +def get_default_qconfig(backend="x86", version=0): + """ + Returns the default PTQ qconfig for the specified backend. + + Args: + * `backend` (str): a string representing the target backend. Currently supports + `x86` (default), `fbgemm`, `qnnpack` and `onednn`. + + Return: + qconfig + """ + supported_backends = ["fbgemm", "x86", "qnnpack", "onednn"] + if backend not in supported_backends: + raise AssertionError( + "backend: " + + str(backend) + + f" not supported. backend must be one of {supported_backends}" + ) + + if version == 0: + if backend == "fbgemm": + qconfig = QConfig( + activation=HistogramObserver.with_args(reduce_range=True), + weight=default_per_channel_weight_observer, + ) + elif backend == "qnnpack": + # TODO: make this compatible with xnnpack constraints + qconfig = QConfig( + activation=HistogramObserver.with_args(reduce_range=False), + weight=default_weight_observer, + ) + elif backend == "onednn": + if not torch.cpu._is_vnni_supported(): + warnings.warn( + "Default qconfig of oneDNN backend with reduce_range of false may have accuracy issues " + "on CPU without Vector Neural Network Instruction support.", + stacklevel=2, + ) + qconfig = QConfig( + activation=HistogramObserver.with_args(reduce_range=False), + weight=default_per_channel_weight_observer, + ) + elif backend == "x86": + qconfig = QConfig( + activation=HistogramObserver.with_args(reduce_range=True), + weight=default_per_channel_weight_observer, + ) + else: + # won't reach + qconfig = default_qconfig + else: + raise AssertionError( + "Version number: " + + str(version) + + " in get_default_qconfig is not supported. Version number must be 0" + ) + + return qconfig + + +""" +Default, symmetric PTQ qconfig for the specified backend. And a per_channel +variant of the same. + +Symmetric here applies to signed weights with zero point = 0, and additional +value restrictions. The activations are also signed 8-bit integers with this +qconfig. + + * Once this change is merged [as of 3/17/22], with backend or qengine = + 'qnnpack', some quantized operators with this symmetric qconfig may use + operators from xnnpack library. + + ** Support to use xnnpack ops with `qnnpack` backed for asymmetric + qconfig (returned by get_default_qconfig()) is not available yet. + + * This qconfig uses signed activations and weights. Weights have added + restrictions such as zero point is forced to be 0, making the weights + symmetric, hence the name. And the 8-bit quantized values are + restricting to to [-127, +127], excluding -128. + + * xnnpack has a requantization scale value restriction, 0x1p-32 <= + requantization_scale < 256.0 where, `requantization_scale = (input_scale + * kernel_scale) / (output_scale)`. Using this eps (w/ assumed max value + of 256) is to prevent requantization_scale to go below xnnpack lower + threshold. +""" +default_symmetric_qnnpack_qconfig = QConfig( + activation=HistogramObserver.with_args( + dtype=torch.qint8, reduce_range=False, eps=2**-12 + ), + weight=weight_observer_range_neg_127_to_127, +) + +default_per_channel_symmetric_qnnpack_qconfig = QConfig( + activation=HistogramObserver.with_args( + dtype=torch.qint8, reduce_range=False, eps=2**-12 + ), + weight=per_channel_weight_observer_range_neg_127_to_127, +) + +default_embedding_qat_qconfig = QConfig( + activation=NoopObserver.with_args(dtype=torch.float32), + weight=default_embedding_fake_quant, +) + +default_embedding_qat_qconfig_4bit = QConfig( + activation=NoopObserver.with_args(dtype=torch.float32), + weight=default_embedding_fake_quant_4bit, +) + +default_quint8_weight_qconfig = QConfig( + activation=HistogramObserver, weight=MinMaxObserver +) + + +def get_default_qat_qconfig(backend="x86", version=1): + """ + Returns the default QAT qconfig for the specified backend. + + Args: + * `backend` (str): a string representing the target backend. Currently supports + `x86` (default), `fbgemm`, `qnnpack` and `onednn`. + * `version`: version, for backwards compatibility. Can be `None` or `1`. + + Return: + qconfig + """ + supported_backends = ["fbgemm", "x86", "qnnpack", "onednn"] + if backend not in supported_backends: + raise AssertionError( + "backend: " + + str(backend) + + f" not supported. backend must be one of {supported_backends}" + ) + + # Histogram observer is too slow for quantization aware training + if version == 0: + if backend == "fbgemm": + qconfig = QConfig( + activation=FakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, + quant_min=0, + quant_max=255, + reduce_range=True, + ), + weight=default_per_channel_weight_fake_quant, + ) + elif backend == "qnnpack": + qconfig = QConfig( + activation=FakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, + quant_min=0, + quant_max=255, + reduce_range=False, + ), + weight=default_weight_fake_quant, + ) + elif backend == "onednn": + qconfig = QConfig( + activation=FakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255 + ), + weight=default_per_channel_weight_fake_quant, + ) + elif backend == "x86": + qconfig = QConfig( + activation=FakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, + quant_min=0, + quant_max=255, + reduce_range=True, + ), + weight=default_per_channel_weight_fake_quant, + ) + else: + qconfig = default_qat_qconfig + # Use the fused observe + fake_quant modules for doing QAT. + elif version == 1: + if backend == "fbgemm": + qconfig = QConfig( + activation=FusedMovingAvgObsFakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, + quant_min=0, + quant_max=255, + reduce_range=True, + ), + weight=default_fused_per_channel_wt_fake_quant, + ) + elif backend == "qnnpack": + # TODO: make this compatible with xnnpack constraints + qconfig = QConfig( + activation=FusedMovingAvgObsFakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, + quant_min=0, + quant_max=255, + reduce_range=False, + ), + weight=default_fused_wt_fake_quant, + ) + elif backend == "onednn": + qconfig = QConfig( + activation=FusedMovingAvgObsFakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255 + ), + weight=default_fused_per_channel_wt_fake_quant, + ) + elif backend == "x86": + qconfig = QConfig( + activation=FusedMovingAvgObsFakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, + quant_min=0, + quant_max=255, + reduce_range=True, + ), + weight=default_fused_per_channel_wt_fake_quant, + ) + else: + qconfig = default_qat_qconfig_v2 + else: + raise AssertionError( + "Version number: " + + str(version) + + "in get_default_qat_qconfig is not supported. Version number must be 0 or 1" + ) + + return qconfig + + +""" +Default symmetric QAT qconfig for qnnpack. And its per channel weight variant. +""" +default_symmetric_qnnpack_qat_qconfig = QConfig( + activation=FusedMovingAvgObsFakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, + quant_min=-128, + quant_max=127, + dtype=torch.qint8, + reduce_range=False, + eps=2**-12, + ), + weight=fused_wt_fake_quant_range_neg_127_to_127, +) + +default_per_channel_symmetric_qnnpack_qat_qconfig = QConfig( + activation=FusedMovingAvgObsFakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, + quant_min=-128, + quant_max=127, + dtype=torch.qint8, + reduce_range=False, + eps=2**-12, + ), + weight=fused_per_channel_wt_fake_quant_range_neg_127_to_127, +) + +_default_fp32_placeholder_qconfig = QConfig( + activation=PlaceholderObserver.with_args(dtype=torch.float32), + weight=PlaceholderObserver.with_args(dtype=torch.float32), +) + +_default_quint8_placeholder_qconfig = QConfig( + activation=PlaceholderObserver.with_args(dtype=torch.quint8), + # operators using this qconfig doesn't have weights + weight=None, +) + + +@deprecated( + "`torch.ao.quantization.get_default_qconfig_dict` is deprecated and will be removed in " + "a future version. Please use `torch.ao.quantization.get_default_qconfig_mapping` instead.", + category=FutureWarning, +) +def get_default_qconfig_dict(backend="x86", version=0): + return torch.ao.quantization.get_default_qconfig_mapping(backend, version).to_dict() + + +@deprecated( + "`torch.ao.quantization.get_default_qat_qconfig_dict` is deprecated and will be removed in " + "a future version. Please use `torch.ao.quantization.get_default_qat_qconfig_mapping` instead.", + category=FutureWarning, +) +def get_default_qat_qconfig_dict(backend="x86", version=1): + return torch.ao.quantization.get_default_qat_qconfig_mapping( + backend, version + ).to_dict() + + +def _assert_valid_qconfig(qconfig: QConfig | None, mod: torch.nn.Module) -> None: + """ + Verifies that this `qconfig` is valid. + """ + if qconfig is None: + return + is_conv_transpose_mod = isinstance( + mod, + (torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d), + ) + if is_conv_transpose_mod: + if qconfig.weight is None: + # for now, we assume that any qconfig for ConvTranspose without a weight is valid + return + example_observer = qconfig.weight() + is_per_channel = isinstance( + example_observer, + ( + torch.ao.quantization.PerChannelMinMaxObserver, + torch.ao.quantization.MovingAveragePerChannelMinMaxObserver, + ), + ) + if is_per_channel: + raise AssertionError( + "Per channel weight observer is not supported yet for ConvTranspose{n}d." + ) + + +if sys.version_info < (3, 12): + QConfigAny = Optional[QConfig] + QConfigAny.__module__ = "torch.ao.quantization.qconfig" +else: + from typing import TypeAliasType + + QConfigAny = TypeAliasType("QConfigAny", QConfig | None) + + +def _add_module_to_qconfig_obs_ctr( + qconfig: QConfigAny, module: nn.Module | None +) -> Any: + r"""This is a helper function for use in quantization prepare that updates a qconfig so that + the constructors stored in the qconfig will create observers on the same device that + 'module' is on. This is intended to be used when the qconfigs are propagated to each + module in order to avoid potential device alignment issues. + + Args: + qconfig: QConfig with obs constructors stored in activation and weight + module: module which the qconfig is related to + + Return: + qconfig: configured so that obs constructors set to construct on the same device as module + """ + + if module is None or qconfig is None or qconfig._fields != ("activation", "weight"): + return qconfig + + def get_factory_kwargs_based_on_module_device(): + if not isinstance(module, torch.nn.Module): + raise AssertionError("module must be an instance of torch.nn.Module") + devices = {p.device for p in module.parameters()} | { + p.device for p in module.buffers() + } + device = next(iter(devices)) if len(devices) > 0 else None + return None if device is None else {"device": device} + + def configure_constructor_to_put_obs_on_module_device(original_constructor): + try: + # check if constructor can accept factory_kwargs + check = original_constructor.with_args(factory_kwargs=None) + check() + return original_constructor.with_callable_args( + factory_kwargs=get_factory_kwargs_based_on_module_device + ) + except AttributeError: # qconfig doesn't have activation or weight + return original_constructor + except TypeError: # the class doesn't accept factory_kwargs argument + return original_constructor + + activation = configure_constructor_to_put_obs_on_module_device(qconfig.activation) + weight = configure_constructor_to_put_obs_on_module_device(qconfig.weight) + + return QConfig(activation, weight) + + +_ObserverOrFakeQuantizeConstructor = Union[ + _PartialWrapper, type[ObserverBase], type[FakeQuantizeBase] +] + + +def _obs_or_fq_ctr_equals( + obs_or_fq1: _ObserverOrFakeQuantizeConstructor, + obs_or_fq2: _ObserverOrFakeQuantizeConstructor, +): + if isinstance(obs_or_fq1, _PartialWrapper) and isinstance( + obs_or_fq2, _PartialWrapper + ): + return _partial_wrapper_equals(obs_or_fq1, obs_or_fq2) + return obs_or_fq1 == obs_or_fq2 + + +def _partial_wrapper_equals(obs_or_fq1: _PartialWrapper, obs_or_fq2: _PartialWrapper): + """ + Return whether the two partial wrappers are equal, + """ + # functools.partial has no __eq__ operator defined so '==' defaults to 'is' + obs_or_fq1_keywords = copy.copy(obs_or_fq1.p.keywords) + obs_or_fq2_keywords = copy.copy(obs_or_fq2.p.keywords) + keywords_equal = True + # compare observer constructor with _obs_or_fq_ctr_equals since direct compare would fail + if "observer" in obs_or_fq1_keywords and "observer" in obs_or_fq2_keywords: + keywords_equal = keywords_equal and _obs_or_fq_ctr_equals( + obs_or_fq1_keywords["observer"], obs_or_fq2_keywords["observer"] + ) + obs_or_fq1_keywords.pop("observer") + obs_or_fq2_keywords.pop("observer") + keywords_equal = keywords_equal and obs_or_fq1_keywords == obs_or_fq2_keywords + return ( + obs_or_fq1.p.func == obs_or_fq2.p.func + and obs_or_fq1.p.args == obs_or_fq2.p.args + and keywords_equal + ) + + +def qconfig_equals(q1: QConfigAny, q2: QConfigAny): + """ + Returns `True` if `q1` equals `q2`, and `False` otherwise. + """ + if q1 is None or q2 is None: + return q1 == q2 + else: + if q1 is None or q2 is None: + raise AssertionError( + "Both q1 and q2 must be non-None for qconfig comparison" + ) + try: + # Qconfig weight and activation can be either a partial wrapper, + # or an observer class. Special handling is required (above) for + # comparing partial wrappers. + activation_same = _obs_or_fq_ctr_equals(q1.activation, q2.activation) + weight_same = _obs_or_fq_ctr_equals(q1.weight, q2.weight) + return activation_same and weight_same + except AttributeError: + return q1 == q2 + + +def _activation_is_memoryless(qconfig: QConfig): + """ + Return whether the observer for activations defined in the given QConfig is memoryless. + This means a MovingAverage observer with averaging constant equal to 1. + """ + + def _is_memoryless(observer): + return ( + hasattr(observer, "averaging_constant") and observer.averaging_constant == 1 + ) + + act = qconfig.activation() + if isinstance(act, FakeQuantizeBase) and hasattr(act, "activation_post_process"): + return _is_memoryless(act.activation_post_process) + else: + return _is_memoryless(act) + + +def _is_reuse_input_qconfig(qconfig: QConfig | None): + return ( + qconfig is not None + and isinstance(qconfig.activation(), ReuseInputObserver) + and isinstance(qconfig.weight(), NoopObserver) + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/qconfig_mapping.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/qconfig_mapping.py new file mode 100644 index 0000000000000000000000000000000000000000..cf896a96da055ea99d1e165c12dc450f50ad77dc --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/qconfig_mapping.py @@ -0,0 +1,385 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +from collections import OrderedDict +from typing import Any, TYPE_CHECKING + +import torch + +from .fake_quantize import default_weight_fake_quant, FixedQParamsFakeQuantize +from .observer import ( + _PartialWrapper, + default_fixed_qparams_range_0to1_observer, + default_fixed_qparams_range_neg1to1_observer, + default_placeholder_observer, + default_weight_observer, +) +from .qconfig import ( + default_quint8_weight_qconfig, + default_reuse_input_qconfig, + default_symmetric_qnnpack_qat_qconfig, + default_symmetric_qnnpack_qconfig, + get_default_qat_qconfig, + get_default_qconfig, + QConfig, + QConfigAny, +) + + +if TYPE_CHECKING: + from collections.abc import Callable + + +__all__ = [ + "get_default_qconfig_mapping", + "get_default_qat_qconfig_mapping", + "QConfigMapping", +] + + +# TODO: replace all usages with these constants +_GLOBAL_DICT_KEY = "" +_OBJECT_TYPE_DICT_KEY = "object_type" +_MODULE_NAME_REGEX_DICT_KEY = "module_name_regex" +_MODULE_NAME_DICT_KEY = "module_name" +_MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY = "module_name_object_type_order" + +# TODO: derive this map from the BackendConfig +_FIXED_QPARAMS_OP_TO_OBSERVER: dict[Callable | str, _PartialWrapper] = { + torch.nn.Hardsigmoid: default_fixed_qparams_range_0to1_observer, + torch.nn.functional.hardsigmoid: default_fixed_qparams_range_0to1_observer, + "hardsigmoid": default_fixed_qparams_range_0to1_observer, + "hardsigmoid_": default_fixed_qparams_range_0to1_observer, + torch.nn.Sigmoid: default_fixed_qparams_range_0to1_observer, + torch.sigmoid: default_fixed_qparams_range_0to1_observer, + "sigmoid": default_fixed_qparams_range_0to1_observer, + "sigmoid_": default_fixed_qparams_range_0to1_observer, + torch.nn.Softmax: default_fixed_qparams_range_0to1_observer, + torch.nn.Tanh: default_fixed_qparams_range_neg1to1_observer, + torch.tanh: default_fixed_qparams_range_neg1to1_observer, + "tanh": default_fixed_qparams_range_neg1to1_observer, + "tanh_": default_fixed_qparams_range_neg1to1_observer, +} + + +def _get_default_qconfig_mapping( + is_qat: bool, backend: str, version: int +) -> QConfigMapping: + """ + Return the default QConfigMapping for the given quantization type and backend. + """ + if is_qat: + qconfig = get_default_qat_qconfig(backend, version) + else: + qconfig = get_default_qconfig(backend, version) + default_weight = default_weight_fake_quant if is_qat else default_weight_observer + + # default_per_channel_weight_observer is not currently compatible with fbgemm backend + # so we have to modify the weight observer to default_weight_observer or another + # per tensor supported observer. + # see https://github.com/pytorch/pytorch/issues/47535 + if backend in ("fbgemm", "x86"): + qconfig_transpose = QConfig( + activation=qconfig.activation, weight=default_weight + ) + else: + qconfig_transpose = qconfig + + # currently layernorm only supports float weights + # we have to add this because otherwise there will be a extra quantize-dequantize pair + qconfig_layernorm = QConfig( + activation=qconfig.activation, weight=default_placeholder_observer + ) + + qconfig_mapping = ( + QConfigMapping() + .set_global(qconfig) + .set_object_type("reshape", default_reuse_input_qconfig) + .set_object_type(torch.nn.ConvTranspose1d, qconfig_transpose) + .set_object_type(torch.nn.ConvTranspose2d, qconfig_transpose) + .set_object_type(torch.nn.ConvTranspose3d, qconfig_transpose) + .set_object_type(torch.nn.functional.conv_transpose1d, qconfig_transpose) + .set_object_type(torch.nn.functional.conv_transpose2d, qconfig_transpose) + .set_object_type(torch.nn.functional.conv_transpose3d, qconfig_transpose) + .set_object_type(torch.nn.functional.layer_norm, qconfig_layernorm) + .set_object_type(torch.nn.LayerNorm, qconfig_layernorm) + .set_object_type(torch.nn.PReLU, default_quint8_weight_qconfig) + ) + # Use special observers for ops with fixed qparams + fixed_qparams_observer_to_qconfig: dict[Any, QConfigAny] = {} + for fixed_qparams_op, observer in _FIXED_QPARAMS_OP_TO_OBSERVER.items(): + if observer in fixed_qparams_observer_to_qconfig: + fixed_qparams_qconfig = fixed_qparams_observer_to_qconfig[observer] + else: + if is_qat: + activation = FixedQParamsFakeQuantize.with_args(observer=observer) + else: + activation = observer + fixed_qparams_qconfig = QConfig( + activation=activation, weight=default_weight + ) + fixed_qparams_observer_to_qconfig[observer] = fixed_qparams_qconfig + qconfig_mapping.set_object_type(fixed_qparams_op, fixed_qparams_qconfig) + + # TODO Currently it's required that separate ops in a fused op/module have the same qconfig. + # Need to be able to support fusion of ops with different qconfigs + + return qconfig_mapping + + +def get_default_qconfig_mapping(backend="x86", version=0) -> QConfigMapping: + """ + Return the default QConfigMapping for post training quantization. + + Args: + * ``backend`` (str) : the quantization backend for the default qconfig mapping, should be + one of ["x86" (default), "fbgemm", "qnnpack", "onednn"] + * ``version`` (int) : the version for the default qconfig mapping + """ + # TODO: add assert for backend choices + return _get_default_qconfig_mapping(False, backend, version) + + +def get_default_qat_qconfig_mapping(backend="x86", version=1) -> QConfigMapping: + """ + Return the default QConfigMapping for quantization aware training. + + Args: + * ``backend`` (str) : the quantization backend for the default qconfig mapping, should be + one of ["x86" (default), "fbgemm", "qnnpack", "onednn"] + * ``version`` (int) : the version for the default qconfig mapping + """ + return _get_default_qconfig_mapping(True, backend, version) + + +def _get_symmetric_qnnpack_qconfig_mapping() -> QConfigMapping: + """ + Return a QConfigMapping that uses `torch.ao.quantization.default_symmetric_qnnpack_qconfig` + as the default QConfig. + """ + default_qconfig = default_symmetric_qnnpack_qconfig + return _get_default_qconfig_mapping_with_default_qconfig( + False, "qnnpack", default_qconfig + ) + + +def _get_symmetric_qnnpack_qat_qconfig_mapping() -> QConfigMapping: + """ + Return a QConfigMapping that uses `torch.ao.quantization.default_symmetric_qnnpack_qat_qconfig` + as the default QConfig. + """ + default_qconfig = default_symmetric_qnnpack_qat_qconfig + return _get_default_qconfig_mapping_with_default_qconfig( + True, "qnnpack", default_qconfig + ) + + +def _get_default_qconfig_mapping_with_default_qconfig( + is_qat: bool, + backend: str, + default_qconfig: QConfig, +) -> QConfigMapping: + """ + Return a QConfigMapping that uses the provided qconfig as the default QConfig. + """ + if is_qat: + qconfig_mapping = get_default_qat_qconfig_mapping(backend) + else: + qconfig_mapping = get_default_qconfig_mapping(backend) + qconfig_mapping.set_global(default_qconfig) + for pattern in qconfig_mapping.object_type_qconfigs: + if pattern not in _FIXED_QPARAMS_OP_TO_OBSERVER: + qconfig_mapping.set_object_type(pattern, default_qconfig) + return qconfig_mapping + + +_QCONFIG_STYLE_ORDER: list[str] = [ + "global_qconfig", + "object_type_qconfigs", + "module_name_regex_qconfigs", + "module_name_qconfigs", + "module_name_object_type_order_qconfigs", +] + + +class QConfigMapping: + """ + Mapping from model ops to :class:`torch.ao.quantization.QConfig` s. + + The user can specify QConfigs using the following methods (in increasing match priority): + + ``set_global`` : sets the global (default) QConfig + + ``set_object_type`` : sets the QConfig for a given module type, function, or method name + + ``set_module_name_regex`` : sets the QConfig for modules matching the given regex string + + ``set_module_name`` : sets the QConfig for modules matching the given module name + + ``set_module_name_object_type_order`` : sets the QConfig for modules matching a combination + of the given module name, object type, and the index at which the module appears + + Example usage:: + + qconfig_mapping = QConfigMapping() + .set_global(global_qconfig) + .set_object_type(torch.nn.Linear, qconfig1) + .set_object_type(torch.nn.ReLU, qconfig1) + .set_module_name_regex("foo.*bar.*conv[0-9]+", qconfig1) + .set_module_name_regex("foo.*", qconfig2) + .set_module_name("module1", qconfig1) + .set_module_name("module2", qconfig2) + .set_module_name_object_type_order("foo.bar", torch.nn.functional.linear, 0, qconfig3) + + """ + + def __init__(self) -> None: + # In increasing match priority: + self.global_qconfig: QConfigAny = None + self.object_type_qconfigs: OrderedDict[Callable | str, QConfigAny] = ( + OrderedDict() + ) + self.module_name_regex_qconfigs: OrderedDict[str, QConfigAny] = OrderedDict() + self.module_name_qconfigs: OrderedDict[str, QConfigAny] = OrderedDict() + self.module_name_object_type_order_qconfigs: OrderedDict[ + tuple[str, Callable, int], QConfigAny + ] = OrderedDict() + + def set_global(self, global_qconfig: QConfigAny) -> QConfigMapping: + """ + Set the global (default) QConfig. + """ + self.global_qconfig = global_qconfig + return self + + def set_object_type( + self, object_type: Callable | str, qconfig: QConfigAny + ) -> QConfigMapping: + """ + Set the QConfig for a given module type, function, or method name. + If the QConfig for an existing object type was already set, the new QConfig will override the old one. + """ + self.object_type_qconfigs[object_type] = qconfig + return self + + def set_module_name_regex( + self, module_name_regex: str, qconfig: QConfigAny + ) -> QConfigMapping: + """ + Set the QConfig for modules matching the given regex string. + + Regexes will be matched in the order in which they are registered through this method. + Thus, the caller should register more specific patterns first, e.g.:: + + qconfig_mapping = QConfigMapping() + .set_module_name_regex("foo.*bar.*conv[0-9]+", qconfig1) + .set_module_name_regex("foo.*bar.*", qconfig2) + .set_module_name_regex("foo.*", qconfig3) + + In this example, "foo.bar.conv0" would match qconfig1, "foo.bar.linear" would match qconfig2, + and "foo.baz.relu" would match qconfig3. + + If the QConfig for an existing module name regex was already set, the new QConfig will override the + old one while preserving the order in which the regexes were originally registered. + """ + self.module_name_regex_qconfigs[module_name_regex] = qconfig + return self + + def set_module_name(self, module_name: str, qconfig: QConfigAny) -> QConfigMapping: + """ + Set the QConfig for modules matching the given module name. + If the QConfig for an existing module name was already set, the new QConfig will override the old one. + """ + self.module_name_qconfigs[module_name] = qconfig + return self + + def set_module_name_object_type_order( + self, module_name: str, object_type: Callable, index: int, qconfig: QConfigAny + ) -> QConfigMapping: + """ + Set the QConfig for modules matching a combination of the given module name, object type, + and the index at which the module appears. + + If the QConfig for an existing (module name, object type, index) was already set, the new QConfig + will override the old one. + """ + self.module_name_object_type_order_qconfigs[ + (module_name, object_type, index) + ] = qconfig + return self + + def __repr__(self) -> str: + output = self.__class__.__name__ + " (" + for style_name in _QCONFIG_STYLE_ORDER: + output += f"\n {style_name}" + qconfigs = getattr(self, style_name) + if isinstance(qconfigs, OrderedDict) and len(qconfigs) > 0: + for key, qconfig in qconfigs.items(): + output += f"\n {key}: {qconfig}" + else: + output += f"\n {qconfigs}" + return output + "\n)" + + # TODO: remove this + def to_dict(self) -> dict[str, Any]: + """ + Convert this ``QConfigMapping`` to a dictionary with the following keys: + + "" (for global QConfig) + + "object_type" + + "module_name_regex" + + "module_name" + + "module_name_object_type_order" + + The values of this dictionary are lists of tuples. + """ + return { + _GLOBAL_DICT_KEY: self.global_qconfig, + _OBJECT_TYPE_DICT_KEY: list(self.object_type_qconfigs.items()), + _MODULE_NAME_REGEX_DICT_KEY: list(self.module_name_regex_qconfigs.items()), + _MODULE_NAME_DICT_KEY: list(self.module_name_qconfigs.items()), + _MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY: [ + (*k, v) for k, v in self.module_name_object_type_order_qconfigs.items() + ], + } + + # TODO: remove this + @classmethod + def from_dict(cls, qconfig_dict: dict[str, Any]) -> QConfigMapping: + """ + Create a ``QConfigMapping`` from a dictionary with the following keys (all optional): + + "" (for global QConfig) + + "object_type" + + "module_name_regex" + + "module_name" + + "module_name_object_type_order" + + The values of this dictionary are expected to be lists of tuples. + """ + conf = cls() + if _GLOBAL_DICT_KEY in qconfig_dict: + conf.set_global(qconfig_dict[_GLOBAL_DICT_KEY]) + for object_type, qconfig in qconfig_dict.get(_OBJECT_TYPE_DICT_KEY, []): + conf.set_object_type(object_type, qconfig) + for module_name_regex, qconfig in qconfig_dict.get( + _MODULE_NAME_REGEX_DICT_KEY, [] + ): + conf.set_module_name_regex(module_name_regex, qconfig) + for module_name, qconfig in qconfig_dict.get(_MODULE_NAME_DICT_KEY, []): + conf.set_module_name(module_name, qconfig) + for module_name, object_type, index, qconfig in qconfig_dict.get( + _MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY, [] + ): + conf.set_module_name_object_type_order( + module_name, object_type, index, qconfig + ) + return conf diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quant_type.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quant_type.py new file mode 100644 index 0000000000000000000000000000000000000000..18488d7f9ccba604ca8f1df7ea0ef4a88546d63e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quant_type.py @@ -0,0 +1,35 @@ +import enum + + +__all__ = [ + "QuantType", +] + + +# Quantization type (dynamic quantization, static quantization). +# Should match the c++ enum in quantization_type.h +class QuantType(enum.IntEnum): + DYNAMIC = 0 + STATIC = 1 + QAT = 2 + WEIGHT_ONLY = 3 + + +_quant_type_to_str = { + QuantType.STATIC: "static", + QuantType.DYNAMIC: "dynamic", + QuantType.QAT: "qat", + QuantType.WEIGHT_ONLY: "weight_only", +} + + +# TODO: make this private +def _get_quant_type_to_str(quant_type: QuantType) -> str: + return _quant_type_to_str[quant_type] + + +def _quant_type_from_str(name: str) -> QuantType: + for quant_type, s in _quant_type_to_str.items(): + if name == s: + return quant_type + raise ValueError(f"Unknown QuantType name '{name}'") diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantization_mappings.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantization_mappings.py new file mode 100644 index 0000000000000000000000000000000000000000..647ed5a4d4f3946626ef360a7a45541719136006 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantization_mappings.py @@ -0,0 +1,369 @@ +import copy +from collections.abc import Callable +from typing import Any + +import torch +import torch.ao.nn as ao_nn +import torch.ao.nn.intrinsic as nni +import torch.ao.nn.intrinsic.qat as nniqat +import torch.ao.nn.intrinsic.quantized as nniq +import torch.ao.nn.intrinsic.quantized.dynamic as nniqd +import torch.ao.nn.qat as nnqat +import torch.ao.nn.qat.dynamic as nnqatd +import torch.ao.nn.quantized as nnq +import torch.ao.nn.quantized.dynamic as nnqd +import torch.ao.nn.quantized.reference as nnqr + +# Because `torch.ao.nn` uses lazy imports, we need to make +# sure we import the contents explicitly here. +import torch.ao.nn.sparse +import torch.nn.functional as F +from torch import nn +from torch.ao.quantization.fake_quantize import ( + default_fixed_qparams_range_0to1_fake_quant, + default_fixed_qparams_range_neg1to1_fake_quant, +) +from torch.ao.quantization.stubs import DeQuantStub, QuantStub +from torch.ao.quantization.utils import get_combined_dict +from torch.nn.utils.parametrize import type_before_parametrizations + + +__all__ = [ + "DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS", + "DEFAULT_STATIC_QUANT_MODULE_MAPPINGS", + "DEFAULT_QAT_MODULE_MAPPINGS", + "DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS", + "DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS", + "DEFAULT_MODULE_TO_ACT_POST_PROCESS", + "DEFAULT_STATIC_SPARSE_QUANT_MODULE_MAPPINGS", + "DEFAULT_DYNAMIC_SPARSE_QUANT_MODULE_MAPPINGS", + "no_observer_set", + "get_default_static_quant_module_mappings", + "get_default_static_quant_reference_module_mappings", + "get_embedding_static_quant_module_mappings", + "get_default_static_sparse_quant_module_mappings", + "get_static_quant_module_class", + "get_dynamic_quant_module_class", + "get_default_qat_module_mappings", + "get_embedding_qat_module_mappings", + "get_default_dynamic_quant_module_mappings", + "get_default_dynamic_sparse_quant_module_mappings", + "get_default_qconfig_propagation_list", + "get_default_compare_output_module_list", + "get_default_float_to_quantized_operator_mappings", + "get_quantized_operator", +] + +# Default map for swapping float module to reference quantized modules +DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS: dict[Callable, Any] = { + QuantStub: nnq.Quantize, + DeQuantStub: nnq.DeQuantize, + nn.Linear: nnqr.Linear, + nn.Conv1d: nnqr.Conv1d, + nn.Conv2d: nnqr.Conv2d, + nn.Conv3d: nnqr.Conv3d, + nn.ConvTranspose1d: nnqr.ConvTranspose1d, + nn.ConvTranspose2d: nnqr.ConvTranspose2d, + nn.ConvTranspose3d: nnqr.ConvTranspose3d, + nn.Embedding: nnqr.Embedding, + nn.EmbeddingBag: nnqr.EmbeddingBag, + nn.GRUCell: nnqr.GRUCell, + nn.LSTMCell: nnqr.LSTMCell, + nn.RNNCell: nnqr.RNNCell, + nn.LSTM: nnqr.LSTM, +} + +# Default map for swapping float module to quantized ones +DEFAULT_STATIC_QUANT_MODULE_MAPPINGS: dict[Callable, Any] = { + QuantStub: nnq.Quantize, + DeQuantStub: nnq.DeQuantize, + nn.BatchNorm2d: nnq.BatchNorm2d, + nn.BatchNorm3d: nnq.BatchNorm3d, + nn.Dropout: nnq.Dropout, + nn.Conv1d: nnq.Conv1d, + nn.Conv2d: nnq.Conv2d, + nn.Conv3d: nnq.Conv3d, + nn.ConvTranspose1d: nnq.ConvTranspose1d, + nn.ConvTranspose2d: nnq.ConvTranspose2d, + nn.ConvTranspose3d: nnq.ConvTranspose3d, + nn.ELU: nnq.ELU, + nn.Embedding: nnq.Embedding, + nn.EmbeddingBag: nnq.EmbeddingBag, + nn.GroupNorm: nnq.GroupNorm, + nn.Hardswish: nnq.Hardswish, + nn.InstanceNorm1d: nnq.InstanceNorm1d, + nn.InstanceNorm2d: nnq.InstanceNorm2d, + nn.InstanceNorm3d: nnq.InstanceNorm3d, + nn.LayerNorm: nnq.LayerNorm, + nn.LeakyReLU: nnq.LeakyReLU, + nn.modules.linear.NonDynamicallyQuantizableLinear: nnq.Linear, + nn.Linear: nnq.Linear, + nn.ReLU6: nnq.ReLU6, + nn.PReLU: nnq.PReLU, + # Wrapper Modules: + nnq.FloatFunctional: nnq.QFunctional, + # Intrinsic modules: + nni.BNReLU2d: nniq.BNReLU2d, + nni.BNReLU3d: nniq.BNReLU3d, + nni.ConvReLU1d: nniq.ConvReLU1d, + nni.ConvReLU2d: nniq.ConvReLU2d, + nni.ConvReLU3d: nniq.ConvReLU3d, + nni.ConvAdd2d: nniq.ConvAdd2d, + nni.ConvAddReLU2d: nniq.ConvAddReLU2d, + nni.LinearReLU: nniq.LinearReLU, + nni.LinearLeakyReLU: nniq.LinearLeakyReLU, + nni.LinearTanh: nniq.LinearTanh, + nniqat.ConvBn1d: nnq.Conv1d, + nniqat.ConvBn2d: nnq.Conv2d, + nniqat.ConvBn3d: nnq.Conv3d, + nniqat.ConvBnReLU1d: nniq.ConvReLU1d, + nniqat.ConvBnReLU2d: nniq.ConvReLU2d, + nniqat.ConvBnReLU3d: nniq.ConvReLU3d, + nniqat.ConvReLU2d: nniq.ConvReLU2d, + nniqat.ConvReLU3d: nniq.ConvReLU3d, + nniqat.LinearReLU: nniq.LinearReLU, + nniqat.LinearBn1d: nnq.Linear, + # QAT modules: + nnqat.Linear: nnq.Linear, + nnqat.Conv2d: nnq.Conv2d, + nnqat.Conv3d: nnq.Conv3d, +} + +# Default map for swapping float module to qat modules +DEFAULT_QAT_MODULE_MAPPINGS: dict[Callable, Any] = { + nn.Conv2d: nnqat.Conv2d, + nn.Conv3d: nnqat.Conv3d, + nn.Linear: nnqat.Linear, + nn.modules.linear.NonDynamicallyQuantizableLinear: nnqat.Linear, + # Intrinsic modules: + nni.ConvBn1d: nniqat.ConvBn1d, + nni.ConvBn2d: nniqat.ConvBn2d, + nni.ConvBn3d: nniqat.ConvBn3d, + nni.ConvBnReLU1d: nniqat.ConvBnReLU1d, + nni.ConvBnReLU2d: nniqat.ConvBnReLU2d, + nni.ConvBnReLU3d: nniqat.ConvBnReLU3d, + nni.ConvReLU2d: nniqat.ConvReLU2d, + nni.ConvReLU3d: nniqat.ConvReLU3d, + nni.LinearReLU: nniqat.LinearReLU, + nni.LinearBn1d: nniqat.LinearBn1d, +} + +# Default map for swapping dynamic modules +DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS: dict[Callable, Any] = { + nn.GRUCell: nnqd.GRUCell, + nn.Linear: nnqd.Linear, + nnqatd.Linear: nnqd.Linear, + nn.modules.linear.NonDynamicallyQuantizableLinear: nnqd.Linear, + nn.LSTM: nnqd.LSTM, + nn.GRU: nnqd.GRU, + nn.LSTMCell: nnqd.LSTMCell, + nn.RNNCell: nnqd.RNNCell, + nni.LinearReLU: nniqd.LinearReLU, + nn.EmbeddingBag: nnq.EmbeddingBag, + nn.Embedding: nnq.Embedding, + # Don't want to enable these by default because the numerical + # accuracy is poor compared to other dynamic ops + # nn.Conv1d: nnqd.Conv1d, + # nn.Conv2d: nnqd.Conv2d, + # nn.Conv3d: nnqd.Conv3d, + # nn.ConvTranspose1d: nnqd.ConvTranspose1d, + # nn.ConvTranspose2d: nnqd.ConvTranspose2d, + # nn.ConvTranspose3d: nnqd.ConvTranspose3d, +} + +# Allowlist for propagating the qconfig +_INCLUDE_QCONFIG_PROPAGATE_LIST: set[Callable] = { + nn.Sequential, +} + +# Default mapping from floating point function or torch ops to quantized ops +# TODO: merge with default static mapping +DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS: dict[Callable | str, Callable] = { + F.elu: torch.ops.quantized.elu, + F.hardswish: torch.ops.quantized.hardswish, + F.instance_norm: torch.ops.quantized.instance_norm, + F.layer_norm: torch.ops.quantized.layer_norm, + F.leaky_relu: torch.ops.quantized.leaky_relu, + F.dropout: torch.ops.quantized.dropout, +} + +# mapping from module to output activation post process class +DEFAULT_MODULE_TO_ACT_POST_PROCESS: dict[Callable, Callable] = { + nn.Hardsigmoid: default_fixed_qparams_range_0to1_fake_quant, + nn.Sigmoid: default_fixed_qparams_range_0to1_fake_quant, + nn.Softmax: default_fixed_qparams_range_0to1_fake_quant, + nn.Tanh: default_fixed_qparams_range_neg1to1_fake_quant, +} + +# Default map for swapping float module to static sparse quantized ones +DEFAULT_STATIC_SPARSE_QUANT_MODULE_MAPPINGS: dict[Callable, Any] = { + nn.Linear: ao_nn.sparse.quantized.Linear +} + +# Default map for swapping float module to dynamic sparse quantized ones +DEFAULT_DYNAMIC_SPARSE_QUANT_MODULE_MAPPINGS: dict[Callable, Any] = { + nn.Linear: ao_nn.sparse.quantized.dynamic.Linear +} + + +def no_observer_set() -> set[Any]: + r"""These modules cannot have observers inserted by default.""" + no_observers = {nn.quantizable.LSTM, nn.quantizable.MultiheadAttention} + return no_observers + + +def get_default_static_quant_module_mappings() -> dict[Callable, Any]: + """Get module mapping for post training static quantization""" + return copy.deepcopy(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS) + + +def get_default_static_quant_reference_module_mappings() -> dict[Callable, Any]: + """Get reference module mapping for post training static quantization""" + return copy.deepcopy(DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS) + + +def get_embedding_static_quant_module_mappings() -> dict[Callable, Any]: + """Get module mapping, including mapping for embedding QAT""" + mapping = copy.deepcopy(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS) + mapping[nnqat.EmbeddingBag] = nnq.EmbeddingBag + mapping[nnqat.Embedding] = nnq.Embedding + return mapping + + +def get_default_static_sparse_quant_module_mappings() -> dict[Callable, Any]: + """Get module mapping for post training static sparse quantization""" + return copy.deepcopy(DEFAULT_STATIC_SPARSE_QUANT_MODULE_MAPPINGS) + + +def get_static_quant_module_class( + float_module_class: Callable, + additional_static_quant_mapping: dict[Callable, Any] | None = None, + is_reference: bool = False, +) -> Any: + r"""n Get the statically quantized module class corresponding to + the floating point module class + """ + if additional_static_quant_mapping is None: + additional_static_quant_mapping = {} + all_mappings = get_combined_dict( + DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS + if is_reference + else DEFAULT_STATIC_QUANT_MODULE_MAPPINGS, + additional_static_quant_mapping, + ) + static_quant_module_class = all_mappings.get(float_module_class, None) + if static_quant_module_class is None: + raise AssertionError( + f"Floating point module class {str(float_module_class)}" + + " does not have a corresponding quantized module class" + ) + return copy.deepcopy(static_quant_module_class) + + +def get_dynamic_quant_module_class( + float_module_class: Callable, + additional_dynamic_quant_mapping: dict[Callable, Any] | None = None, +) -> Any: + r"""n Get the dynamically quantized module class corresponding to + the floating point module class + """ + if additional_dynamic_quant_mapping is None: + additional_dynamic_quant_mapping = {} + all_mappings = get_combined_dict( + DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS, additional_dynamic_quant_mapping + ) + dynamic_quant_module_class = all_mappings.get(float_module_class, None) + if dynamic_quant_module_class is None: + raise AssertionError( + f"Floating point module class {str(float_module_class)}" + + " does not have a corresponding quantized module class" + ) + return copy.deepcopy(dynamic_quant_module_class) + + +def get_default_qat_module_mappings() -> dict[Callable, Any]: + """Get default module mapping for quantization aware training""" + return copy.deepcopy(DEFAULT_QAT_MODULE_MAPPINGS) + + +def get_embedding_qat_module_mappings() -> dict[Callable, Any]: + """Get module mapping for quantization aware training + This is includes default values in addition to + enabling qat for embeddings. + """ + mapping = copy.deepcopy(DEFAULT_QAT_MODULE_MAPPINGS) + mapping[nn.EmbeddingBag] = nnqat.EmbeddingBag + mapping[nn.Embedding] = nnqat.Embedding + return mapping + + +def get_default_dynamic_quant_module_mappings() -> dict[Callable, Any]: + """Get module mapping for post training dynamic quantization""" + return DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS + + +def get_default_dynamic_sparse_quant_module_mappings() -> dict[Callable, Any]: + """Get module mapping for post training dynamic sparse quantization""" + return DEFAULT_DYNAMIC_SPARSE_QUANT_MODULE_MAPPINGS + + +def get_default_qconfig_propagation_list() -> set[Callable]: + """Get the default list of module types that we'll attach qconfig + attribute to in prepare + """ + QCONFIG_PROPAGATE_MODULE_CLASS_LIST = ( + set(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS.keys()) + | set(DEFAULT_QAT_MODULE_MAPPINGS.keys()) + | set(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS.keys()) + | _INCLUDE_QCONFIG_PROPAGATE_LIST + ) + return copy.deepcopy(QCONFIG_PROPAGATE_MODULE_CLASS_LIST) + + +def get_default_compare_output_module_list() -> set[Callable]: + """Get list of module class types that we will record output + in numeric suite + """ + NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_MODULE_LIST = ( + set(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS.values()) + | set(DEFAULT_QAT_MODULE_MAPPINGS.values()) + | set(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS.values()) + | set(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS.keys()) + | set(DEFAULT_QAT_MODULE_MAPPINGS.keys()) + | set(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS.keys()) + | _INCLUDE_QCONFIG_PROPAGATE_LIST + ) + return copy.deepcopy(NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_MODULE_LIST) + + +def get_default_float_to_quantized_operator_mappings() -> dict[ + Callable | str, Callable +]: + return copy.deepcopy(DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS) + + +# TODO: merge with get_static_quant_module_class +def get_quantized_operator(float_op: Callable | str) -> Callable: + """Get the quantized operator corresponding to the float operator""" + quantized_op = DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS.get(float_op) + if quantized_op is None: + raise AssertionError( + f"Operator {str(float_op)} does not have corresponding quantized op" + ) + return quantized_op + + +def _get_special_act_post_process(module: torch.nn.Module) -> Callable | None: + r"""Get the special activation post process for `module`, this has + higher priority than the activation post process in `qconfig` + e.g. + input: torch.nn.Sigmoid + output: default_affine_fixed_qparam_fake_quant + """ + return DEFAULT_MODULE_TO_ACT_POST_PROCESS.get( + type_before_parametrizations(module), None + ) + + +def _has_special_act_post_process(module: torch.nn.Module) -> bool: + return module.training and type(module) in DEFAULT_MODULE_TO_ACT_POST_PROCESS diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantize.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..e71dd24fda745d7f23f671eedaa1ff43df147a9a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantize.py @@ -0,0 +1,829 @@ +# mypy: allow-untyped-defs +import copy +import inspect +import itertools +import typing_extensions +import warnings + +import torch +import torch.ao.nn.quantized as nnq +import torch.nn as nn +from torch.ao.nn.intrinsic import _FusedModule +from torch.ao.quantization.observer import _is_activation_post_process +from torch.ao.quantization.qconfig import ( + _activation_is_memoryless, + _add_module_to_qconfig_obs_ctr, + default_dynamic_qconfig, + float16_dynamic_qconfig, + float_qparams_weight_only_qconfig, + float_qparams_weight_only_qconfig_4bit, +) +from torch.ao.quantization.quantization_mappings import ( + _get_special_act_post_process, + _has_special_act_post_process, + get_default_dynamic_quant_module_mappings, + get_default_qat_module_mappings, + get_default_qconfig_propagation_list, + get_default_static_quant_module_mappings, + get_default_static_quant_reference_module_mappings, + no_observer_set, +) +from torch.ao.quantization.stubs import DeQuantStub, QuantWrapper +from torch.nn.utils.parametrize import type_before_parametrizations + +from .utils import ( + DEPRECATION_WARNING, + get_qparam_dict, + has_no_children_ignoring_parametrizations, +) + + +__all__ = [ + "get_default_custom_config_dict", + "propagate_qconfig_", + "add_quant_dequant", + "prepare", + "quantize", + "quantize_dynamic", + "prepare_qat", + "quantize_qat", + "convert", + "swap_module", +] + + +# TODO remove this once BC is no longer required to avoid a SEV +is_activation_post_process = _is_activation_post_process + + +_DEFAULT_CUSTOM_CONFIG_DICT = { + "float_to_observed_custom_module_class": { + nn.LSTM: nn.quantizable.LSTM, + nn.MultiheadAttention: nn.quantizable.MultiheadAttention, + }, + "observed_to_quantized_custom_module_class": { + nn.quantizable.LSTM: nn.quantized.LSTM, + nn.quantizable.MultiheadAttention: nn.quantized.MultiheadAttention, + }, +} + + +def get_default_custom_config_dict(): + r"""Defines the default custom config dict.""" + return _DEFAULT_CUSTOM_CONFIG_DICT + + +def _propagate_qconfig_helper( + module, + qconfig_dict, + qconfig_parent=None, + prefix="", + prepare_custom_config_dict=None, +): + r"""This is a helper function for `propagate_qconfig_` + + Args: + module: input module + qconfig_dict: dictionary that maps from name of submodule to quantization + configuration + qconfig_parent: quantization config of parent module, we will fallback to + this config when there is no specified config for current + module + prefix: corresponding prefix of the current module, used as key in + qconfig_dict + prepare_custom_config_dict: dictionary for custom handling of modules + see docs for :func:`~torch.ao.quantization.prepare_fx` + + Return: + None, module is modified inplace with qconfig attached + """ + + module_qconfig = qconfig_dict.get( + type_before_parametrizations(module), qconfig_parent + ) + module_qconfig = qconfig_dict.get(prefix, module_qconfig) + module_qconfig = getattr(module, "qconfig", module_qconfig) + + torch.ao.quantization.qconfig._assert_valid_qconfig(module_qconfig, module) + + qconfig_with_device_check = _add_module_to_qconfig_obs_ctr(module_qconfig, module) + module.qconfig = qconfig_with_device_check + + for name, child in module.named_children(): + module_prefix = prefix + "." + name if prefix else name + # do no not propagate qconfig to child if child is non traceable + if prepare_custom_config_dict is None or not ( + name in prepare_custom_config_dict.get("non_traceable_module_name", []) + or type(child) + in prepare_custom_config_dict.get("non_traceable_module_class", []) + ): + _propagate_qconfig_helper( + child, qconfig_dict, qconfig_with_device_check, module_prefix + ) + + +def propagate_qconfig_(module, qconfig_dict=None, prepare_custom_config_dict=None): + r"""Propagate qconfig through the module hierarchy and assign `qconfig` + attribute on each leaf module + + Args: + module: input module + qconfig_dict: dictionary that maps from name or type of submodule to + quantization configuration, qconfig applies to all submodules of a + given module unless qconfig for the submodules are specified (when + the submodule already has qconfig attribute) + prepare_custom_config_dict: dictionary for custom handling of modules + see docs for :func:`~torch.ao.quantization.prepare_fx` + + Return: + None, module is modified inplace with qconfig attached + """ + if qconfig_dict is None: + qconfig_dict = {} + if prepare_custom_config_dict is None: + prepare_custom_config_dict = {} + _propagate_qconfig_helper( + module, qconfig_dict, prepare_custom_config_dict=prepare_custom_config_dict + ) + + +def _observer_forward_hook(self, input, output): + r"""Forward hook that calls observer on the output""" + return self.activation_post_process(output) + + +def _observer_forward_pre_hook(self, input): + r"""Forward pre hook that calls observer on the output""" + return self.activation_post_process(input[0]) + + +def _register_activation_post_process_hook(module, pre_hook=False): + if not hasattr(module, "activation_post_process"): + raise AssertionError( + "Expect activation_post_process attribute already attached to the module" + ) + if pre_hook: + module.register_forward_pre_hook(_observer_forward_pre_hook, prepend=True) + else: + module.register_forward_hook(_observer_forward_hook, prepend=True) + + +def _add_observer_( + module, + qconfig_propagation_list=None, + non_leaf_module_list=None, + device=None, + custom_module_class_mapping=None, +): + r"""Add observer for the leaf child of the module. + + This function insert observer module to all leaf child module that + has a valid qconfig attribute. + + Args: + module: input module with qconfig attributes for all the leaf modules that we want to quantize + qconfig_propagation_list: a list of quantizable modules that will have observers added to them + if they are leaf nodes + device: parent device, if any + non_leaf_module_list: list of non-leaf modules we want to add observer + + Return: + None, module is modified inplace with added observer modules and forward_hooks + """ + if qconfig_propagation_list is None: + qconfig_propagation_list = get_default_qconfig_propagation_list() + + if custom_module_class_mapping is None: + custom_module_class_mapping = {} + + # respect device affinity when adding observers + if device is None: + devices = _get_unique_devices_(module) + if len(devices) > 1: + raise AssertionError( + f"_add_observer_ only works with cpu or single-device CUDA modules, but got devices {devices}" + ) + device = next(iter(devices)) if len(devices) > 0 else None + + def get_activation_post_process(qconfig, device, special_act_post_process=None): + activation = ( + qconfig.activation() + if special_act_post_process is None + else special_act_post_process() + ) + if device is not None: + activation.to(device) + return activation + + def needs_observation(m): + return hasattr(m, "qconfig") and m.qconfig is not None + + def insert_activation_post_process(m, special_act_post_process=None): + """Adds an activation post process module and register + a pre or post hook that calls the module + """ + # We don't insert observer/fake_quantize for DeQuantStub + if needs_observation(m) and not isinstance(m, DeQuantStub): + # observer and hook will be gone after we swap the module + m.add_module( + "activation_post_process", + get_activation_post_process( + m.qconfig, device, special_act_post_process + ), + ) + # Register observer as the first entry in the hook list + # All post forward hooks are preserved and will be executed after the observer before convert + _register_activation_post_process_hook( + m, pre_hook=_activation_is_memoryless(m.qconfig) + ) + + for name, child in module.named_children(): + # TODO remove Dropout special after codebase stable + if type_before_parametrizations(child) is nn.Dropout: + continue + elif issubclass( + type_before_parametrizations(child), (nnq.FloatFunctional, nnq.QFunctional) + ): + if needs_observation(child): + if not hasattr(child, "activation_post_process"): + raise AssertionError( + f"functional class {type_before_parametrizations(child)} has no pre-defined `activation_post_process`" + ) + child.activation_post_process = get_activation_post_process( + child.qconfig, device + ) + elif isinstance(child, _FusedModule): + # activation_post_process are now added directly to nn.Sequential/_FusedModule + if needs_observation(child): + insert_activation_post_process(child) + elif ( + non_leaf_module_list is not None + and type_before_parametrizations(child) in non_leaf_module_list + ): + if needs_observation(child): + insert_activation_post_process(child) + elif _has_special_act_post_process(child): + special_act_post_process = _get_special_act_post_process(child) + insert_activation_post_process(child, special_act_post_process) + elif ( + needs_observation(child) + and type_before_parametrizations(child) in custom_module_class_mapping + ): + observed_class = custom_module_class_mapping[ + type_before_parametrizations(child) + ] + observed_child = observed_class.from_float(child) + setattr(module, name, observed_child) + # TODO: These are the modules that cannot be observed + # Once there are more, we should move them to a separate list + if not issubclass(observed_class, tuple(no_observer_set())): + insert_activation_post_process(observed_child) + else: + _add_observer_( + child, + qconfig_propagation_list, + non_leaf_module_list, + device, + custom_module_class_mapping, + ) + + # Insert observers only for leaf nodes, note that this observer is for + # the output of the module, for input QuantStub will observe them + if ( + has_no_children_ignoring_parametrizations(module) + and not isinstance(module, torch.nn.Sequential) + and type_before_parametrizations(module) in qconfig_propagation_list + ): + insert_activation_post_process(module) + # This is a special case for AdaRound eager mode + # AdaRound contains weight_fake_quant to be propagated from API to convert + # leaf node check with a number of children looks naive assumption that blocks + # Adding an exception case for AdaRound + if ( + hasattr(module, "weight_fake_quant") + and not isinstance(module, torch.nn.Sequential) + and type_before_parametrizations(module) in qconfig_propagation_list + ): + insert_activation_post_process(module) + + +def _get_unique_devices_(module): + return {p.device for p in module.parameters() if p.device.type != "meta"} | { + p.device for p in module.buffers() if p.device.type != "meta" + } + + +def add_quant_dequant(module): + r"""Wrap the leaf child module in QuantWrapper if it has a valid qconfig + Note that this function will modify the children of module inplace and it + can return a new module which wraps the input module as well. + + Args: + module: input module with qconfig attributes for all the leaf modules + that we want to quantize + + Return: + Either the inplace modified module with submodules wrapped in + `QuantWrapper` based on qconfig or a new `QuantWrapper` module which + wraps the input module, the latter case only happens when the input + module is a leaf module and we want to quantize it. + """ + if ( + has_no_children_ignoring_parametrizations(module) + and hasattr(module, "qconfig") + and module.qconfig + ): + return QuantWrapper(module) + + for name, child in module.named_children(): + module._modules[name] = add_quant_dequant(child) + return module + + +@typing_extensions.deprecated(DEPRECATION_WARNING) +def prepare( + model, + inplace=False, + allow_list=None, + observer_non_leaf_module_list=None, + prepare_custom_config_dict=None, +): + r"""Prepares a copy of the model for quantization calibration or quantization-aware training. + + Quantization configuration should be assigned preemptively + to individual submodules in `.qconfig` attribute. + + The model will be attached with observer or fake quant modules, and qconfig + will be propagated. + + Args: + `model`: input model to be modified in-place + `inplace`: carry out model transformations in-place, the original module is mutated + `allow_list`: list of quantizable modules + `observer_non_leaf_module_list`: list of non-leaf modules we want to add observer + `prepare_custom_config_dict`: customization configuration dictionary for prepare function + + .. code-block:: python + + # Example of prepare_custom_config_dict: + prepare_custom_config_dict = { + # user will manually define the corresponding observed + # module class which has a from_float class method that converts + # float custom module to observed custom module + "float_to_observed_custom_module_class": {CustomModule: ObservedCustomModule} + } + + """ + torch._C._log_api_usage_once("quantization_api.quantize.prepare") + if prepare_custom_config_dict is None: + prepare_custom_config_dict = get_default_custom_config_dict() + custom_module_class_mapping = prepare_custom_config_dict.get( + "float_to_observed_custom_module_class", {} + ) + + if not inplace: + model = copy.deepcopy(model) + + # TODO: remove allow_list + qconfig_propagation_list = allow_list + if allow_list is None: + qconfig_propagation_list = get_default_qconfig_propagation_list() + propagate_qconfig_(model, qconfig_dict=None) + + # sanity check common API misusage + if not any(hasattr(m, "qconfig") and m.qconfig for m in model.modules()): + warnings.warn( + "None of the submodule got qconfig applied. Make sure you " + "passed correct configuration through `qconfig_dict` or " + "by assigning the `.qconfig` attribute directly on submodules", + stacklevel=2, + ) + + _add_observer_( + model, + qconfig_propagation_list, + observer_non_leaf_module_list, + custom_module_class_mapping=custom_module_class_mapping, + ) + return model + + +def _remove_activation_post_process(module): + # TODO: maybe we should change activation_post_process to _activation_post_process + # to prevent it from being used by user + if hasattr(module, "activation_post_process") and _is_activation_post_process( + module.activation_post_process + ): + delattr(module, "activation_post_process") + + # remove activation_post_process pre and post hooks + def remove_hooks(pre_hook=False): + hook_map = module._forward_pre_hooks if pre_hook else module._forward_hooks + observer_hook = ( + _observer_forward_pre_hook if pre_hook else _observer_forward_hook + ) + handle_ids_to_remove = set() + for handle_id, hook_fn in hook_map.items(): + if hook_fn is observer_hook: + handle_ids_to_remove.add(handle_id) + for handle_id in handle_ids_to_remove: + hook_map.pop(handle_id) + + remove_hooks(pre_hook=True) + remove_hooks(pre_hook=False) + + +# TODO: rename to something more general +def _remove_qconfig(module): + r"""Clean up the qconfig left in the module so that new qconfig can be + propagated. + + Args: + module: module to be cleaned up + """ + for child in module.children(): + _remove_qconfig(child) + + if hasattr(module, "qconfig"): + del module.qconfig + + _remove_activation_post_process(module) + + +@typing_extensions.deprecated(DEPRECATION_WARNING) +def quantize(model, run_fn, run_args, mapping=None, inplace=False): + r"""Quantize the input float model with post training static quantization. + + First it will prepare the model for calibration, then it calls + `run_fn` which will run the calibration step, after that we will + convert the model to a quantized model. + + Args: + model: input float model + run_fn: a calibration function for calibrating the prepared model + run_args: positional arguments for `run_fn` + inplace: carry out model transformations in-place, the original module is mutated + mapping: correspondence between original module types and quantized counterparts + + Return: + Quantized model. + """ + torch._C._log_api_usage_once("quantization_api.quantize.quantize") + if mapping is None: + mapping = get_default_static_quant_module_mappings() + if not inplace: + model = copy.deepcopy(model) + model.eval() + prepare(model, inplace=True) + run_fn(model, *run_args) + convert(model, mapping, inplace=True) + return model + + +@typing_extensions.deprecated(DEPRECATION_WARNING) +def quantize_dynamic( + model, qconfig_spec=None, dtype=torch.qint8, mapping=None, inplace=False +): + r"""Converts a float model to dynamic (i.e. weights-only) quantized model. + + Replaces specified modules with dynamic weight-only quantized versions and output the quantized model. + + For simplest usage provide `dtype` argument that can be float16 or qint8. Weight-only quantization + by default is performed for layers with large weights size - i.e. Linear and RNN variants. + + Fine grained control is possible with `qconfig` and `mapping` that act similarly to `quantize()`. + If `qconfig` is provided, the `dtype` argument is ignored. + + Args: + model: input model + qconfig_spec: Either: + + - A dictionary that maps from name or type of submodule to quantization + configuration, qconfig applies to all submodules of a given + module unless qconfig for the submodules are specified (when the + submodule already has qconfig attribute). Entries in the dictionary + need to be QConfig instances. + + - A set of types and/or submodule names to apply dynamic quantization to, + in which case the `dtype` argument is used to specify the bit-width + + inplace: carry out model transformations in-place, the original module is mutated + mapping: maps type of a submodule to a type of corresponding dynamically quantized version + with which the submodule needs to be replaced + + """ + torch._C._log_api_usage_once("quantization_api.quantize.quantize_dynamic") + if qconfig_spec is None: + if dtype == torch.qint8: + qconfig_spec = { + nn.Linear: default_dynamic_qconfig, + nn.LSTM: default_dynamic_qconfig, + nn.GRU: default_dynamic_qconfig, + nn.LSTMCell: default_dynamic_qconfig, + nn.RNNCell: default_dynamic_qconfig, + nn.GRUCell: default_dynamic_qconfig, + } + elif dtype == torch.float16: + qconfig_spec = { + nn.Linear: float16_dynamic_qconfig, + nn.LSTM: float16_dynamic_qconfig, + nn.GRU: float16_dynamic_qconfig, + nn.LSTMCell: float16_dynamic_qconfig, + nn.RNNCell: float16_dynamic_qconfig, + nn.GRUCell: float16_dynamic_qconfig, + } + elif dtype == torch.quint8: + qconfig_spec = { + nn.EmbeddingBag: float_qparams_weight_only_qconfig, + nn.Embedding: float_qparams_weight_only_qconfig, + } + elif dtype == torch.quint4x2: + qconfig_spec = { + nn.EmbeddingBag: float_qparams_weight_only_qconfig_4bit, + } + else: + raise ValueError( + f"Don't know how to quantize with default settings for {dtype}. Provide full qconfig please" + ) + elif isinstance(qconfig_spec, set): + if dtype is torch.qint8: + default_qconfig = default_dynamic_qconfig + elif dtype is torch.float16: + default_qconfig = float16_dynamic_qconfig + elif dtype is torch.quint8: + default_qconfig = float_qparams_weight_only_qconfig + elif dtype is torch.quint4x2: + default_qconfig = float_qparams_weight_only_qconfig_4bit + else: + raise RuntimeError( + "Unknown dtype specified for quantize_dynamic: ", str(dtype) + ) + qconfig_spec = dict(zip(qconfig_spec, itertools.repeat(default_qconfig))) + + if mapping is None: + mapping = get_default_dynamic_quant_module_mappings() + + if not inplace: + model = copy.deepcopy(model) + model.eval() + propagate_qconfig_(model, qconfig_spec) + convert(model, mapping, inplace=True) + return model + + +@typing_extensions.deprecated(DEPRECATION_WARNING) +def prepare_qat(model, mapping=None, inplace=False): + r""" + Prepares a copy of the model for quantization calibration or + quantization-aware training and converts it to quantized version. + + Quantization configuration should be assigned preemptively + to individual submodules in `.qconfig` attribute. + + Args: + model: input model to be modified in-place + mapping: dictionary that maps float modules to quantized modules to be + replaced. + inplace: carry out model transformations in-place, the original module + is mutated + """ + torch._C._log_api_usage_once("quantization_api.quantize.prepare_qat") + if not model.training: + raise AssertionError("prepare_qat only works on models in training mode") + if mapping is None: + mapping = get_default_qat_module_mappings() + + if not inplace: + model = copy.deepcopy(model) + + propagate_qconfig_(model, qconfig_dict=None) + convert(model, mapping=mapping, inplace=True, remove_qconfig=False) + prepare(model, observer_non_leaf_module_list=set(mapping.values()), inplace=True) + return model + + +@typing_extensions.deprecated(DEPRECATION_WARNING) +def quantize_qat(model, run_fn, run_args, inplace=False): + r"""Do quantization aware training and output a quantized model + + Args: + model: input model + run_fn: a function for evaluating the prepared model, can be a + function that simply runs the prepared model or a training + loop + run_args: positional arguments for `run_fn` + + Return: + Quantized model. + """ + torch._C._log_api_usage_once("quantization_api.quantize.quantize_qat") + if not inplace: + model = copy.deepcopy(model) + model.train() + prepare_qat(model, inplace=True) + run_fn(model, *run_args) + convert(model, inplace=True) + return model + + +@typing_extensions.deprecated(DEPRECATION_WARNING) +def convert( + module, + mapping=None, + inplace=False, + remove_qconfig=True, + is_reference=False, + convert_custom_config_dict=None, + use_precomputed_fake_quant=False, +): + r"""Converts submodules in input module to a different module according to `mapping` + by calling `from_float` method on the target module class. And remove qconfig at the + end if remove_qconfig is set to True. + + Args: + `module`: prepared and calibrated module + `mapping`: a dictionary that maps from source module type to target + module type, can be overwritten to allow swapping user defined + Modules + `inplace`: carry out model transformations in-place, the original module + is mutated + `convert_custom_config_dict`: custom configuration dictionary for convert function + `use_precomputed_fake_quant`: a flag to enable use of precomputed fake quant + + .. code-block:: python + + # Example of convert_custom_config_dict: + convert_custom_config_dict = { + # user will manually define the corresponding quantized + # module class which has a from_observed class method that converts + # observed custom module to quantized custom module + "observed_to_quantized_custom_module_class": { + ObservedCustomModule: QuantizedCustomModule + } + } + + """ + torch._C._log_api_usage_once("quantization_api.quantize.convert") + if not inplace: + module = copy.deepcopy(module) + _convert( + module, + mapping, + inplace=True, + is_reference=is_reference, + convert_custom_config_dict=convert_custom_config_dict, + use_precomputed_fake_quant=use_precomputed_fake_quant, + ) + if remove_qconfig: + _remove_qconfig(module) + return module + + +def _convert( + module, + mapping=None, + inplace=False, + is_reference=False, + convert_custom_config_dict=None, + use_precomputed_fake_quant=False, +): + r"""Converts submodules in input module to a different module according to `mapping` + by calling `from_float` method on the target module class + + Args: + module: input module + mapping: a dictionary that maps from source module type to target + module type, can be overwritten to allow swapping user defined + Modules + inplace: carry out model transformations in-place, the original module + is mutated + is_reference: a flag to enable quantized reference module + use_precomputed_fake_quant: a flag to enable use of precomputed fake quant + + """ + if mapping is None: + mapping = ( + get_default_static_quant_reference_module_mappings() + if is_reference + else get_default_static_quant_module_mappings() + ) + if convert_custom_config_dict is None: + convert_custom_config_dict = get_default_custom_config_dict() + custom_module_class_mapping = convert_custom_config_dict.get( + "observed_to_quantized_custom_module_class", {} + ) + + if not inplace: + module = copy.deepcopy(module) + reassign = {} + for name, mod in module.named_children(): + # both fused modules and observed custom modules are + # swapped as one unit + if ( + not isinstance(mod, _FusedModule) + and type_before_parametrizations(mod) not in custom_module_class_mapping + ): + _convert( + mod, + mapping, + True, # inplace + is_reference, + convert_custom_config_dict, + use_precomputed_fake_quant=use_precomputed_fake_quant, + ) + reassign[name] = swap_module( + mod, mapping, custom_module_class_mapping, use_precomputed_fake_quant + ) + + for key, value in reassign.items(): + module._modules[key] = value + + return module + + +def swap_module( + mod, mapping, custom_module_class_mapping, use_precomputed_fake_quant=False +): + r"""Swaps the module if it has a quantized counterpart and it has an + `observer` attached. + + Args: + mod: input module + mapping: a dictionary that maps from nn module to nnq module + + Return: + The corresponding quantized module of `mod` + """ + new_mod = mod + if hasattr(mod, "qconfig") and mod.qconfig is not None: + swapped = False + if type_before_parametrizations(mod) in custom_module_class_mapping: + new_mod = custom_module_class_mapping[ + type_before_parametrizations(mod) + ].from_observed(mod) + swapped = True + elif type_before_parametrizations(mod) in mapping: + qmod = mapping[type_before_parametrizations(mod)] + if hasattr(qmod, "_IS_REFERENCE") and qmod._IS_REFERENCE: + if mod.qconfig is None: + raise AssertionError( + "module qconfig must not be None when swapping to reference module" + ) + weight_post_process = mod.qconfig.weight() + weight_post_process(mod.weight) + weight_qparams = get_qparam_dict(weight_post_process) + new_mod = qmod.from_float(mod, weight_qparams) + else: + sig = inspect.signature(qmod.from_float) + if "use_precomputed_fake_quant" in sig.parameters: + new_mod = qmod.from_float( + mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + else: + new_mod = qmod.from_float(mod) + swapped = True + + if swapped: + # Preserve module's pre forward hooks. They'll be called on quantized input + for pre_hook_fn in mod._forward_pre_hooks.values(): + new_mod.register_forward_pre_hook(pre_hook_fn) + # Preserve module's post forward hooks except _observer_forward_hook + # After convert they'll work with quantized output + for hook_fn in mod._forward_hooks.values(): + if hook_fn is not _observer_forward_hook: + new_mod.register_forward_hook(hook_fn) + + # respect device affinity when swapping modules + devices = _get_unique_devices_(mod) + if not ( + len(devices) <= 1 + or (len(devices) == 2 and torch.device("meta") in devices) + ): + raise AssertionError( + f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}" + ) + device = next(iter(devices)) if len(devices) > 0 else None + if device: + new_mod.to(device) + return new_mod + + +def _get_observer_dict(mod, target_dict, prefix=""): + r"""Traverse the modules and save all observers into dict. + This is mainly used for quantization accuracy debug + Args: + mod: the top module we want to save all observers + prefix: the prefix for the current module + target_dict: the dictionary used to save all the observers + """ + + def get_prefix(prefix): + return prefix if prefix == "" else prefix + "." + + if hasattr(mod, "activation_post_process"): + target_dict[get_prefix(prefix) + "activation_post_process"] = ( + mod.activation_post_process + ) + for name, child in mod.named_children(): + module_prefix = get_prefix(prefix) + name if prefix else name + _get_observer_dict(child, target_dict, module_prefix) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantize_fx.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantize_fx.py new file mode 100644 index 0000000000000000000000000000000000000000..ba6ab86aaa048fbd128f9a89cc32d4e438d3fe12 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantize_fx.py @@ -0,0 +1,759 @@ +import copy +import typing_extensions +import warnings +from typing import Any + +import torch +from torch.fx import GraphModule +from torch.fx.graph_module import _USER_PRESERVED_ATTRIBUTES_KEY + +from .backend_config import BackendConfig, get_tensorrt_backend_config # noqa: F401 +from .fx.convert import convert +from .fx.custom_config import ConvertCustomConfig, FuseCustomConfig, PrepareCustomConfig +from .fx.fuse import fuse # noqa: F401 +from .fx.graph_module import ObservedGraphModule # noqa: F401 +from .fx.prepare import prepare # noqa: F401 +from .fx.tracer import QuantizationTracer, Scope, ScopeContextManager # noqa: F401 +from .fx.utils import ( # noqa: F401 + get_custom_module_class_keys, + get_skipped_module_name_and_classes, +) +from .qconfig_mapping import QConfigMapping +from .utils import DEPRECATION_WARNING + + +def attach_preserved_attrs_to_model( + model: GraphModule | torch.nn.Module, + preserved_attrs: dict[str, Any], +) -> None: + """Store preserved attributes to the model.meta so that it can be preserved during deepcopy""" + model.meta[_USER_PRESERVED_ATTRIBUTES_KEY] = copy.copy(preserved_attrs) # type: ignore[operator, index, assignment] + # set the preserved attributes in the model so that user can call + # model.attr as they do before calling fx graph mode quantization + for attr_name, attr in model.meta[_USER_PRESERVED_ATTRIBUTES_KEY].items(): # type: ignore[index, union-attr] + setattr(model, attr_name, attr) + + +def _check_is_graph_module(model: torch.nn.Module) -> None: + if not isinstance(model, GraphModule): + raise ValueError( + "input model must be a GraphModule, " + + "Got type:" + + str(type(model)) + + " Please make " + + "sure to follow the tutorials." + ) + + +def _attach_meta_to_node_if_not_exist(model: GraphModule) -> None: + """Attach meta field to all nodes of the graph if it does not exist, + meta field is a field stores some meta information about the node, such + as dtype and shape information for output of the node, this only exists + if the program is captured by make_fx (used in quantize_pt2e flow), if + the program is captured by torch.fx symbolic tracing, this field may not exist, + so we add it here to avoid checking this all over the places + """ + for node in model.graph.nodes: + if not hasattr(node, "meta"): + node.meta = {} + + +def _swap_ff_with_fxff(model: torch.nn.Module) -> None: + r"""Swap FloatFunctional with FXFloatFunctional""" + modules_to_swap = [] + for name, module in model.named_children(): + if isinstance(module, torch.ao.nn.quantized.FloatFunctional): + modules_to_swap.append(name) + else: + _swap_ff_with_fxff(module) + + for name in modules_to_swap: + del model._modules[name] + model._modules[name] = torch.ao.nn.quantized.FXFloatFunctional() + + +def _fuse_fx( + model: GraphModule, + is_qat: bool, + fuse_custom_config: FuseCustomConfig | dict[str, Any] | None = None, + backend_config: BackendConfig | dict[str, Any] | None = None, +) -> GraphModule: + r"""Internal helper function to fuse modules in preparation for quantization + + Args: + model: GraphModule object from symbolic tracing (torch.fx.symbolic_trace) + """ + _check_is_graph_module(model) + return fuse(model, is_qat, fuse_custom_config, backend_config) # type: ignore[operator] + + +def _prepare_fx( + model: torch.nn.Module, + qconfig_mapping: QConfigMapping | dict[str, Any], + is_qat: bool, + example_inputs: tuple[Any, ...], + prepare_custom_config: PrepareCustomConfig | dict[str, Any] | None = None, + _equalization_config: QConfigMapping | dict[str, Any] | None = None, + backend_config: BackendConfig | dict[str, Any] | None = None, + is_standalone_module: bool = False, +) -> GraphModule: + r"""Internal helper function for prepare_fx + Args: + `model`, `qconfig_mapping`, `prepare_custom_config`, `_equalization_config`: + see docs for :func:`~torch.ao.quantization.prepare_fx` + `is_standalone_module`: a boolean flag indicates whether we are + quantizing a standalone module or not, a standalone module + is a submodule of the parent module that is not inlined in the + forward graph of the parent module, + the way we quantize standalone module is described in: + :func:`~torch.ao.quantization._prepare_standalone_module_fx` + """ + if prepare_custom_config is None: + prepare_custom_config = PrepareCustomConfig() + if _equalization_config is None: + _equalization_config = QConfigMapping() + + if isinstance(prepare_custom_config, dict): + warnings.warn( + "Passing a prepare_custom_config_dict to prepare is deprecated and will not be supported " + "in a future version. Please pass in a PrepareCustomConfig instead.", + FutureWarning, + stacklevel=3, + ) + prepare_custom_config = PrepareCustomConfig.from_dict(prepare_custom_config) + + # swap FloatFunctional with FXFloatFunctional + _swap_ff_with_fxff(model) + + skipped_module_names, skipped_module_classes = get_skipped_module_name_and_classes( + prepare_custom_config, is_standalone_module + ) + preserved_attr_names = prepare_custom_config.preserved_attributes + preserved_attrs = { + attr: getattr(model, attr) + for attr in preserved_attr_names + if hasattr(model, attr) + } + # symbolically trace the model + tracer = QuantizationTracer(skipped_module_names, skipped_module_classes) # type: ignore[arg-type] + graph_module = GraphModule(model, tracer.trace(model)) + _attach_meta_to_node_if_not_exist(graph_module) + + fuse_custom_config = FuseCustomConfig().set_preserved_attributes( + prepare_custom_config.preserved_attributes + ) + graph_module = _fuse_fx(graph_module, is_qat, fuse_custom_config, backend_config) + prepared = prepare( + graph_module, + qconfig_mapping, + is_qat, + tracer.node_name_to_scope, + example_inputs=example_inputs, + prepare_custom_config=prepare_custom_config, + _equalization_config=_equalization_config, + backend_config=backend_config, + is_standalone_module=is_standalone_module, + ) # type: ignore[operator] + + attach_preserved_attrs_to_model(prepared, preserved_attrs) + return prepared + + +def _prepare_standalone_module_fx( + model: torch.nn.Module, + qconfig_mapping: QConfigMapping | dict[str, Any], + is_qat: bool, + example_inputs: tuple[Any, ...], + prepare_custom_config: PrepareCustomConfig | dict[str, Any] | None = None, + backend_config: BackendConfig | dict[str, Any] | None = None, +) -> GraphModule: + r"""[Internal use only] Prepare a standalone module, so that it can be used when quantizing the + parent module. + standalone_module means it a submodule that is not inlined in parent module, + and will be quantized separately as one unit. + + How the standalone module is observed is specified by `input_quantized_idxs` and + `output_quantized_idxs` in the prepare_custom_config for the standalone module + + Returns: + + * model(GraphModule): prepared standalone module. It has these attributes in + model.meta: + + * `standalone_module_input_quantized_idxs(List[Int])`: a list of + indexes for the graph input that is expected to be quantized, + same as input_quantized_idxs configuration provided + for the standalone module + * `standalone_module_output_quantized_idxs(List[Int])`: a list of + indices for the graph output that is quantized + same as input_quantized_idxs configuration provided + for the standalone module + + """ + return _prepare_fx( + model, + qconfig_mapping, + is_qat, + example_inputs, + prepare_custom_config, + backend_config=backend_config, + is_standalone_module=True, + ) + + +def fuse_fx( + model: torch.nn.Module, + fuse_custom_config: FuseCustomConfig | dict[str, Any] | None = None, + backend_config: BackendConfig | dict[str, Any] | None = None, +) -> GraphModule: + r"""Fuse modules like conv+bn, conv+bn+relu etc, model must be in eval mode. + Fusion rules are defined in torch.ao.quantization.fx.fusion_pattern.py + + Args: + + * `model` (torch.nn.Module): a torch.nn.Module model + * `fuse_custom_config` (FuseCustomConfig): custom configurations for fuse_fx. + See :class:`~torch.ao.quantization.fx.custom_config.FuseCustomConfig` for more details + Example:: + + from torch.ao.quantization import fuse_fx + + m = Model().eval() + m = fuse_fx(m) + + """ + if fuse_custom_config is None: + fuse_custom_config = FuseCustomConfig() + + if isinstance(fuse_custom_config, dict): + warnings.warn( + "Passing a fuse_custom_config_dict to fuse is deprecated and will not be supported " + "in a future version. Please pass in a FuseCustomConfig instead.", + FutureWarning, + stacklevel=2, + ) + fuse_custom_config = FuseCustomConfig.from_dict(fuse_custom_config) + + torch._C._log_api_usage_once("quantization_api.quantize_fx.fuse_fx") + preserved_attr_names = fuse_custom_config.preserved_attributes + preserved_attrs = { + attr: getattr(model, attr) + for attr in preserved_attr_names + if hasattr(model, attr) + } + + graph_module = torch.fx.symbolic_trace(model) + _attach_meta_to_node_if_not_exist(graph_module) + graph_module = _fuse_fx(graph_module, False, fuse_custom_config, backend_config) + + attach_preserved_attrs_to_model(graph_module, preserved_attrs) + return graph_module + + +@typing_extensions.deprecated(DEPRECATION_WARNING) +def prepare_fx( + model: torch.nn.Module, + qconfig_mapping: QConfigMapping | dict[str, Any], + example_inputs: tuple[Any, ...], + prepare_custom_config: PrepareCustomConfig | dict[str, Any] | None = None, + _equalization_config: QConfigMapping | dict[str, Any] | None = None, + backend_config: BackendConfig | dict[str, Any] | None = None, +) -> GraphModule: + r""" Prepare a model for post training quantization + + Args: + * `model` (torch.nn.Module): torch.nn.Module model + + * `qconfig_mapping` (QConfigMapping): QConfigMapping object to configure how a model is + quantized, see :class:`~torch.ao.quantization.qconfig_mapping.QConfigMapping` + for more details + + * `example_inputs` (Tuple[Any, ...]): Example inputs for forward function of the model, + Tuple of positional args (keyword args can be passed as positional args as well) + + * `prepare_custom_config` (PrepareCustomConfig): customization configuration for quantization tool. + See :class:`~torch.ao.quantization.fx.custom_config.PrepareCustomConfig` for more details + + * `_equalization_config`: config for specifying how to perform equalization on the model + + * `backend_config` (BackendConfig): config that specifies how operators are quantized + in a backend, this includes how the operators are observed, + supported fusion patterns, how quantize/dequantize ops are + inserted, supported dtypes etc. See :class:`~torch.ao.quantization.backend_config.BackendConfig` for more details + + Return: + A GraphModule with observer (configured by qconfig_mapping), ready for calibration + + Example:: + + import torch + from torch.ao.quantization import get_default_qconfig_mapping + from torch.ao.quantization.quantize_fx import prepare_fx + + class Submodule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(5, 5) + def forward(self, x): + x = self.linear(x) + return x + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(5, 5) + self.sub = Submodule() + + def forward(self, x): + x = self.linear(x) + x = self.sub(x) + x + return x + + # initialize a floating point model + float_model = M().eval() + + # define calibration function + def calibrate(model, data_loader): + model.eval() + with torch.no_grad(): + for image, target in data_loader: + model(image) + + # qconfig is the configuration for how we insert observers for a particular + # operator + # qconfig = get_default_qconfig("fbgemm") + # Example of customizing qconfig: + # qconfig = torch.ao.quantization.QConfig( + # activation=MinMaxObserver.with_args(dtype=torch.qint8), + # weight=MinMaxObserver.with_args(dtype=torch.qint8)) + # `activation` and `weight` are constructors of observer module + + # qconfig_mapping is a collection of quantization configurations, user can + # set the qconfig for each operator (torch op calls, functional calls, module calls) + # in the model through qconfig_mapping + # the following call will get the qconfig_mapping that works best for models + # that target "fbgemm" backend + qconfig_mapping = get_default_qconfig_mapping("fbgemm") + + # We can customize qconfig_mapping in different ways. + # e.g. set the global qconfig, which means we will use the same qconfig for + # all operators in the model, this can be overwritten by other settings + # qconfig_mapping = QConfigMapping().set_global(qconfig) + # e.g. quantize the linear submodule with a specific qconfig + # qconfig_mapping = QConfigMapping().set_module_name("linear", qconfig) + # e.g. quantize all nn.Linear modules with a specific qconfig + # qconfig_mapping = QConfigMapping().set_object_type(torch.nn.Linear, qconfig) + # for a more complete list, please see the docstring for :class:`torch.ao.quantization.QConfigMapping` + # argument + + # example_inputs is a tuple of inputs, that is used to infer the type of the + # outputs in the model + # currently it's not used, but please make sure model(*example_inputs) runs + example_inputs = (torch.randn(1, 3, 224, 224),) + + # TODO: add backend_config after we split the backend_config for fbgemm and qnnpack + # e.g. backend_config = get_default_backend_config("fbgemm") + # `prepare_fx` inserts observers in the model based on qconfig_mapping and + # backend_config. If the configuration for an operator in qconfig_mapping + # is supported in the backend_config (meaning it's supported by the target + # hardware), we'll insert observer modules according to the qconfig_mapping + # otherwise the configuration in qconfig_mapping will be ignored + # + # Example: + # in qconfig_mapping, user sets linear module to be quantized with quint8 for + # activation and qint8 for weight: + # qconfig = torch.ao.quantization.QConfig( + # observer=MinMaxObserver.with_args(dtype=torch.quint8), + # weight=MinMaxObserver.with-args(dtype=torch.qint8)) + # Note: current qconfig api does not support setting output observer, but + # we may extend this to support these more fine grained control in the + # future + # + # qconfig_mapping = QConfigMapping().set_object_type(torch.nn.Linear, qconfig) + # in backend config, linear module also supports in this configuration: + # weighted_int8_dtype_config = DTypeConfig( + # input_dtype=torch.quint8, + # output_dtype=torch.quint8, + # weight_dtype=torch.qint8, + # bias_type=torch.float) + + # linear_pattern_config = BackendPatternConfig(torch.nn.Linear) \ + # .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \ + # .add_dtype_config(weighted_int8_dtype_config) \ + # ... + + # backend_config = BackendConfig().set_backend_pattern_config(linear_pattern_config) + # `prepare_fx` will check that the setting requested by suer in qconfig_mapping + # is supported by the backend_config and insert observers and fake quant modules + # in the model + prepared_model = prepare_fx(float_model, qconfig_mapping, example_inputs) + # Run calibration + calibrate(prepared_model, sample_inference_data) + """ + torch._C._log_api_usage_once("quantization_api.quantize_fx.prepare_fx") + return _prepare_fx( + model, + qconfig_mapping, + False, # is_qat + example_inputs, + prepare_custom_config, + _equalization_config, + backend_config, + ) + + +@typing_extensions.deprecated(DEPRECATION_WARNING) +def prepare_qat_fx( + model: torch.nn.Module, + qconfig_mapping: QConfigMapping | dict[str, Any], + example_inputs: tuple[Any, ...], + prepare_custom_config: PrepareCustomConfig | dict[str, Any] | None = None, + backend_config: BackendConfig | dict[str, Any] | None = None, +) -> GraphModule: + r"""Prepare a model for quantization aware training + + Args: + * `model` (torch.nn.Module): torch.nn.Module model + * `qconfig_mapping` (QConfigMapping): see :func:`~torch.ao.quantization.prepare_fx` + * `example_inputs` (Tuple[Any, ...]): see :func:`~torch.ao.quantization.prepare_fx` + * `prepare_custom_config` (PrepareCustomConfig): see :func:`~torch.ao.quantization.prepare_fx` + * `backend_config` (BackendConfig): see :func:`~torch.ao.quantization.prepare_fx` + + Return: + A GraphModule with fake quant modules (configured by qconfig_mapping and backend_config), ready for + quantization aware training + + Example:: + + import torch + from torch.ao.quantization import get_default_qat_qconfig_mapping + from torch.ao.quantization.quantize_fx import prepare_qat_fx + + + class Submodule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(5, 5) + + def forward(self, x): + x = self.linear(x) + return x + + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(5, 5) + self.sub = Submodule() + + def forward(self, x): + x = self.linear(x) + x = self.sub(x) + x + return x + + + # initialize a floating point model + float_model = M().train() + # (optional, but preferred) load the weights from pretrained model + # float_model.load_weights(...) + + + # define the training loop for quantization aware training + def train_loop(model, train_data): + model.train() + for image, target in data_loader: + ... + + + # qconfig is the configuration for how we insert observers for a particular + # operator + # qconfig = get_default_qconfig("fbgemm") + # Example of customizing qconfig: + # qconfig = torch.ao.quantization.QConfig( + # activation=FakeQuantize.with_args(observer=MinMaxObserver.with_args(dtype=torch.qint8)), + # weight=FakeQuantize.with_args(observer=MinMaxObserver.with_args(dtype=torch.qint8))) + # `activation` and `weight` are constructors of observer module + + # qconfig_mapping is a collection of quantization configurations, user can + # set the qconfig for each operator (torch op calls, functional calls, module calls) + # in the model through qconfig_mapping + # the following call will get the qconfig_mapping that works best for models + # that target "fbgemm" backend + qconfig_mapping = get_default_qat_qconfig_mapping("fbgemm") + + # We can customize qconfig_mapping in different ways, please take a look at + # the docstring for :func:`~torch.ao.quantization.prepare_fx` for different ways + # to configure this + + # example_inputs is a tuple of inputs, that is used to infer the type of the + # outputs in the model + # currently it's not used, but please make sure model(*example_inputs) runs + example_inputs = (torch.randn(1, 3, 224, 224),) + + # TODO: add backend_config after we split the backend_config for fbgemm and qnnpack + # e.g. backend_config = get_default_backend_config("fbgemm") + # `prepare_qat_fx` inserts observers in the model based on qconfig_mapping and + # backend_config, if the configuration for an operator in qconfig_mapping + # is supported in the backend_config (meaning it's supported by the target + # hardware), we'll insert fake_quantize modules according to the qconfig_mapping + # otherwise the configuration in qconfig_mapping will be ignored + # see :func:`~torch.ao.quantization.prepare_fx` for a detailed explanation of + # how qconfig_mapping interacts with backend_config + prepared_model = prepare_qat_fx(float_model, qconfig_mapping, example_inputs) + # Run training + train_loop(prepared_model, train_loop) + + """ + torch._C._log_api_usage_once("quantization_api.quantize_fx.prepare_qat_fx") + return _prepare_fx( + model, + qconfig_mapping, + True, # is_qat + example_inputs, + prepare_custom_config, + backend_config=backend_config, + ) + + +def _convert_fx( + graph_module: GraphModule, + is_reference: bool, + convert_custom_config: ConvertCustomConfig | dict[str, Any] | None = None, + is_standalone_module: bool = False, + _remove_qconfig: bool = True, + qconfig_mapping: QConfigMapping | dict[str, Any] | None = None, + backend_config: BackendConfig | dict[str, Any] | None = None, + is_decomposed: bool = False, + keep_original_weights: bool = False, +) -> GraphModule: + """`is_standalone_module`: see docs in :func:`~torch.ao.quantization.prepare_standalone_module_fx`""" + if convert_custom_config is None: + convert_custom_config = ConvertCustomConfig() + + if isinstance(convert_custom_config, dict): + warnings.warn( + "Passing a convert_custom_config_dict to convert is deprecated and will not be supported " + "in a future version. Please pass in a ConvertCustomConfig instead.", + FutureWarning, + stacklevel=3, + ) + convert_custom_config = ConvertCustomConfig.from_dict(convert_custom_config) + + _check_is_graph_module(graph_module) + preserved_attr_names = convert_custom_config.preserved_attributes + preserved_attrs = { + attr: getattr(graph_module, attr) + for attr in preserved_attr_names + if hasattr(graph_module, attr) + } + + quantized = convert( + graph_module, + is_reference, + convert_custom_config, + is_standalone_module, + _remove_qconfig_flag=_remove_qconfig, + qconfig_mapping=qconfig_mapping, + backend_config=backend_config, + is_decomposed=is_decomposed, + keep_original_weights=keep_original_weights, + ) + + attach_preserved_attrs_to_model(quantized, preserved_attrs) + return quantized + + +@typing_extensions.deprecated(DEPRECATION_WARNING) +def convert_fx( + graph_module: GraphModule, + convert_custom_config: ConvertCustomConfig | dict[str, Any] | None = None, + _remove_qconfig: bool = True, + qconfig_mapping: QConfigMapping | dict[str, Any] | None = None, + backend_config: BackendConfig | dict[str, Any] | None = None, + keep_original_weights: bool = False, +) -> GraphModule: + r"""Convert a calibrated or trained model to a quantized model + + Args: + * `graph_module` (torch.fx.GraphModule): A prepared and calibrated/trained model (GraphModule) + + * `convert_custom_config` (ConvertCustomConfig): custom configurations for convert function. + See :class:`~torch.ao.quantization.fx.custom_config.ConvertCustomConfig` for more details + + * `_remove_qconfig` (bool): Option to remove the qconfig attributes in the model after convert. + + * `qconfig_mapping` (QConfigMapping): config for specifying how to convert a model for quantization. + + The keys must include the ones in the qconfig_mapping passed to `prepare_fx` or `prepare_qat_fx`, + with the same values or `None`. Additional keys can be specified with values set to `None`. + + For each entry whose value is set to None, we skip quantizing that entry in the model:: + + qconfig_mapping = QConfigMapping + .set_global(qconfig_from_prepare) + .set_object_type(torch.nn.functional.add, None) # skip quantizing torch.nn.functional.add + .set_object_type(torch.nn.functional.linear, qconfig_from_prepare) + .set_module_name("foo.bar", None) # skip quantizing module "foo.bar" + + * `backend_config` (BackendConfig): A configuration for the backend which describes how + operators should be quantized in the backend, this includes quantization + mode support (static/dynamic/weight_only), dtype support (quint8/qint8 etc.), + observer placement for each operators and fused operators. + See :class:`~torch.ao.quantization.backend_config.BackendConfig` for more details + + Return: + A quantized model (torch.nn.Module) + + Example:: + + # prepared_model: the model after prepare_fx/prepare_qat_fx and calibration/training + # convert_fx converts a calibrated/trained model to a quantized model for the + # target hardware, this includes converting the model first to a reference + # quantized model, and then lower the reference quantized model to a backend + # Currently, the supported backends are fbgemm (onednn), qnnpack (xnnpack) and + # they share the same set of quantized operators, so we are using the same + # lowering procedure + # + # backend_config defines the corresponding reference quantized module for + # the weighted modules in the model, e.g. nn.Linear + # TODO: add backend_config after we split the backend_config for fbgemm and qnnpack + # e.g. backend_config = get_default_backend_config("fbgemm") + quantized_model = convert_fx(prepared_model) + + """ + torch._C._log_api_usage_once("quantization_api.quantize_fx.convert_fx") + return _convert_fx( + graph_module, + is_reference=False, + convert_custom_config=convert_custom_config, + _remove_qconfig=_remove_qconfig, + qconfig_mapping=qconfig_mapping, + backend_config=backend_config, + keep_original_weights=keep_original_weights, + ) + + +def convert_to_reference_fx( + graph_module: GraphModule, + convert_custom_config: ConvertCustomConfig | dict[str, Any] | None = None, + _remove_qconfig: bool = True, + qconfig_mapping: QConfigMapping | dict[str, Any] | None = None, + backend_config: BackendConfig | dict[str, Any] | None = None, +) -> GraphModule: + r"""Convert a calibrated or trained model to a reference quantized model, + see https://github.com/pytorch/rfcs/blob/master/RFC-0019-Extending-PyTorch-Quantization-to-Custom-Backends.md for more details, + reference quantized model is a standard representation of a quantized model provided + by FX Graph Mode Quantization, it can be further lowered to run on the target + hardware, like accelerators + + Args: + * `graph_module` (GraphModule): A prepared and calibrated/trained model (GraphModule) + + * `convert_custom_config` (ConvertCustomConfig): custom configurations for convert function. + See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details. + + * `_remove_qconfig` (bool): Option to remove the qconfig attributes in the model after convert. + + * `qconfig_mapping` (QConfigMapping): config for specifying how to convert a model for quantization. + See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details. + + * `backend_config` (BackendConfig): A configuration for the backend which describes how + operators should be quantized in the backend. See + :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details. + + Return: + A reference quantized model (GraphModule) + + Example:: + + # prepared_model: the model after prepare_fx/prepare_qat_fx and calibration/training + # TODO: add backend_config after we split the backend_config for fbgemm and qnnpack + # e.g. backend_config = get_default_backend_config("fbgemm") + reference_quantized_model = convert_to_reference_fx(prepared_model) + + """ + torch._C._log_api_usage_once("quantization_api.quantize_fx.convert_to_reference_fx") + return _convert_fx( + graph_module, + is_reference=True, + convert_custom_config=convert_custom_config, + _remove_qconfig=_remove_qconfig, + qconfig_mapping=qconfig_mapping, + backend_config=backend_config, + ) + + +def _convert_to_reference_decomposed_fx( + graph_module: GraphModule, + convert_custom_config: ConvertCustomConfig | dict[str, Any] | None = None, + qconfig_mapping: QConfigMapping | dict[str, Any] | None = None, + backend_config: BackendConfig | dict[str, Any] | None = None, +) -> GraphModule: + r"""Convert a calibrated or trained model to a reference quantized model, with + decomposed representation for quantized Tensor + see https://github.com/pytorch/rfcs/blob/master/RFC-0019-Extending-PyTorch-Quantization-to-Custom-Backends.md for more details, + reference quantized model is a standard representation of a quantized model provided + by FX Graph Mode Quantization, it can be further lowered to run on the target + hardware, like accelerators + + Note: this is not public API + + Args: + * `graph_module` (GraphModule): A prepared and calibrated/trained model (GraphModule) + + * `convert_custom_config` (ConvertCustomConfig): custom configurations for convert function. + See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details. + + * `_remove_qconfig` (bool): Option to remove the qconfig attributes in the model after convert. + + * `qconfig_mapping` (QConfigMapping): config for specifying how to convert a model for quantization. + See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details. + + * `backend_config` (BackendConfig): A configuration for the backend which describes how + operators should be quantized in the backend. See + :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details. + + Return: + A reference quantized model (GraphModule) with operators working with decomposed quantized Tensor + + Example:: + + # prepared_model: the model after prepare_fx/prepare_qat_fx and calibration/training + # TODO: add backend_config after we split the backend_config for fbgemm and qnnpack + # e.g. backend_config = get_default_backend_config("fbgemm") + reference_quantized_model = _convert_to_reference_decomposed_fx(prepared_model) + + """ + torch._C._log_api_usage_once( + "quantization_api.quantize_fx._convert_to_reference_decomposed_fx" + ) + return _convert_fx( + graph_module, + is_reference=True, + convert_custom_config=convert_custom_config, + _remove_qconfig=False, + qconfig_mapping=qconfig_mapping, + backend_config=backend_config, + is_decomposed=True, + ) + + +def _convert_standalone_module_fx( + graph_module: GraphModule, + is_reference: bool = False, + convert_custom_config: ConvertCustomConfig | dict[str, Any] | None = None, +) -> GraphModule: + r"""[Internal use only] Convert a model produced by :func:`~torch.ao.quantization.prepare_standalone_module_fx` + and convert it to a quantized model + + Returns a quantized standalone module, whether input/output is quantized is + specified by prepare_custom_config, with + input_quantized_idxs, output_quantized_idxs, please + see docs for prepare_fx for details + """ + return _convert_fx( + graph_module, + is_reference, + convert_custom_config, + is_standalone_module=True, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantize_jit.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantize_jit.py new file mode 100644 index 0000000000000000000000000000000000000000..ec4caab1edcd010a66032cab51cae77ad8e4ed62 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantize_jit.py @@ -0,0 +1,423 @@ +# mypy: allow-untyped-defs + +import torch +from torch.ao.quantization.qconfig import QConfig +from torch.ao.quantization.quant_type import QuantType +from torch.jit._recursive import wrap_cpp_module + + +__all__ = [ + "script_qconfig", + "script_qconfig_dict", + "fuse_conv_bn_jit", + "prepare_jit", + "prepare_dynamic_jit", + "convert_jit", + "convert_dynamic_jit", + "quantize_jit", + "quantize_dynamic_jit", +] + + +def _check_is_script_module(model): + if not isinstance(model, torch.jit.ScriptModule): + raise ValueError("input must be a script module, got: " + str(type(model))) + + +def _check_forward_method(model): + if not model._c._has_method("forward"): + raise ValueError("input script module does not have forward method") + + +def script_qconfig(qconfig): + r"""Instantiate the activation and weight observer modules and script + them, these observer module instances will be deepcopied during + prepare_jit step. + """ + return QConfig( + activation=torch.jit.script(qconfig.activation())._c, + weight=torch.jit.script(qconfig.weight())._c, + ) + + +def script_qconfig_dict(qconfig_dict): + r"""Helper function used by `prepare_jit`. + Apply `script_qconfig` for all entries in `qconfig_dict` that is + not None. + """ + return {k: script_qconfig(v) if v else None for k, v in qconfig_dict.items()} + + +def fuse_conv_bn_jit(model, inplace=False): + r"""Fuse conv - bn module + Works for eval model only. + + Args: + model: TorchScript model from scripting or tracing + """ + torch._C._log_api_usage_once("quantization_api.quantize_jit.fuse_conv_bn_jit") + model_c = model._c + model_c = torch._C._jit_pass_fold_convbn(model_c) + if inplace: + model._reconstruct(model_c) + else: + model = wrap_cpp_module(model_c) + return model + + +def _prepare_jit(model, qconfig_dict, inplace=False, quant_type=QuantType.STATIC): + _check_is_script_module(model) + _check_forward_method(model) + if not all(isinstance(x, str) for x in qconfig_dict): + raise ValueError("qconfig_dict should only contain names(str) as keys.") + scripted_qconfig_dict = script_qconfig_dict(qconfig_dict) + model = fuse_conv_bn_jit(model, inplace) + model_c = torch._C._jit_pass_insert_observers( + model._c, "forward", scripted_qconfig_dict, inplace, quant_type + ) + if inplace: + model._reconstruct(model_c) + else: + model = wrap_cpp_module(model_c) + return model + + +def _prepare_ondevice_jit( + model, + qconfig_dict, + method_name="forward", + inplace=False, + quant_type=QuantType.STATIC, +): + _check_is_script_module(model) + if not all(isinstance(x, str) for x in qconfig_dict): + raise ValueError("qconfig_dict should only contain names(str) as keys.") + scripted_qconfig_dict = script_qconfig_dict(qconfig_dict) + method_graph = model._c._get_method(method_name).graph + torch._C._jit_pass_inline(method_graph) + model = fuse_conv_bn_jit(model, inplace) + model_c = torch._C._jit_pass_insert_observer_method_for_ondevice_ptq( + model._c, method_name, scripted_qconfig_dict, inplace, quant_type + ) + if inplace: + model._reconstruct(model_c) + else: + model = wrap_cpp_module(model_c) + return model + + +def prepare_jit(model, qconfig_dict, inplace=False): + torch._C._log_api_usage_once("quantization_api.quantize_jit.prepare_jit") + return _prepare_jit(model, qconfig_dict, inplace, quant_type=QuantType.STATIC) + + +def prepare_dynamic_jit(model, qconfig_dict, inplace=False): + torch._C._log_api_usage_once("quantization_api.quantize_jit.prepare_dynamic_jit") + return _prepare_jit(model, qconfig_dict, inplace, quant_type=QuantType.DYNAMIC) + + +def _prepare_ondevice_dynamic_jit( + model, qconfig_dict, method_name="forward", inplace=False +): + return _prepare_ondevice_jit( + model, qconfig_dict, method_name, inplace, quant_type=QuantType.DYNAMIC + ) + + +def _convert_jit( + model, inplace=False, debug=False, quant_type=QuantType.STATIC, preserved_attrs=None +): + _check_is_script_module(model) + model.eval() + model_c = model._c + model_c = torch._C._jit_pass_insert_quant_dequant( + model_c, "forward", inplace, debug, quant_type + ) + if not debug: + is_xpu = all(p.device.type == "xpu" for p in model.parameters()) + if not is_xpu: + # Moving model parameters to CPU since quantized operators + # are only supported on CPU and XPU right now + model.cpu() + if preserved_attrs is None: + preserved_attrs = [] + model_c = torch._C._jit_pass_quant_finalize( + model_c, quant_type, preserved_attrs + ) + if inplace: + model._reconstruct(model_c) + else: + model = wrap_cpp_module(model_c) + torch._C._jit_pass_constant_propagation(model.graph) + torch._C._jit_pass_dce(model.graph) + return model + + +def _convert_ondevice_jit( + model, method_name, inplace=False, debug=False, quant_type=QuantType.STATIC +): + _check_is_script_module(model) + if quant_type != QuantType.DYNAMIC: + raise AssertionError( + "This API, while should work for static quant, is only tested for dynamic quant." + ) + if method_name.startswith("observe_"): + raise AssertionError("Pass in valid method to be quantized, e.g. forward") + observe_method_name = "observe_" + method_name + quantize_method_name = "quantize_" + method_name + model_c = model._c + model_c = torch._C._jit_pass_insert_quant_dequant_for_ondevice_ptq( + model._c, observe_method_name, inplace, debug, QuantType.DYNAMIC + ) + model_c = torch._C._jit_pass_quant_finalize_for_ondevice_ptq( + model_c, QuantType.DYNAMIC, quantize_method_name + ) + if inplace: + model._reconstruct(model_c) + else: + model = wrap_cpp_module(model_c) + return model + + +def convert_jit(model, inplace=False, debug=False, preserved_attrs=None): + torch._C._log_api_usage_once("quantization_api.quantize_jit.convert_jit") + return _convert_jit( + model, + inplace, + debug, + quant_type=QuantType.STATIC, + preserved_attrs=preserved_attrs, + ) + + +def convert_dynamic_jit(model, inplace=False, debug=False, preserved_attrs=None): + torch._C._log_api_usage_once("quantization_api.quantize_jit.convert_dynamic_jit") + return _convert_jit( + model, + inplace, + debug, + quant_type=QuantType.DYNAMIC, + preserved_attrs=preserved_attrs, + ) + + +def _convert_ondevice_dynamic_jit(model, method_name, inplace=False, debug=False): + return _convert_ondevice_jit( + model, method_name, inplace, debug, quant_type=QuantType.DYNAMIC + ) + + +def _quantize_ondevice_dynamic_jit_impl( + model, qconfig_dict, method_name, inplace=False +): + model = _prepare_ondevice_dynamic_jit(model, qconfig_dict, method_name, inplace) + model = _convert_ondevice_dynamic_jit(model, method_name, inplace) + return model + + +def _quantize_jit( + model, + qconfig_dict, + run_fn=None, + run_args=None, + inplace=False, + debug=False, + quant_type=QuantType.STATIC, +): + # Always do inplace convert because the Tensor is already + # copied in prepare_jit when inplace is False + if quant_type == QuantType.DYNAMIC: + model = prepare_dynamic_jit(model, qconfig_dict, inplace) + model = convert_dynamic_jit(model, True, debug) + else: + if not run_fn: + raise AssertionError( + "Must provide calibration function for post training static quantization" + ) + if not run_args: + raise AssertionError( + "Must provide calibration dataset for post training static quantization" + ) + model = prepare_jit(model, qconfig_dict, inplace) + run_fn(model, *run_args) + model = convert_jit(model, True, debug) + + torch._C._jit_pass_constant_propagation(model.graph) + torch._C._jit_pass_dce(model.graph) + return model + + +def quantize_jit(model, qconfig_dict, run_fn, run_args, inplace=False, debug=False): + r"""Quantize the input float TorchScript model with + post training static quantization. + + First it will prepare the model for calibration, then it calls + `run_fn` which will run the calibration step, after that we will + convert the model to a quantized model. + + Args: + `model`: input float TorchScript model + `qconfig_dict`: qconfig_dict is a dictionary with names of sub modules as key and + qconfig for that module as value, empty key means the qconfig will be applied + to whole model unless it's overwritten by more specific configurations, the + qconfig for each module is either found in the dictionary or fallback to + the qconfig of parent module. + + Right now qconfig_dict is the only way to configure how the model is quantized, + and it is done in the granularity of module, that is, we only support one type + of qconfig for each torch.nn.Module, and the qconfig for sub module will + override the qconfig for parent module, empty string means global configuration. + `run_fn`: a calibration function for calibrating the prepared model + `run_args`: positional arguments for `run_fn` + `inplace`: carry out model transformations in-place, the original module is + mutated + `debug`: flag for producing a debug friendly model (preserve weight attribute) + + Return: + Quantized TorchSciprt model. + + Example: + ```python + import torch + from torch.ao.quantization import get_default_qconfig + from torch.ao.quantization import quantize_jit + + ts_model = torch.jit.script( + float_model.eval() + ) # or torch.jit.trace(float_model, input) + qconfig = get_default_qconfig("fbgemm") + + + def calibrate(model, data_loader): + model.eval() + with torch.no_grad(): + for image, target in data_loader: + model(image) + + + quantized_model = quantize_jit( + ts_model, {"": qconfig}, calibrate, [data_loader_test] + ) + ``` + """ + torch._C._log_api_usage_once("quantization_api.quantize_jit.quantize_jit") + return _quantize_jit( + model, + qconfig_dict, + run_fn, + run_args, + inplace, + debug, + quant_type=QuantType.STATIC, + ) + + +def quantize_dynamic_jit(model, qconfig_dict, inplace=False, debug=False): + r"""Quantize the input float TorchScript model with + post training dynamic quantization. + Currently only qint8 quantization of torch.nn.Linear is supported. + + Args: + `model`: input float TorchScript model + `qconfig_dict`: qconfig_dict is a dictionary with names of sub modules as key and + qconfig for that module as value, please see detailed + descriptions in :func:`~torch.ao.quantization.quantize_jit` + `inplace`: carry out model transformations in-place, the original module is + mutated + `debug`: flag for producing a debug friendly model (preserve weight attribute) + + Return: + Quantized TorchSciprt model. + + Example: + ```python + import torch + from torch.ao.quantization import per_channel_dynamic_qconfig + from torch.ao.quantization import quantize_dynamic_jit + + ts_model = torch.jit.script( + float_model.eval() + ) # or torch.jit.trace(float_model, input) + qconfig = get_default_qconfig("fbgemm") + + + def calibrate(model, data_loader): + model.eval() + with torch.no_grad(): + for image, target in data_loader: + model(image) + + + quantized_model = quantize_dynamic_jit( + ts_model, {"": qconfig}, calibrate, [data_loader_test] + ) + ``` + """ + torch._C._log_api_usage_once("quantization_api.quantize_jit.quantize_dynamic_jit") + return _quantize_jit( + model, qconfig_dict, inplace=inplace, debug=debug, quant_type=QuantType.DYNAMIC + ) + + +def _quantize_ondevice_dynamic_jit( + model, qconfig_dict, method_name="forward", inplace=False +): + r"""Prepares the input float TorchScript model with + *on-device* post training dynamic quantization. + Currently only qint8 quantization of torch.nn.Linear is supported. + + Args: + `model`: input float TorchScript model + `qconfig_dict`: qconfig_dict is a dictionary with names of sub modules as key and + qconfig for that module as value, please see detailed + `method_name`: Name of the method within the model, to be prepared for quantization + descriptions in :func:`~torch.ao.quantization.quantize_jit` + `inplace`: carry out model transformations in-place, the original module is + mutated + + Return: + TorchScript model that is ready for on device quantization. + This means that the returned + model has: + - Method is inlined. + - Model has observer modules inserted in the model. + - Model has packed params inserted in the model. However they are empty as in they dont + contain valid quantized weights. + - observe_ is added that observe the values to be quantized. + - reset_observers_ to reset observers. + - quantize_ is added to the model. + - This method extract scale, zero points. + - Quantizes observed weights. + - Creates packed params from it and update the attribute of the model with the new values + for the packed params. + - Reset the original fp32 weights with empty tensor using SetAttr. + - quantized_ is added to the model. + - This method uses quantized weights and quantized linear ops instead of fp32 op. + - This method should be used for inference post PTQ. + - Note that all method's signatures should be the same as method_name. + + Later on device: + - Run reset_observers_ + - Run observe_ + - Run quantize_ + - Now model can be saved and loaded later. + - Run model with quantized_ + + Example: + ```python + import torch + from torch.ao.quantization import per_channel_dynamic_qconfig + from torch.ao.quantization.quantize_jit import _quantize_ondevice_dynamic_jit + + ts_model = torch.jit.script( + float_model.eval() + ) # or torch.jit.trace(float_model, input) + qconfig = get_default_qconfig("fbgemm") + quant_ready_model = _quantize_ondevice_dynamic_jit( + ts_model, {"": qconfig}, "forward", True + ) + ``` + """ + return _quantize_ondevice_dynamic_jit_impl( + model, qconfig_dict, method_name, inplace=inplace + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantize_pt2e.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantize_pt2e.py new file mode 100644 index 0000000000000000000000000000000000000000..169e2905ddbdcc2ec86d92d1b858abe7e91af298 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantize_pt2e.py @@ -0,0 +1,262 @@ +import typing_extensions + +import torch +from torch._export.passes.constant_folding import constant_fold +from torch.ao.quantization.pt2e.duplicate_dq_pass import DuplicateDQPass +from torch.ao.quantization.pt2e.port_metadata_pass import PortNodeMetaForQDQ +from torch.ao.quantization.quantizer import ( # noqa: F401 + DerivedQuantizationSpec, + FixedQParamsQuantizationSpec, + QuantizationAnnotation, + QuantizationSpec, + QuantizationSpecBase, + Quantizer, + SharedQuantizationSpec, +) +from torch.fx import GraphModule, Node +from torch.fx.passes.infra.pass_manager import PassManager + +from .pt2e.prepare import prepare +from .pt2e.qat_utils import _fold_conv_bn_qat, _fuse_conv_bn_qat +from .pt2e.representation import reference_representation_rewrite +from .pt2e.utils import _disallow_eval_train, _fuse_conv_bn_, _get_node_name_to_scope +from .quantize_fx import _convert_to_reference_decomposed_fx +from .utils import DEPRECATION_WARNING + + +__all__ = [ + "prepare_pt2e", + "prepare_qat_pt2e", + "convert_pt2e", +] + + +@typing_extensions.deprecated(DEPRECATION_WARNING) +def prepare_pt2e( + model: GraphModule, + quantizer: Quantizer, +) -> GraphModule: + """Prepare a model for post training quantization + + Args: + * `model` (torch.fx.GraphModule): a model captured by `torch.export.export_for_training` API. + * `quantizer`: A backend specific quantizer that conveys how user want the + model to be quantized. Tutorial for how to write a quantizer can be found here: + https://pytorch.org/tutorials/prototype/pt2e_quantizer.html + + Return: + A GraphModule with observer (based on quantizer annotation), ready for calibration + + Example:: + + import torch + from torch.ao.quantization.quantize_pt2e import prepare_pt2e + from torch.ao.quantization.quantizer import ( + XNNPACKQuantizer, + get_symmetric_quantization_config, + ) + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(5, 10) + + def forward(self, x): + return self.linear(x) + + # initialize a floating point model + float_model = M().eval() + + # define calibration function + def calibrate(model, data_loader): + model.eval() + with torch.no_grad(): + for image, target in data_loader: + model(image) + + # Step 1. program capture + # NOTE: this API will be updated to torch.export API in the future, but the captured + # result should mostly stay the same + m = torch.export.export_for_training(m, *example_inputs).module() + # we get a model with aten ops + + # Step 2. quantization + # backend developer will write their own Quantizer and expose methods to allow + # users to express how they + # want the model to be quantized + quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config()) + m = prepare_pt2e(m, quantizer) + + # run calibration + # calibrate(m, sample_inference_data) + """ + torch._C._log_api_usage_once("quantization_api.quantize_pt2e.prepare_pt2e") + original_graph_meta = model.meta + node_name_to_scope = _get_node_name_to_scope(model) + # TODO: check qconfig_mapping to make sure conv and bn are both configured + # to be quantized before fusion + # TODO: (maybe) rewrite this with subgraph_rewriter + _fuse_conv_bn_(model) + model = quantizer.transform_for_annotation(model) + quantizer.annotate(model) + quantizer.validate(model) + model = prepare( + model, + node_name_to_scope, + is_qat=False, + obs_or_fq_callback=quantizer.prepare_obs_or_fq_callback, + ) + model.meta.update(original_graph_meta) + model = _disallow_eval_train(model) + return model + + +@typing_extensions.deprecated(DEPRECATION_WARNING) +def prepare_qat_pt2e( + model: GraphModule, + quantizer: Quantizer, +) -> GraphModule: + """Prepare a model for quantization aware training + + Args: + * `model` (torch.fx.GraphModule): see :func:`~torch.ao.quantization.quantize_pt2e.prepare_pt2e` + * `quantizer`: see :func:`~torch.ao.quantization.quantize_pt2e.prepare_pt2e` + + Return: + A GraphModule with fake quant modules (based on quantizer annotation), ready for + quantization aware training + + Example:: + import torch + from torch.ao.quantization.quantize_pt2e import prepare_qat_pt2e + from torch.ao.quantization.quantizer import ( + XNNPACKQuantizer, + get_symmetric_quantization_config, + ) + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(5, 10) + + def forward(self, x): + return self.linear(x) + + # initialize a floating point model + float_model = M().eval() + + # define the training loop for quantization aware training + def train_loop(model, train_data): + model.train() + for image, target in data_loader: + ... + + # Step 1. program capture + # NOTE: this API will be updated to torch.export API in the future, but the captured + # result should mostly stay the same + m = torch.export.export_for_training(m, *example_inputs).module() + # we get a model with aten ops + + # Step 2. quantization + # backend developer will write their own Quantizer and expose methods to allow + # users to express how they + # want the model to be quantized + quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config()) + m = prepare_qat_pt2e(m, quantizer) + + # run quantization aware training + train_loop(prepared_model, train_loop) + + """ + torch._C._log_api_usage_once("quantization_api.quantize_pt2e.prepare_qat_pt2e") + original_graph_meta = model.meta + node_name_to_scope = _get_node_name_to_scope(model) + model = quantizer.transform_for_annotation(model) + quantizer.annotate(model) + quantizer.validate(model) + # Perform fusion after annotate to avoid quantizing ops in the new + # subgraph that don't need to be quantized + # TODO: only fuse if conv and bn are both configured to be quantized + _fuse_conv_bn_qat(model) + model = prepare( + model, + node_name_to_scope, + is_qat=True, + obs_or_fq_callback=quantizer.prepare_obs_or_fq_callback, + ) + model.meta.update(original_graph_meta) + model = _disallow_eval_train(model) + return model + + +_QUANT_OPS = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor, + torch.ops.quantized_decomposed.quantize_per_channel.default, + torch.ops.pt2e_quant.quantize_affine, +] + + +def _quant_node_constraint(n: Node) -> bool: + """If there is any pure ops between get_attr and quantize op they will be const propagated + e.g. get_attr(weight) -> transpose -> quantize -> dequantize* + (Note: dequantize op is not going to be constant propagated) + + This filter is added because we don't want to constant fold the things that are not + related to quantization + """ + return n.op == "call_function" and n.target in _QUANT_OPS + + +@typing_extensions.deprecated(DEPRECATION_WARNING) +def convert_pt2e( + model: GraphModule, + use_reference_representation: bool = False, + fold_quantize: bool = True, +) -> GraphModule: + """Convert a calibrated/trained model to a quantized model + + Args: + * `model` (torch.fx.GraphModule): calibrated/trained model + * `use_reference_representation` (bool): boolean flag to indicate whether to produce reference representation or not + * `fold_quantize` (bool): boolean flag for whether fold the quantize op or not + + Returns: + quantized model, either in q/dq representation or reference representation + + Example:: + + # prepared_model: the model produced by `prepare_pt2e`/`prepare_qat_pt2e` and calibration/training + # `convert_pt2e` produces a quantized model that represents quantized computation with + # quantize dequantize ops and fp32 ops by default. + # Please refer to + # https://pytorch.org/tutorials/prototype/pt2e_quant_ptq_static.html#convert-the-calibrated-model-to-a-quantized-model + # for detailed explanation of output quantized model + quantized_model = convert_pt2e(prepared_model) + + """ + torch._C._log_api_usage_once("quantization_api.quantize_pt2e.convert_pt2e") + if not isinstance(use_reference_representation, bool): + raise ValueError( + "Unexpected argument type for `use_reference_representation`, " + f"please make sure you intend to pass argument {use_reference_representation} to convert_pt2e" + ) + original_graph_meta = model.meta + model = _convert_to_reference_decomposed_fx(model) + model = _fold_conv_bn_qat(model) + + pm = PassManager([DuplicateDQPass()]) + model = pm(model).graph_module + + pm = PassManager([PortNodeMetaForQDQ()]) + model = pm(model).graph_module + + if fold_quantize: + constant_fold(model, _quant_node_constraint) + + if use_reference_representation: + model = reference_representation_rewrite(model) + + model.meta.update(original_graph_meta) + model = _disallow_eval_train(model) + return model diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantizer/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantizer/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..265c230d974e561cf59292eb738028fd222620ed Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantizer/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantizer/__pycache__/composable_quantizer.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantizer/__pycache__/composable_quantizer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..789db870327090ff8308e897de0bb145ec082561 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantizer/__pycache__/composable_quantizer.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantizer/__pycache__/embedding_quantizer.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantizer/__pycache__/embedding_quantizer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..666223830f662cbec2d9f48b3ba329d4f542b9d2 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantizer/__pycache__/embedding_quantizer.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantizer/__pycache__/quantizer.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantizer/__pycache__/quantizer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..38dee4e9de0b7c127eee92d0245caaaf5fb02024 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantizer/__pycache__/quantizer.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantizer/__pycache__/utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantizer/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c32f8dd74337c0d1c4cd7b02771888c53fafd4c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantizer/__pycache__/utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantizer/__pycache__/x86_inductor_quantizer.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantizer/__pycache__/x86_inductor_quantizer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..90ef29f1475dd1fa0eb822e1113f7eaf58c4b424 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantizer/__pycache__/x86_inductor_quantizer.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantizer/__pycache__/xnnpack_quantizer.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantizer/__pycache__/xnnpack_quantizer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c69d70018ce30c135c7d57039e3b9274acedbbaf Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantizer/__pycache__/xnnpack_quantizer.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantizer/__pycache__/xnnpack_quantizer_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantizer/__pycache__/xnnpack_quantizer_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e31b3d8400e42ebfee4c76e34481f7d55c88eee1 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantizer/__pycache__/xnnpack_quantizer_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantizer/__pycache__/xpu_inductor_quantizer.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantizer/__pycache__/xpu_inductor_quantizer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d73877fdbf7f96f9654241606e08a2372d6b4e7 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/quantizer/__pycache__/xpu_inductor_quantizer.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/stubs.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/stubs.py new file mode 100644 index 0000000000000000000000000000000000000000..8dd05374eff844be2cec2d913b88a338aded4e6a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/stubs.py @@ -0,0 +1,74 @@ +from typing import Any + +import torch +from torch import nn +from torch.ao.quantization import QConfig + + +__all__ = ["QuantStub", "DeQuantStub", "QuantWrapper"] + + +class QuantStub(nn.Module): + r"""Quantize stub module, before calibration, this is same as an observer, + it will be swapped as `nnq.Quantize` in `convert`. + + Args: + qconfig: quantization configuration for the tensor, + if qconfig is not provided, we will get qconfig from parent modules + """ + + def __init__(self, qconfig: QConfig | None = None): + super().__init__() + if qconfig: + self.qconfig = qconfig + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + + +class DeQuantStub(nn.Module): + r"""Dequantize stub module, before calibration, this is same as identity, + this will be swapped as `nnq.DeQuantize` in `convert`. + + Args: + qconfig: quantization configuration for the tensor, + if qconfig is not provided, we will get qconfig from parent modules + """ + + def __init__(self, qconfig: Any | None = None): + super().__init__() + if qconfig: + self.qconfig = qconfig + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + + +class QuantWrapper(nn.Module): + r"""A wrapper class that wraps the input module, adds QuantStub and + DeQuantStub and surround the call to module with call to quant and dequant + modules. + + This is used by the `quantization` utility functions to add the quant and + dequant modules, before `convert` function `QuantStub` will just be observer, + it observes the input tensor, after `convert`, `QuantStub` + will be swapped to `nnq.Quantize` which does actual quantization. Similarly + for `DeQuantStub`. + """ + + quant: QuantStub + dequant: DeQuantStub + module: nn.Module + + def __init__(self, module: nn.Module): + super().__init__() + qconfig = getattr(module, "qconfig", None) + self.add_module("quant", QuantStub(qconfig)) + self.add_module("dequant", DeQuantStub(qconfig)) + self.add_module("module", module) + self.train(module.training) + + def forward(self, X: torch.Tensor) -> torch.Tensor: + X = self.quant(X) + X = self.module(X) + return self.dequant(X) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..84a027e17e6b07cfbddc8b7b436ba0299b32ef91 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/quantization/utils.py @@ -0,0 +1,875 @@ +# mypy: allow-untyped-defs +""" +Utils shared by different modes of quantization (eager/graph) +""" + +import functools +import sys +import warnings +from collections import OrderedDict +from collections.abc import Callable +from inspect import getfullargspec, signature +from typing import Any, Union + +import torch +from torch.ao.quantization.quant_type import QuantType +from torch.fx import Node +from torch.nn.utils.parametrize import is_parametrized + + +if sys.version_info < (3, 12): + NodePattern = Union[tuple[Node, Node], tuple[Node, tuple[Node, Node]], Any] + NodePattern.__module__ = "torch.ao.quantization.utils" +else: + from typing import TypeAliasType + + NodePattern = TypeAliasType( + "NodePattern", tuple[Node, Node] | tuple[Node, tuple[Node, Node]] | Any + ) + + +# This is the Quantizer class instance from torch/quantization/fx/quantize.py. +# Define separately to prevent circular imports. +# TODO(future PR): improve this. +# make this public once fixed (can't be public as is because setting the module directly +# doesn't work) +QuantizerCls = Any + +# Type for fusion patterns, it can be more complicated than the following actually, +# see pattern.md for docs +# TODO: not sure if typing supports recursive data types + +if sys.version_info < (3, 12): + Pattern = Union[ + Callable, + tuple[Callable, Callable], + tuple[Callable, tuple[Callable, Callable]], + Any, + ] + Pattern.__module__ = "torch.ao.quantization.utils" +else: + from typing import TypeAliasType + + Pattern = TypeAliasType( + "Pattern", + Callable + | tuple[Callable, Callable] + | tuple[Callable, tuple[Callable, Callable]] + | Any, + ) + + +# TODO: maybe rename this to MatchInputNode +class MatchAllNode: + """A node pattern that matches all nodes, used in defining + fusion patterns in FX Graph Mode Quantization + """ + + +module_type_list = { + torch.nn.ReLU, + torch.nn.ReLU6, + torch.nn.AdaptiveAvgPool1d, + torch.nn.AdaptiveAvgPool2d, + torch.nn.AdaptiveAvgPool3d, + torch.nn.AvgPool1d, + torch.nn.AvgPool2d, + torch.nn.AvgPool3d, + torch.nn.MaxPool1d, + torch.nn.MaxPool2d, + torch.nn.MaxPool3d, + torch.nn.Identity, + torch.nn.Hardsigmoid, + torch.nn.Sigmoid, + torch.nn.Tanh, +} +func_list = { + torch.nn.functional.adaptive_avg_pool1d, + torch.nn.functional.adaptive_avg_pool2d, + torch.nn.functional.adaptive_avg_pool3d, + torch.nn.functional.elu, + torch.nn.functional.hardswish, + torch.nn.functional.instance_norm, + torch.nn.functional.layer_norm, + torch.nn.functional.leaky_relu, + torch.nn.functional.silu, + torch.nn.functional.mish, + torch.nn.functional.dropout, + torch.nn.functional.max_pool1d, + torch.nn.functional.max_pool2d, + torch.nn.functional.max_pool3d, + torch.nn.functional.relu, + torch.nn.functional.hardtanh, + torch.nn.functional.hardtanh_, + torch.nn.functional.hardsigmoid, + torch.nn.functional.sigmoid, + torch.transpose, + torch.repeat_interleave, + torch.sigmoid, + torch.squeeze, + torch.stack, + torch.sum, + torch.tanh, + torch.unsqueeze, + torch.cat, +} +method_list = { + torch.mean, + "relu", + "relu_", + "contiguous", + "detach", + "detach_", + "hardsigmoid", + "hardsigmoid_", + "permute", + "repeat", + "repeat_interleave", + "reshape", + "resize_", + "shape", + "sigmoid", + "sigmoid_", + "size", + "squeeze", + "squeeze_", + "tanh", + "tanh_", + "transpose", + "unsqueeze", + "unsqueeze_", + "view", +} + + +# TODO: not used now, remove +def check_node(node, modules): + # TODO: reuse is_fixed_qparam_node after we move this function to _lower_to_native_backend.py + is_call_function = node.op == "call_function" and node.target in func_list + is_call_method = node.op == "call_method" and node.target in method_list + is_call_module = ( + node.op == "call_module" and type(modules[str(node.target)]) in module_type_list + ) + return is_call_function, is_call_method, is_call_module + + +def get_combined_dict(default_dict, additional_dict): + """ + Combines two dictionaries. + + This function takes two dictionaries as input and returns a new dictionary + that contains all the key-value pairs from both input dictionaries. + If there are any duplicate keys in the `additional_dict`, the values + from the `additional_dict` will overwrite those in the `default_dict`. + Args: + default_dict (dict): The main dictionary that will be used as the base + additional_dict (dict): The dictionary used to update `default_dict` + + Returns: + dict: The resulting dictionary + Example: + >>> x = dict(a=1, b=1) + >>> y = dict(b=2, c=3) + >>> get_combined_dict(x, y) + {'a': 1, 'b': 2, 'c': 3} + """ + d = default_dict.copy() + d.update(additional_dict) + return d + + +def is_per_tensor(qscheme): + return qscheme == torch.per_tensor_affine or qscheme == torch.per_tensor_symmetric + + +def is_per_channel(qscheme): + return qscheme in [ + torch.per_channel_affine, + torch.per_channel_affine_float_qparams, + torch.per_channel_symmetric, + ] + + +def getattr_from_fqn(obj: Any, fqn: str) -> Any: + """ + Given an obj and a fqn such as "foo.bar.baz", returns gm.foo.bar.baz. + """ + return functools.reduce(getattr, fqn.split("."), obj) + + +def to_underlying_dtype(qdtype): + DTYPE_MAPPING = { + torch.quint8: torch.uint8, + torch.qint8: torch.int8, + torch.qint32: torch.int32, + torch.quint4x2: torch.uint8, + torch.quint2x4: torch.uint8, + torch.uint8: torch.uint8, + torch.int8: torch.int8, + torch.uint16: torch.uint16, + torch.int16: torch.int16, + torch.int32: torch.int32, + torch.float8_e5m2: torch.float8_e5m2, + torch.float8_e4m3fn: torch.float8_e4m3fn, + } + if qdtype not in DTYPE_MAPPING: + raise AssertionError("Unsupported dtype: " + str(qdtype)) + return DTYPE_MAPPING[qdtype] + + +def get_qparam_dict(observer_or_fake_quant): + from torch.ao.quantization.observer import PlaceholderObserver + + qscheme = getattr(observer_or_fake_quant, "qscheme", None) + dtype = observer_or_fake_quant.dtype + qparams = {"qscheme": qscheme, "dtype": dtype} + + if not qscheme or isinstance(observer_or_fake_quant, PlaceholderObserver): + return {"qscheme": None, "dtype": dtype} + + if is_per_tensor(qscheme): + qscheme = torch.per_tensor_affine + elif is_per_channel(qscheme): + # change symmetric to affine since we do not have symmetric + # quantized Tensor + if qscheme == torch.per_channel_symmetric: + qscheme = torch.per_channel_affine + qparams["axis"] = observer_or_fake_quant.ch_axis + else: + raise RuntimeError(f"Unrecognized qscheme: {qscheme}") + # update qscheme, since we don't have symmetric quant qscheme + # in quantized Tensor + qparams["qscheme"] = qscheme + + scale, zero_point = observer_or_fake_quant.calculate_qparams() + qparams["scale"] = scale + qparams["zero_point"] = zero_point + + if hasattr(observer_or_fake_quant, "quant_min"): + qparams["quant_min"] = observer_or_fake_quant.quant_min + if hasattr(observer_or_fake_quant, "quant_max"): + qparams["quant_max"] = observer_or_fake_quant.quant_max + + return qparams + + +def get_swapped_custom_module_class( + custom_module, custom_module_class_mapping, qconfig +): + """Get the observed/quantized custom module class that we need + to swap `custom_module` to + Input: + custom_module: input, can be an instance of either a float or observed custom module + custom_module_class_mapping: the float to observed or observed to quantized custom module class mapping + qconfig: qconfig configured for the custom module + + Output: + corresponding observed/quantized custom module class for input custom module instance + """ + quant_type = get_quant_type(qconfig) + class_mapping = custom_module_class_mapping.get(quant_type, {}) + if type(custom_module) not in class_mapping: + raise AssertionError( + "did not find corresponding observed " + f"module class for {type(custom_module)} in mapping: {class_mapping}" + ) + return class_mapping[type(custom_module)] + + +def activation_dtype(qconfig): + if qconfig is None: + raise AssertionError("qconfig must be provided to determine activation dtype") + activation = qconfig.activation() + return activation.dtype + + +def weight_dtype(qconfig): + if qconfig is None: + raise AssertionError("qconfig must be provided to determine weight dtype") + weight = qconfig.weight() + return weight.dtype + + +def activation_is_statically_quantized(qconfig): + """Given a qconfig, decide if the activation needs to be + quantized or not, this includes quantizing to quint8, qint8 and qint32 and float16 + """ + return activation_dtype(qconfig) in [ + torch.quint8, + torch.qint8, + torch.qint32, + torch.float16, + torch.uint8, + torch.int8, + torch.int16, + torch.int32, + torch.float8_e5m2, + torch.float8_e4m3fn, + ] and (not activation_is_dynamically_quantized(qconfig)) + + +def activation_is_dynamically_quantized(qconfig): + """Given a qconfig, decide if the activation needs to be + dynamically quantized or not, this includes dynamically quantizing to + quint8, qint8 and float16 + """ + _activation_dtype, _, activation_is_dynamic = get_qconfig_dtypes(qconfig) + return activation_is_dynamic + + +def activation_is_int8_quantized(qconfig): + """Given a qconfig, decide if the activation needs to be + quantized to int8 or not, this includes quantizing to quint8, qint8 + """ + return activation_dtype(qconfig) in [ + torch.quint8, + torch.qint8, + torch.uint8, + torch.int8, + ] + + +def activation_is_int32_quantized(qconfig): + """Given a qconfig, decide if the activation needs to be + quantized to int32 or not + """ + return activation_dtype(qconfig) in [torch.qint32, torch.int32] + + +def weight_is_quantized(qconfig): + """Given a qconfig, decide if the weight needs to be + quantized or not + """ + return weight_dtype(qconfig) in [ + torch.quint8, + torch.qint8, + torch.float16, + torch.quint4x2, + torch.uint8, + torch.int8, + torch.int16, + torch.int32, + torch.float8_e5m2, + torch.float8_e4m3fn, + ] + + +def weight_is_statically_quantized(qconfig): + """Given a qconfig, decide if the weight needs to be statically + quantized or not + """ + return weight_dtype(qconfig) in [torch.quint8, torch.qint8, torch.uint8, torch.int8] + + +def op_is_int8_dynamically_quantized(qconfig) -> bool: + """Given a qconfig, returns True if this op is using int8 dynamic + quantization + """ + activation_dtype, weight_dtype, activation_is_dynamic = get_qconfig_dtypes(qconfig) + return ( + activation_dtype in [torch.quint8, torch.uint8] + and + # for now, the lines below assume fbgemm or qnnpack + weight_dtype in [torch.qint8, torch.int8] + and activation_is_dynamic + ) + + +def get_qconfig_dtypes(qconfig): + r"""returns the qconfig tuple for qconfig: + (activation_dtype, weight_dtype, activation_is_dynamic) + """ + if qconfig is None: + raise AssertionError("qconfig must be provided to extract dtypes") + activation = qconfig.activation() + weight = qconfig.weight() + act_is_dynamic = getattr(activation, "is_dynamic", False) + return (activation.dtype, weight.dtype, act_is_dynamic) + + +def get_quant_type(qconfig): + if qconfig is None: + raise AssertionError("qconfig must be provided to determine quant type") + activation = qconfig.activation() + weight = qconfig.weight() + static_dtypes = [ + torch.quint8, + torch.qint8, + torch.quint4x2, + torch.qint32, + torch.uint8, + torch.int8, + torch.int16, + torch.int32, + torch.float8_e5m2, + torch.float8_e4m3fn, + ] + if weight.dtype in static_dtypes: + if hasattr(activation, "is_dynamic") and activation.is_dynamic: + return QuantType.DYNAMIC + elif activation.dtype in static_dtypes: + return QuantType.STATIC + else: + return QuantType.WEIGHT_ONLY + + if weight.dtype == torch.float16: + if hasattr(activation, "is_dynamic") and activation.is_dynamic: + return QuantType.DYNAMIC + elif activation.dtype == torch.float16: + return QuantType.STATIC + + raise Exception( # noqa: TRY002 + f"Unrecognized dtype combination in get_quant_type: activation({activation.dtype})," + f"weight({weight.dtype})" + ) + + +def check_min_max_valid(min_val: torch.Tensor, max_val: torch.Tensor) -> bool: + """Checks if the given minimum and maximum values are valid, meaning that + they exist and the min value is less than the max value. + """ + if min_val.numel() == 0 or max_val.numel() == 0: + warnings.warn( + "must run observer before calling calculate_qparams. " + + "Returning default values.", + stacklevel=2, + ) + return False + + if min_val.dim() == 0 or max_val.dim() == 0: + if min_val == float("inf") and max_val == float("-inf"): + warnings.warn( + "must run observer before calling calculate_qparams. " + + "Returning default values.", + stacklevel=2, + ) + + return False + + if min_val > max_val: + raise AssertionError(f"min {min_val} should be less than max {max_val}") + else: + if torch.any(min_val > max_val): + raise AssertionError(f"min {min_val} should be less than max {max_val}") + + return True + + +def calculate_qmin_qmax( + quant_min: int, + quant_max: int, + has_customized_qrange: bool, + dtype: torch.dtype, + reduce_range: bool, +) -> tuple[int, int]: + r"""Calculates actual qmin and qmax based on the quantization range, + observer datatype and if range is reduced. + """ + # TODO(jerryzh): Figure out why custom quant_min/quant_max are still adjusted. + if has_customized_qrange: + # This initialization here is to be resolve TorchScript compilation issues and allow + # using of refinement to decouple initial_qmin and initial_qmax from quantization range. + # The actual values of initial_qmin and initial_qmax will be reset below. + if dtype in [torch.qint32, torch.int32]: + initial_quant_min, initial_quant_max = 0, 2**32 - 1 + else: + initial_quant_min, initial_quant_max = 0, 255 + # The following assignment of self.qmin and self.qmax to the local variables and the if check refine the + # attribute from Optional valid integers for use, based on TorchScript's requirements. + custom_quant_min, custom_quant_max = quant_min, quant_max + if custom_quant_min is not None and custom_quant_max is not None: + initial_quant_min, initial_quant_max = ( + custom_quant_min, + custom_quant_max, + ) + + qrange_len = initial_quant_max - initial_quant_min + 1 + if dtype in [torch.qint8, torch.int8]: + if not (0 < qrange_len <= 256): + raise AssertionError( + "quantization range should be positive and not exceed the maximum bit range (=256)." + ) + elif dtype in [torch.qint32, torch.int32]: + if not (0 < qrange_len <= 2**32): + raise AssertionError( + "quantization range should be positive and not exceed the maximum bit range (=4294967296)." + ) + if reduce_range: + quant_min, quant_max = quant_min // 2, quant_max // 2 + else: + # Fallback onto default 8-bit qmin and qmax calculation if dynamic range is not used. + if dtype in [torch.qint8, torch.int8]: + if reduce_range: + quant_min, quant_max = -64, 63 + else: + quant_min, quant_max = -128, 127 + elif dtype in [torch.quint8, torch.uint8]: + if reduce_range: + quant_min, quant_max = 0, 127 + else: + quant_min, quant_max = 0, 255 + elif dtype in [torch.qint32, torch.int32]: + quant_min, quant_max = -1 * (2**31), (2**31) - 1 + elif dtype == torch.uint16: + quant_min, quant_max = 0, 2**16 - 1 + elif dtype == torch.int16: + quant_min, quant_max = -(2**15), 2**15 - 1 + else: + quant_min, quant_max = 0, 15 + return quant_min, quant_max + + +def _parent_name(target): + """ + Turn 'foo.bar' into ['foo', 'bar'] + """ + r = target.rsplit(".", 1) + if len(r) == 1: + return "", r[0] + else: + return r[0], r[1] + + +def has_no_children_ignoring_parametrizations(module): + """ + Checks if module._modules is empty or + if module is a parametrization, checks that module._modules only has + the 'parametrizations' module + """ + if len(module._modules) == 0: + return True + elif is_parametrized(module): + return len(module._modules) == 1 and "parametrizations" in module._modules + else: + return False + + +def _get_path_of_module( + root: torch.nn.Module, submodule: torch.nn.Module +) -> str | None: + """Get the path (fully qualified name) of a submodule + + Example:: + + >> class M(torch.nn.Module): + def __init__(self) -> None: + self.linear = torch.nn.Linear(5, 5) + def forward(self, x): + return self.linear(x) + + >> m = M() + >> l = m.linear + >> _get_path_of_module(m, l) + "linear" + """ + for n, p in root.named_modules(): + if submodule is p: + return n + return None + + +def _get_signature_locals(f: Callable, loc: dict[str, Any]) -> dict[str, Any]: + """Get local keyword arguments + + Example:: + + >> def f(self, a, b=9): + pass + >> loc = {"a": 6, "c": 7} + >> _get_signature_locals(f, loc) + {"a": 6} + """ + return {k: v for k, v in loc.items() if k in signature(f).parameters} + + +def _get_default_kwargs(f: Callable) -> "OrderedDict[str, Any]": + """Get all default keyword arguments from function signature + + Example:: + + >> def f(self, a, b=9): + pass + >> _get_default_kwargs(f) + {"b": 9} + """ + kwargs = {} + for name, param in signature(f).parameters.items(): + if param.default is not param.empty: + kwargs[name] = param.default + elif param.kind is param.VAR_POSITIONAL: + kwargs[name] = () + elif param.kind is param.VAR_KEYWORD: + kwargs[name] = {} + return OrderedDict(kwargs) + + +def _normalize_kwargs(func: Callable, loc: dict[str, Any]) -> "OrderedDict[str, Any]": + """Given a function and local function arguments, normalize the keyword + arguments by filling in default arguments from function signature + + Example:: + + >> def f(self, key1=3, key2=3): + pass + >> loc = {"key2": 6} + >> _normalize_kwargs(f, loc) + {"key1": 3, "key2": 6} + """ + default_kwargs = _get_default_kwargs(func) + local_kwargs = _get_signature_locals(func, loc) + normalized_kwargs = default_kwargs.copy() + for attr, val in local_kwargs.items(): + if attr in normalized_kwargs: + # override the default keyword arguments + normalized_kwargs[attr] = val + return normalized_kwargs + + +def validate_qmin_qmax(quant_min: int, quant_max: int) -> None: + r"""Validates that the user-specified quantization range is properly initialized + and within the given bound supported by the observer dtype. + + To accommodate lower-bit quantization with respect to the existing torch.qint8 and + torch.quint8 datatypes, the user can choose to use dynamic quantization range by passing + in a tuple of initial qmin and qmax values. One use case is these customized qmin and qmax + values are used to calculate static estimates of the scale and zero point for aggressive lower-bit + fake quantization. These estimates are compared against parameters learned through backpropagation. + The related literatures for scale and zero point via backpropagation are as follows: + + Learned Step Size Quantization: https://openreview.net/pdf?id=rkgO66VKDS + Trained Quantization Thresholds: https://arxiv.org/pdf/1903.08066.pdf + """ + # The variable names are prefixed with "initial" because their values (qmin and qmax) might be adjusted + # based on whether quantization range is reduced and the datatype (signed/unsigned) used by the observer. + if not (quant_min <= 0 <= quant_max): + raise AssertionError("Used-specified quantization range must include 0.") + if quant_min >= quant_max: + raise AssertionError( + "qmin must be strictly less than qmax for user-specified quantization range." + ) + + +# Functionally equivalent to '_calculate_qparams' in observer.py. Observers must be torchscriptable however and qscheme +# as far as I can tell is not allowed to passed as a parameter in torchscript functions. This makes refactoring observer +# to use this utility a massive pain and very gross. For now Im opting just to duplicate as this code seems unlikely to change +# (last update over 1 year ago) and when torchscript is fully deprecated we can refactor. TODO(jakeszwe, jerryzh168) +def determine_qparams( + min_val: torch.Tensor, + max_val: torch.Tensor, + quant_min: int, + quant_max: int, + dtype: torch.dtype, + eps: torch.Tensor, + has_customized_qrange: bool, + qscheme: torch.qscheme = torch.per_tensor_affine, +) -> tuple[torch.Tensor, torch.Tensor]: + r"""Calculates the quantization parameters, given min and max + value tensors. Works for both per tensor and per channel cases + + Args: + min_val: Minimum values per channel + max_val: Maximum values per channel + + Returns: + scales: Scales tensor of shape (#channels,) + zero_points: Zero points tensor of shape (#channels,) + """ + if not check_min_max_valid(min_val, max_val): + return torch.tensor([1.0], device=min_val.device.type), torch.tensor( + [0], device=min_val.device.type + ) + + min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) + max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) + + device = min_val_neg.device + scale = torch.ones(min_val_neg.size(), dtype=torch.double, device=device) + zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device) + eps = eps.to(device) + + if qscheme == torch.per_tensor_symmetric or qscheme == torch.per_channel_symmetric: + max_val_pos = torch.max(-min_val_neg, max_val_pos) + scale = max_val_pos / (float(quant_max - quant_min) / 2) + scale = torch.max(scale, eps) + if dtype in [torch.uint8, torch.quint8]: + if has_customized_qrange: + # When customized quantization range is used, down-rounded midpoint of the range is chosen. + zero_point = zero_point.new_full( + zero_point.size(), (quant_min + quant_max) // 2 + ) + else: + zero_point = zero_point.new_full(zero_point.size(), 128) + elif qscheme == torch.per_channel_affine_float_qparams: + scale = (max_val - min_val) / float(quant_max - quant_min) + scale = torch.where(scale > eps, scale, torch.ones_like(scale)) + # We use the quantize function + # xq = Round(Xf * inv_scale + zero_point), + # setting zero_point to (-1 * min *inv_scale) we get + # Xq = Round((Xf - min) * inv_scale) + zero_point = -1 * min_val / scale + else: + scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min) + scale = torch.max(scale, eps) + zero_point = quant_min - torch.round(min_val_neg / scale).to(torch.int) + zero_point = torch.clamp(zero_point, quant_min, quant_max) + + # For scalar values, cast them to Tensors of size 1 to keep the shape + # consistent with default values in FakeQuantize. + if len(scale.shape) == 0: + # TODO: switch to scale.item() after adding JIT support + scale = torch.tensor([float(scale)], dtype=scale.dtype, device=device) + if len(zero_point.shape) == 0: + # TODO: switch to zero_point.item() after adding JIT support + zero_point = torch.tensor( + [int(zero_point)], dtype=zero_point.dtype, device=device + ) + if qscheme == torch.per_channel_affine_float_qparams: + zero_point = torch.tensor( + [float(zero_point)], dtype=zero_point.dtype, device=device + ) + + return scale.to(torch.double), zero_point.to(torch.int64) + + +def _get_num_pos_args(f: Callable) -> int: + """Get number of positional args for a function + + Example:: + + >> def f(self, key1=3, key2=3): + pass + >> _get_num_pos_args(f) + 3 + """ + return len(getfullargspec(f).args) + + +def get_fqn_to_example_inputs( + model: torch.nn.Module, example_inputs: tuple[Any, ...] +) -> dict[str, tuple[Any, ...]]: + """Given a model and its example inputs, return a dictionary from + fully qualified name of submodules to example_inputs for that submodule, + e.g. {"linear1": (tensor1,), "linear2": (tensor2,), "sub": (tensor3,), + "sub.linear1": (tensor4,), ...} + + Used to make quantizing submodules easier now that FX Graph Mode Quantization requires + example inputs. + + Also works for keyword arguments with default values, we would flatten keyword + arguments as positional arguments and fill in the missing keyword args with default + values, e.g. if we have a forward function: + def forward(self, x, key1=3, key2=3): + ... + + and we call it with self.submodule(x, key2=6) + we'll get example_inputs: (x, 3, 6) + + user can also override `key1` with positional arguments as well: + for self.submodule(x, 5, key2=6) + we'll get: (x, 5, 6) + + variable positional arguments and variable positional keyword arguments in forward + function are not supported currently, so please make sure no submodules is using + them. + """ + root = model + fqn_to_example_inputs = {} + + def _patched_module_call(self, *args, **kwargs): + submodule_example_inputs = list(args).copy() + normalized_kwargs = _normalize_kwargs(self.forward, kwargs) + # minus 1 to skipping counting `self` + num_args = _get_num_pos_args(self.forward) - 1 + num_to_pop = num_args - len(submodule_example_inputs) + while num_to_pop and normalized_kwargs: + normalized_kwargs.popitem(last=False) + num_to_pop -= 1 + submodule_example_inputs.extend(normalized_kwargs.values()) + submodule_example_inputs_tuple = tuple(submodule_example_inputs) + fqn = _get_path_of_module(root, self) + if fqn is not None: + fqn_to_example_inputs[fqn] = submodule_example_inputs_tuple + return orig_module_call(self, *args, **kwargs) + + orig_module_call = torch.nn.Module.__call__ + torch.nn.Module.__call__ = _patched_module_call # type: ignore[method-assign] + try: + model(*example_inputs) + finally: + # restore the module call even if there is an exception + torch.nn.Module.__call__ = orig_module_call # type: ignore[method-assign] + return fqn_to_example_inputs + + +def _assert_and_get_unique_device(module: torch.nn.Module) -> Any: + """ + Returns the unique device for a module, or None if no device is found. + Throws an error if multiple devices are detected. + """ + devices = {p.device for p in module.parameters()} | { + p.device for p in module.buffers() + } + """ + As a temp workaround for AIMP HHC publish we added CPU check.remove it later. T163614564 + """ + if {torch.device("cpu"), torch.device("meta")} == devices: + warnings.warn( + "Both 'meta' and 'cpu' are present in the list of devices. Module can have one device. We Select 'cpu'.", + stacklevel=2, + ) + devices = {torch.device("cpu")} + "" + if len(devices) > 1: + raise AssertionError( + "prepare only works with cpu or single-device CUDA modules, " + f"but got devices {devices}" + ) + device = next(iter(devices)) if len(devices) > 0 else None + return device + + +DEPRECATION_WARNING = ( + "torch.ao.quantization is deprecated and will be removed in 2.10. \n" + "For migrations of users: \n" + "1. Eager mode quantization (torch.ao.quantization.quantize, " + "torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode " + "quantize_ API instead \n" + "2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx," + "torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization " + "API instead (prepare_pt2e, convert_pt2e) \n" + "3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) \n" + "see https://github.com/pytorch/ao/issues/2259 for more details" +) + + +__all__ = [ + "NodePattern", + "Pattern", + "MatchAllNode", + "check_node", + "get_combined_dict", + "is_per_tensor", + "is_per_channel", + "getattr_from_fqn", + "get_qparam_dict", + "get_swapped_custom_module_class", + "activation_dtype", + "weight_dtype", + "activation_is_statically_quantized", + "activation_is_dynamically_quantized", + "activation_is_int8_quantized", + "activation_is_int32_quantized", + "weight_is_quantized", + "weight_is_statically_quantized", + "op_is_int8_dynamically_quantized", + "get_qconfig_dtypes", + "get_quant_type", + "check_min_max_valid", + "calculate_qmin_qmax", + "has_no_children_ignoring_parametrizations", + "get_fqn_to_example_inputs", + "to_underlying_dtype", + "determine_qparams", + "validate_qmin_qmax", + "DEPRECATION_WARNING", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/autograd/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/autograd/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d5bb7ea5709457e2f72807e46c2c1523016bbcd Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/autograd/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/autograd/__pycache__/anomaly_mode.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/autograd/__pycache__/anomaly_mode.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb5f9f5e4d1354ce0046c7a20c24f9f30285fa56 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/autograd/__pycache__/anomaly_mode.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/autograd/__pycache__/forward_ad.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/autograd/__pycache__/forward_ad.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db9ca7bb37059ead425060577b1ffc065a68bb21 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/autograd/__pycache__/forward_ad.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/autograd/__pycache__/function.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/autograd/__pycache__/function.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1f17cb71c65dadceaa202e3e854c04ead00e05f4 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/autograd/__pycache__/function.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/autograd/__pycache__/functional.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/autograd/__pycache__/functional.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d867f83540681fe5e3077d1bf49952418d3b8fc Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/autograd/__pycache__/functional.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/autograd/__pycache__/grad_mode.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/autograd/__pycache__/grad_mode.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e6aaf624294e42a457a09706fe116ef619126f8 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/autograd/__pycache__/grad_mode.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/autograd/__pycache__/gradcheck.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/autograd/__pycache__/gradcheck.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cec91b60f5b53d9d2c96213e4713260f04245987 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/autograd/__pycache__/gradcheck.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/autograd/__pycache__/graph.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/autograd/__pycache__/graph.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0bce9f6dab149190cc1467e3787c3e4652c48037 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/autograd/__pycache__/graph.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/autograd/__pycache__/profiler.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/autograd/__pycache__/profiler.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..488c173578af20aa1139a136441a4902d33d7149 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/autograd/__pycache__/profiler.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/autograd/__pycache__/profiler_legacy.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/autograd/__pycache__/profiler_legacy.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d0c3a2b290274ff48d5dc5fd2349672607b20c2 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/autograd/__pycache__/profiler_legacy.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/autograd/__pycache__/profiler_util.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/autograd/__pycache__/profiler_util.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1cc0c17b1c1ec38a3a081fc03c562c4d89bce2fd Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/autograd/__pycache__/profiler_util.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/autograd/__pycache__/variable.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/autograd/__pycache__/variable.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..09b11194ffa5247b80fd0c84ad20d134d82e81a4 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/autograd/__pycache__/variable.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/autograd/_functions/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/autograd/_functions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4170fad3eeac788dcb36b6ae1ddbee1b44dc25a1 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/autograd/_functions/__init__.py @@ -0,0 +1 @@ +from .tensor import * # noqa: F403 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/autograd/_functions/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/autograd/_functions/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab5ad6cb5aa7d189c4c9b5f5c8eff0bafd43a0a9 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/autograd/_functions/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/autograd/_functions/__pycache__/tensor.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/autograd/_functions/__pycache__/tensor.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ac7d801e73d6af231015f7f36c79835bb18a8a6 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/autograd/_functions/__pycache__/tensor.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/autograd/_functions/__pycache__/utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/autograd/_functions/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bdf4b2bd9fe5076b30cb0e5ecbf64a4c4465ed98 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/autograd/_functions/__pycache__/utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/autograd/_functions/tensor.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/autograd/_functions/tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..716ae1db726ad5b397426e0669cfd241ee7ee556 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/autograd/_functions/tensor.py @@ -0,0 +1,72 @@ +# mypy: allow-untyped-defs +import operator +from functools import reduce +from typing_extensions import deprecated + +import torch +import torch._utils +from torch.autograd.function import Function + + +class Type(Function): + @staticmethod + @deprecated( + "`torch.autograd._functions.Type` is deprecated as of PyTorch 2.1, " + "please use `torch.tensor.to(dtype=dtype)` instead.", + category=FutureWarning, + ) + # pyrefly: ignore [bad-override] + def forward(ctx, i, dest_type): + ctx.input_type = type(i) + ctx.input_device = -1 if not i.is_cuda else i.get_device() + return i.type(dest_type) + + @staticmethod + # pyrefly: ignore [bad-override] + def backward(ctx, grad_output): + if ctx.input_device == -1: + return grad_output.type(ctx.input_type), None + else: + with torch.accelerator.device_index(ctx.input_device): + return grad_output.type(ctx.input_type), None + + +# TODO: deprecate this +class Resize(Function): + @staticmethod + # pyrefly: ignore [bad-override] + def forward(ctx, tensor, sizes): + ctx.sizes = sizes + ctx.numel = reduce(operator.mul, sizes, 1) + if tensor.numel() != ctx.numel: + raise RuntimeError( + ( + "requested resize to {} ({} elements in total), " + "but the given tensor has a size of {} ({} elements). " + "autograd's resize can only change the shape of a given " + "tensor, while preserving the number of elements. " + ).format( + "x".join(map(str, sizes)), + ctx.numel, + "x".join(map(str, tensor.size())), + tensor.numel(), + ) + ) + ctx.input_sizes = tensor.size() + if tensor.is_quantized: + tensor.copy_(tensor) + return tensor.contiguous().view(*sizes) + if tensor.is_contiguous(): + result = tensor.new(tensor).contiguous().view(*sizes) + return result + else: + return tensor.contiguous().view(*sizes) + + @staticmethod + # pyrefly: ignore [bad-override] + def backward(ctx, grad_output): + if grad_output.numel() != ctx.numel: + raise AssertionError( + f"Expected grad_output to have {ctx.numel} elements, but got {grad_output.numel()}" + ) + return grad_output.contiguous().view(ctx.input_sizes), None diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/autograd/_functions/utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/autograd/_functions/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1e74e21d3cef22c0fd459eff5934d4e531d5456d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/autograd/_functions/utils.py @@ -0,0 +1,26 @@ +# mypy: allow-untyped-defs + + +def maybe_view(tensor, size, check_same_size=True): + if check_same_size and tensor.size() == size: + return tensor + return tensor.contiguous().view(size) + + +def maybe_unexpand(tensor, old_size, check_same_size=True): + if check_same_size and tensor.size() == old_size: + return tensor + num_unsqueezed = tensor.dim() - len(old_size) + expanded_dims = [ + dim + for dim, (expanded, original) in enumerate( + zip(tensor.size()[num_unsqueezed:], old_size) + ) + if expanded != original + ] + + for _ in range(num_unsqueezed): + tensor = tensor.sum(0, keepdim=False) + for dim in expanded_dims: + tensor = tensor.sum(dim, keepdim=True) + return tensor diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/contrib/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/contrib/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6bad3ca4996a30cf030b90ebbf3b95d825b0629d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/contrib/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/contrib/__pycache__/_tensorboard_vis.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/contrib/__pycache__/_tensorboard_vis.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4fc0d653de3753208419fd5b3d4e582c14052406 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/contrib/__pycache__/_tensorboard_vis.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..65c9f34d23d0be6366af70073d932eec3238b195 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/__pycache__/_device_limits.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/__pycache__/_device_limits.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93fac43982f0c2280a9ec2201ace3478c584f3d8 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/__pycache__/_device_limits.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/__pycache__/_gpu_trace.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/__pycache__/_gpu_trace.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..05a2c73c2c6094ab8c70e4fce69bced403c3ad6c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/__pycache__/_gpu_trace.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/__pycache__/_memory_viz.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/__pycache__/_memory_viz.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c27b919dd89996fadc9ee8292d39d874bd558098 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/__pycache__/_memory_viz.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/__pycache__/_pin_memory_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/__pycache__/_pin_memory_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8abbe854f1818cbe330db56568f3c74aa78cb58d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/__pycache__/_pin_memory_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/__pycache__/_sanitizer.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/__pycache__/_sanitizer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9cd4a95c111e8d749ee64a2091fa710776151d52 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/__pycache__/_sanitizer.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/__pycache__/_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/__pycache__/_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ad23b60cce163e09437f54df62213b2be780405 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/__pycache__/_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/__pycache__/comm.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/__pycache__/comm.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..02d939d5459a7ba669cc25420974dbea826f5372 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/__pycache__/comm.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/__pycache__/gds.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/__pycache__/gds.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c8e6fc20d9dee34936d1358ff99c51367927cb0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/__pycache__/gds.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/__pycache__/graphs.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/__pycache__/graphs.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..06f5354906622fb22e9d84f4c33e93d83ee9bb88 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/__pycache__/graphs.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/__pycache__/green_contexts.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/__pycache__/green_contexts.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d164eab54d9abb1bcc9a5a10f956b7f04b8cd80 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/__pycache__/green_contexts.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/__pycache__/jiterator.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/__pycache__/jiterator.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e229a8a458f0b42f9029cfd8c4bb5b94ee1496a5 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/__pycache__/jiterator.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/__pycache__/memory.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/__pycache__/memory.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..37c8d9925dfcb9dcb7b17881ce4c84cb3945f254 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/__pycache__/memory.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/__pycache__/nccl.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/__pycache__/nccl.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..78ed8897b1ff04b64b184c51353d4f0bb999b5b9 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/__pycache__/nccl.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/__pycache__/nvtx.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/__pycache__/nvtx.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da6d2988144a54cb2b0f6451f53e791a7dc482bc Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/__pycache__/nvtx.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/__pycache__/profiler.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/__pycache__/profiler.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad13f9295b6a755cd3f8c9d715951431665f3634 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/__pycache__/profiler.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/__pycache__/random.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/__pycache__/random.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45fd38c5c931f9760b0ca409fa9fa2f19b654a14 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/__pycache__/random.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/__pycache__/sparse.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/__pycache__/sparse.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..788e4b7fbf3af505314d1778b906c11532a7bbc1 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/__pycache__/sparse.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/__pycache__/streams.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/__pycache__/streams.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..92adb70c7a419e82680fee77a486174e039d008d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/__pycache__/streams.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/__pycache__/tunable.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/__pycache__/tunable.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..234af882b14581cd892735af03d1c5921970f301 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/__pycache__/tunable.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/amp/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/amp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..88ef0b5acac5e5bdeb034169052bcf5aa7456e33 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/amp/__init__.py @@ -0,0 +1,13 @@ +# pyrefly: ignore [deprecated] +from .autocast_mode import autocast, custom_bwd, custom_fwd +from .common import amp_definitely_not_available +from .grad_scaler import GradScaler + + +__all__ = [ + "amp_definitely_not_available", + "autocast", + "custom_bwd", + "custom_fwd", + "GradScaler", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/amp/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/amp/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0319d5e3c0ab6ed70057321d87ffc85691be5e27 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/amp/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/amp/__pycache__/autocast_mode.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/amp/__pycache__/autocast_mode.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ba588d350d01136c1d5b822552a8d6d24a3ed5d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/amp/__pycache__/autocast_mode.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/amp/__pycache__/common.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/amp/__pycache__/common.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf84a5a36fd7e63ecc7673377261a0f5d08d6a2c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/amp/__pycache__/common.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/amp/__pycache__/grad_scaler.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/amp/__pycache__/grad_scaler.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..11eb84b56389632cae3303b29d604e8b7f125e4e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/amp/__pycache__/grad_scaler.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/amp/autocast_mode.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/amp/autocast_mode.py new file mode 100644 index 0000000000000000000000000000000000000000..e6b63c708d3f2ddfc162a4431e114a2bcf47e9eb --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/amp/autocast_mode.py @@ -0,0 +1,110 @@ +# mypy: allow-untyped-defs +import functools +import sys +from typing import Any +from typing_extensions import deprecated + +import torch + + +__all__ = ["autocast", "custom_fwd", "custom_bwd"] + + +@deprecated( + "`torch.cuda.amp.autocast(args...)` is deprecated. " + "Please use `torch.amp.autocast('cuda', args...)` instead.", + category=FutureWarning, +) +class autocast(torch.amp.autocast_mode.autocast): + r"""See :class:`torch.autocast`. + + ``torch.cuda.amp.autocast(args...)`` is deprecated. Please use ``torch.amp.autocast("cuda", args...)`` instead. + """ + + # TODO: remove this conditional once we stop supporting Python < 3.13 + # Prior to Python 3.13, inspect.signature could not retrieve the correct + # signature information for classes decorated with @deprecated (unless + # the __new__ static method was explicitly defined); + # + # However, this issue has been fixed in Python 3.13 and later versions. + if sys.version_info < (3, 13): + + def __new__( + cls, + enabled: bool = True, + dtype: torch.dtype = torch.float16, + cache_enabled: bool = True, + ): + return super().__new__(cls) + + def __init_subclass__(cls): + pass + + def __init__( + self, + enabled: bool = True, + dtype: torch.dtype = torch.float16, + cache_enabled: bool = True, + ): + if torch._jit_internal.is_scripting(): + self._enabled = enabled + self.device = "cuda" + self.fast_dtype = dtype + return + super().__init__( + "cuda", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled + ) + + def __enter__(self): + if torch._jit_internal.is_scripting(): + return self + return super().__enter__() + + # TODO: discuss a unified TorchScript-friendly API for autocast + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[override] + if torch._jit_internal.is_scripting(): + return + return super().__exit__(exc_type, exc_val, exc_tb) + + def __call__(self, func): + if torch._jit_internal.is_scripting(): + return func + return super().__call__(func) + + +# Preserved only for BC reasons +@deprecated( + "`torch.cuda.amp.autocast_mode._cast(value, dtype)` is deprecated. " + "Please use `torch.amp.autocast_mode._cast(value, 'cuda', dtype)` instead.", + category=FutureWarning, +) +def _cast(value, dtype): + return torch.amp.autocast_mode._cast(value, "cuda", dtype) + + +@deprecated( + "`torch.cuda.amp.custom_fwd(args...)` is deprecated. " + "Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.", + category=FutureWarning, +) +def custom_fwd(fwd=None, *, cast_inputs=None): + """ + ``torch.cuda.amp.custom_fwd(args...)`` is deprecated. Please use + ``torch.amp.custom_fwd(args..., device_type='cuda')`` instead. + """ + return functools.partial(torch.amp.custom_fwd, device_type="cuda")( + fwd=fwd, cast_inputs=cast_inputs + ) + + +@deprecated( + "`torch.cuda.amp.custom_bwd(args...)` is deprecated. " + "Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.", + category=FutureWarning, +) +def custom_bwd(bwd): + """ + ``torch.cuda.amp.custom_bwd(args...)`` is deprecated. Please use + ``torch.amp.custom_bwd(args..., device_type='cuda')`` instead. + """ + return functools.partial(torch.amp.custom_bwd, device_type="cuda")(bwd) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/amp/common.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/amp/common.py new file mode 100644 index 0000000000000000000000000000000000000000..915a9b4f4a9ca6c147abefd7c8ab1891ee5a8179 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/amp/common.py @@ -0,0 +1,11 @@ +# mypy: allow-untyped-defs +from importlib.util import find_spec + +import torch + + +__all__ = ["amp_definitely_not_available"] + + +def amp_definitely_not_available(): + return not (torch.cuda.is_available() or find_spec("torch_xla")) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/amp/grad_scaler.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/amp/grad_scaler.py new file mode 100644 index 0000000000000000000000000000000000000000..62e2020073c8ed99f7295edd1aaea4c54d815f63 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/cuda/amp/grad_scaler.py @@ -0,0 +1,38 @@ +from typing_extensions import deprecated + +import torch + +# We need to keep this unused import for BC reasons +from torch.amp.grad_scaler import OptState # noqa: F401 + + +__all__ = ["GradScaler"] + + +class GradScaler(torch.amp.GradScaler): + r""" + See :class:`torch.amp.GradScaler`. + ``torch.cuda.amp.GradScaler(args...)`` is deprecated. Please use ``torch.amp.GradScaler("cuda", args...)`` instead. + """ + + @deprecated( + "`torch.cuda.amp.GradScaler(args...)` is deprecated. " + "Please use `torch.amp.GradScaler('cuda', args...)` instead.", + category=FutureWarning, + ) + def __init__( + self, + init_scale: float = 2.0**16, + growth_factor: float = 2.0, + backoff_factor: float = 0.5, + growth_interval: int = 2000, + enabled: bool = True, + ) -> None: + super().__init__( + "cuda", + init_scale=init_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + enabled=enabled, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fft/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fft/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b710eb61f38a589abf3e91224fdfad0153e51f65 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/fft/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/futures/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/futures/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c822c956dce22374ff8d0e08fcd1665b4c0dd71 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/futures/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/functional.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/functional.h new file mode 100644 index 0000000000000000000000000000000000000000..675f6dc663bd80f4518ffc301dd10694f5a79f86 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/functional.h @@ -0,0 +1,9 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/functional_base.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/functional_base.h new file mode 100644 index 0000000000000000000000000000000000000000..b81c80ac1efbf8ea2d24e9c0d524e12c75a3e061 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/functional_base.h @@ -0,0 +1,480 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include + +namespace at { +namespace detail { +// We prefer to convert through float for reduced-precision floating +// point types if we have a Vectorized specialization for float and we +// don't have one for the actual type in question. +template +struct should_prefer_converting_through_float + : std::bool_constant< + is_reduced_floating_point_v && + vec::is_vec_specialized_for_v && + !vec::is_vec_specialized_for_v> {}; + +template +constexpr auto should_prefer_converting_through_float_v = + should_prefer_converting_through_float::value; +} // namespace detail + +namespace vec { +// slow path +template +inline scalar_t vec_reduce_all( + const Op& vec_fun, + vec::Vectorized acc_vec, + int64_t size) { + using Vec = vec::Vectorized; + scalar_t acc_arr[Vec::size()]; + acc_vec.store(acc_arr); + for (const auto i : c10::irange(1, size)) { + std::array acc_arr_next = {0}; + acc_arr_next[0] = acc_arr[i]; + Vec acc_vec_next = Vec::loadu(acc_arr_next.data()); + acc_vec = vec_fun(acc_vec, acc_vec_next); + } + acc_vec.store(acc_arr); + return acc_arr[0]; +} + +template +struct VecReduceAllSIMD { + static inline scalar_t apply( + const Op& vec_fun, + const Vectorized& acc_vec) { + return vec_reduce_all(vec_fun, acc_vec, Vectorized::size()); + } +}; + +#if defined(__GNUC__) && (__GNUC__ > 5) && !defined(_MSC_VER) && \ + !defined(C10_MOBILE) +#if defined(CPU_CAPABILITY_AVX2) +template +struct VecReduceAllSIMD { + static inline float apply( + const Op& vec_fun, + const Vectorized& acc_vec) { + using Vec = Vectorized; + Vec v = acc_vec; + // 128-bit shuffle + Vec v1 = _mm256_permute2f128_ps(v, v, 0x1); + v = vec_fun(v, v1); + // 64-bit shuffle + v1 = _mm256_shuffle_ps(v, v, 0x4E); + v = vec_fun(v, v1); + // 32-bit shuffle + v1 = _mm256_shuffle_ps(v, v, 0xB1); + v = vec_fun(v, v1); + return _mm256_cvtss_f32(v); + } +}; +#endif // defined(CPU_CAPABILITY_AVX2) +#if defined(CPU_CAPABILITY_AVX512) +template +struct VecReduceAllSIMD { + static inline float apply( + const Op& vec_fun, + const Vectorized& acc_vec) { + using Vec = Vectorized; + Vec v = acc_vec; + // 256-bit shuffle + Vec v1 = _mm512_shuffle_f32x4(v, v, 0x4E); + v = vec_fun(v, v1); + // 128-bit shuffle + v1 = _mm512_shuffle_f32x4(v, v, 0xB1); + v = vec_fun(v, v1); + // 64-bit shuffle + v1 = _mm512_shuffle_ps(v, v, 0x4E); + v = vec_fun(v, v1); + // 32-bit shuffle + v1 = _mm512_shuffle_ps(v, v, 0xB1); + v = vec_fun(v, v1); + return _mm512_cvtss_f32(v); + } +}; +#endif // defined(CPU_CAPABILITY_AVX512) +#endif // defined(__GNUC__) && (__GNUC__ > 5) && !defined(_MSC_VER) && + // !defined(C10_MOBILE) + +#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && \ + !defined(CPU_CAPABILITY_SVE) +template +struct VecReduceAllSIMD { + static inline float apply( + const Op& vec_fun, + const Vectorized& acc_vec) { + using Vec = Vectorized; + Vec v = acc_vec; + + // 64-bit shuffle: [a1+a5, a2+a6, a3+a7, a4+a8, -, -, -, -] -> [a3+a7, + // a4+a8, a1+a5, a2+a6, -, -, -, -] + float32x4_t v1_1 = vextq_f32(v, v, 2); + Vec v1 = v1_1; + // [a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, -, -, -, -] + v = vec_fun(v, v1); + + // 32-bit shuffle: [a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, -, + // -, -, -] -> [a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, -, -, -, + // -] + v1_1 = vrev64q_f32(v); + v1 = v1_1; + // [a1+a2+a3+a4+a5+a6+a7+a8, a1+a2+a3+a4+a5+a6+a7+a8, + // a1+a2+a3+a4+a5+a6+a7+a8, a1+a2+a3+a4+a5+a6+a7+a8, -, -, -, -] + v = vec_fun(v, v1); + + return v[0]; + } +}; + +template <> +struct VecReduceAllSIMD>> { + static inline float apply( + const std::plus>& vec_fun, + const Vectorized& acc_vec) { + return vaddvq_f32(acc_vec); + } +}; +#endif // defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) + // && !defined(CPU_CAPABILITY_SVE) + +#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && \ + defined(CPU_CAPABILITY_SVE256) +template +struct VecReduceAllSIMD { + static inline float apply( + const Op& vec_fun, + const Vectorized& acc_vec) { + using Vec = Vectorized; + Vec v = acc_vec; + // 128-bit shuffle + svuint32_t ind = svdupq_n_u32(4, 5, 6, 7); + Vec v1 = svtbl_f32(v, ind); + v = vec_fun(v, v1); + // 64-bit shuffle + ind = svdupq_n_u32(2, 3, 0, 1); + v1 = svtbl_f32(v, ind); + v = vec_fun(v, v1); + // 32-bit shuffle + ind = svdupq_n_u32(1, 0, 2, 3); + v1 = svtbl_f32(v, ind); + v = vec_fun(v, v1); + return svlasta(svpfalse(), v); + } +}; +#endif // defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) + // && defined(CPU_CAPABILITY_SVE256) + +template +inline scalar_t vec_reduce_all( + const Op& vec_fun, + const Vectorized& acc_vec) { + return VecReduceAllSIMD::apply(vec_fun, acc_vec); +} + +template < + typename scalar_t, + typename Op, + typename std::enable_if_t, int> = 0> +inline scalar_t reduce_all( + const Op& vec_fun, + const scalar_t* data, + int64_t size) { + using Vec = vec::Vectorized; + if (size < Vec::size()) + return vec_reduce_all(vec_fun, Vec::loadu(data, size), size); + int64_t d = Vec::size(); + Vec acc_vec = Vec::loadu(data); + for (; d < size - (size % Vec::size()); d += Vec::size()) { + Vec data_vec = Vec::loadu(data + d); + acc_vec = vec_fun(acc_vec, data_vec); + } + if (size - d > 0) { + Vec data_vec = Vec::loadu(data + d, size - d); + acc_vec = Vec::set(acc_vec, vec_fun(acc_vec, data_vec), size - d); + } + return vec_reduce_all(vec_fun, acc_vec); +} + +// similar to reduce_all, but reduces into two outputs +template < + typename scalar_t, + typename Op1, + typename Op2, + typename std::enable_if_t, int> = 0> +inline std::pair reduce2_all( + const Op1& vec_fun1, + const Op2& vec_fun2, + const scalar_t* data, + int64_t size) { + using Vec = vec::Vectorized; + if (size < Vec::size()) { + auto loaded_data = Vec::loadu(data, size); + return std::pair( + vec_reduce_all(vec_fun1, loaded_data, size), + vec_reduce_all(vec_fun2, loaded_data, size)); + } + int64_t d = Vec::size(); + Vec acc_vec1 = Vec::loadu(data); + Vec acc_vec2 = Vec::loadu(data); + for (; d < size - (size % Vec::size()); d += Vec::size()) { + Vec data_vec = Vec::loadu(data + d); + acc_vec1 = vec_fun1(acc_vec1, data_vec); + acc_vec2 = vec_fun2(acc_vec2, data_vec); + } + if (size - d > 0) { + Vec data_vec = Vec::loadu(data + d, size - d); + acc_vec1 = Vec::set(acc_vec1, vec_fun1(acc_vec1, data_vec), size - d); + acc_vec2 = Vec::set(acc_vec2, vec_fun2(acc_vec2, data_vec), size - d); + } + return std::pair( + vec_reduce_all(vec_fun1, acc_vec1), vec_reduce_all(vec_fun2, acc_vec2)); +} + +template < + typename scalar_t, + typename MapOp, + typename ReduceOp, + typename std::enable_if_t, int> = 0> +inline scalar_t map_reduce_all( + const MapOp& map_fun, + const ReduceOp& red_fun, + const scalar_t* data, + int64_t size) { + using Vec = vec::Vectorized; + if (size < Vec::size()) + return vec_reduce_all(red_fun, map_fun(Vec::loadu(data, size)), size); + int64_t d = Vec::size(); + Vec acc_vec = map_fun(Vec::loadu(data)); + for (; d < size - (size % Vec::size()); d += Vec::size()) { + Vec data_vec = Vec::loadu(data + d); + data_vec = map_fun(data_vec); + acc_vec = red_fun(acc_vec, data_vec); + } + if (size - d > 0) { + Vec data_vec = Vec::loadu(data + d, size - d); + data_vec = map_fun(data_vec); + acc_vec = Vec::set(acc_vec, red_fun(acc_vec, data_vec), size - d); + } + return vec_reduce_all(red_fun, acc_vec); +} + +template < + typename scalar_t, + typename MapOp, + typename ReduceOp, + typename std::enable_if_t, int> = 0> +inline scalar_t map2_reduce_all( + const MapOp& map_fun, + const ReduceOp& red_fun, + const scalar_t* data, + const scalar_t* data2, + int64_t size) { + using Vec = vec::Vectorized; + if (size < Vec::size()) { + Vec data_vec = Vec::loadu(data, size); + Vec data2_vec = Vec::loadu(data2, size); + data_vec = map_fun(data_vec, data2_vec); + return vec_reduce_all(red_fun, data_vec, size); + } + int64_t d = Vec::size(); + Vec acc_vec = map_fun(Vec::loadu(data), Vec::loadu(data2)); + for (; d < size - (size % Vec::size()); d += Vec::size()) { + Vec data_vec = Vec::loadu(data + d); + Vec data2_vec = Vec::loadu(data2 + d); + data_vec = map_fun(data_vec, data2_vec); + acc_vec = red_fun(acc_vec, data_vec); + } + if (size - d > 0) { + Vec data_vec = Vec::loadu(data + d, size - d); + Vec data2_vec = Vec::loadu(data2 + d, size - d); + data_vec = map_fun(data_vec, data2_vec); + acc_vec = Vec::set(acc_vec, red_fun(acc_vec, data_vec), size - d); + } + return vec_reduce_all(red_fun, acc_vec); +} + +template < + typename scalar_t, + typename MapOp, + typename ReduceOp, + typename std::enable_if_t, int> = 0> +inline scalar_t map3_reduce_all( + const MapOp& map_fun, + const ReduceOp& red_fun, + const scalar_t* data, + const scalar_t* data2, + const scalar_t* data3, + int64_t size) { + using Vec = vec::Vectorized; + if (size < Vec::size()) { + Vec data_vec = Vec::loadu(data, size); + Vec data2_vec = Vec::loadu(data2, size); + Vec data3_vec = Vec::loadu(data3, size); + data_vec = map_fun(data_vec, data2_vec, data3_vec); + return vec_reduce_all(red_fun, data_vec, size); + } + + int64_t d = Vec::size(); + Vec acc_vec = map_fun(Vec::loadu(data), Vec::loadu(data2), Vec::loadu(data3)); + for (; d < size - (size % Vec::size()); d += Vec::size()) { + Vec data_vec = Vec::loadu(data + d); + Vec data2_vec = Vec::loadu(data2 + d); + Vec data3_vec = Vec::loadu(data3 + d); + data_vec = map_fun(data_vec, data2_vec, data3_vec); + acc_vec = red_fun(acc_vec, data_vec); + } + if (size - d > 0) { + Vec data_vec = Vec::loadu(data + d, size - d); + Vec data2_vec = Vec::loadu(data2 + d, size - d); + Vec data3_vec = Vec::loadu(data3 + d, size - d); + data_vec = map_fun(data_vec, data2_vec, data3_vec); + acc_vec = Vec::set(acc_vec, red_fun(acc_vec, data_vec), size - d); + } + return vec_reduce_all(red_fun, acc_vec); +} + +template < + typename scalar_t, + typename Op, + typename std::enable_if_t< + !detail::should_prefer_converting_through_float_v && + std::is_invocable_v>, + int> = 0> +inline void map( + const Op& vec_fun, + scalar_t* output_data, + const scalar_t* input_data, + int64_t size) { + using Vec = vec::Vectorized; + int64_t d = 0; + for (; d < size - (size % Vec::size()); d += Vec::size()) { + Vec output_vec = vec_fun(Vec::loadu(input_data + d)); + output_vec.store(output_data + d); + } + if (size - d > 0) { + Vec output_vec = vec_fun(Vec::loadu(input_data + d, size - d)); + output_vec.store(output_data + d, size - d); + } +} + +template < + typename scalar_t, + typename Op, + typename std::enable_if_t< + !detail::should_prefer_converting_through_float_v && + std::is_invocable_v< + Op, + vec::Vectorized, + vec::Vectorized>, + int> = 0> +inline void map2( + const Op& vec_fun, + scalar_t* output_data, + const scalar_t* input_data, + const scalar_t* input_data2, + int64_t size) { + using Vec = vec::Vectorized; + int64_t d = 0; + for (; d < size - (size % Vec::size()); d += Vec::size()) { + Vec data_vec = Vec::loadu(input_data + d); + Vec data_vec2 = Vec::loadu(input_data2 + d); + Vec output_vec = vec_fun(data_vec, data_vec2); + output_vec.store(output_data + d); + } + if (size - d > 0) { + Vec data_vec = Vec::loadu(input_data + d, size - d); + Vec data_vec2 = Vec::loadu(input_data2 + d, size - d); + Vec output_vec = vec_fun(data_vec, data_vec2); + output_vec.store(output_data + d, size - d); + } +} + +template < + typename scalar_t, + typename Op, + typename std::enable_if_t< + !detail::should_prefer_converting_through_float_v && + std::is_invocable_v< + Op, + vec::Vectorized, + vec::Vectorized, + vec::Vectorized>, + int> = 0> +inline void map3( + const Op& vec_fun, + scalar_t* output_data, + const scalar_t* input_data1, + const scalar_t* input_data2, + const scalar_t* input_data3, + int64_t size) { + using Vec = vec::Vectorized; + int64_t d = 0; + for (; d < size - (size % Vec::size()); d += Vec::size()) { + Vec data_vec1 = Vec::loadu(input_data1 + d); + Vec data_vec2 = Vec::loadu(input_data2 + d); + Vec data_vec3 = Vec::loadu(input_data3 + d); + Vec output_vec = vec_fun(data_vec1, data_vec2, data_vec3); + output_vec.store(output_data + d); + } + if (size - d > 0) { + Vec data_vec1 = Vec::loadu(input_data1 + d, size - d); + Vec data_vec2 = Vec::loadu(input_data2 + d, size - d); + Vec data_vec3 = Vec::loadu(input_data3 + d, size - d); + Vec output_vec = vec_fun(data_vec1, data_vec2, data_vec3); + output_vec.store(output_data + d, size - d); + } +} + +template < + typename scalar_t, + typename Op, + typename std::enable_if_t< + !detail::should_prefer_converting_through_float_v && + std::is_invocable_v< + Op, + vec::Vectorized, + vec::Vectorized, + vec::Vectorized, + vec::Vectorized>, + int> = 0> +inline void map4( + const Op& vec_fun, + scalar_t* output_data, + const scalar_t* input_data1, + const scalar_t* input_data2, + const scalar_t* input_data3, + const scalar_t* input_data4, + int64_t size) { + using Vec = vec::Vectorized; + int64_t d = 0; + for (; d < size - (size % Vec::size()); d += Vec::size()) { + Vec data_vec1 = Vec::loadu(input_data1 + d); + Vec data_vec2 = Vec::loadu(input_data2 + d); + Vec data_vec3 = Vec::loadu(input_data3 + d); + Vec data_vec4 = Vec::loadu(input_data4 + d); + Vec output_vec = vec_fun(data_vec1, data_vec2, data_vec3, data_vec4); + output_vec.store(output_data + d); + } + if (size - d > 0) { + Vec data_vec1 = Vec::loadu(input_data1 + d, size - d); + Vec data_vec2 = Vec::loadu(input_data2 + d, size - d); + Vec data_vec3 = Vec::loadu(input_data3 + d, size - d); + Vec data_vec4 = Vec::loadu(input_data4 + d, size - d); + Vec output_vec = vec_fun(data_vec1, data_vec2, data_vec3, data_vec4); + output_vec.store(output_data + d, size - d); + } +} + +} // namespace vec +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/functional_bfloat16.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/functional_bfloat16.h new file mode 100644 index 0000000000000000000000000000000000000000..ad7daa651fd0c2a685cd52c5ef03b3994ffe1554 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/functional_bfloat16.h @@ -0,0 +1,652 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include + +namespace at::vec { +// BFloat16 specification +template +struct VecScalarType { + using type = scalar_t; +}; +template <> +struct VecScalarType { + using type = float; +}; +template <> +struct VecScalarType { + using type = float; +}; + +// This is different from at::acc_type since we only need to specialize BFloat16 +template +using vec_scalar_t = typename VecScalarType::type; + +// Vector conversion between float and bfloat16/half +template <> +inline std::tuple, Vectorized> convert_to_float< + BFloat16>(const Vectorized& a) { + return convert_bfloat16_float(a); +} + +template <> +inline std::tuple, Vectorized> convert_to_float( + const Vectorized& a) { + return convert_half_float(a); +} + +template <> +inline Vectorized convert_from_float( + const Vectorized& a, + const Vectorized& b) { + return convert_float_bfloat16(a, b); +} + +template <> +inline Vectorized convert_from_float( + const Vectorized& a, + const Vectorized& b) { + return convert_float_half(a, b); +} + +template < + typename scalar_t, + typename std::enable_if_t, int> = 0> +inline void load_to_float( + const scalar_t* data, + Vectorized& out1, + Vectorized& out2); + +template <> +inline void load_to_float( + const BFloat16* data, + Vectorized& out1, + Vectorized& out2) { + load_fp32_from_bf16(data, out1, out2); +} + +template <> +inline void load_to_float( + const Half* data, + Vectorized& out1, + Vectorized& out2) { + load_fp32_from_fp16(data, out1, out2); +} + +template < + typename scalar_t, + typename std::enable_if_t, int> = 0> +inline void load_to_float(const scalar_t* data, Vectorized& out); + +template <> +inline void load_to_float( + const BFloat16* data, + Vectorized& out) { + load_fp32_from_bf16(data, out); +} + +template <> +inline void load_to_float(const Half* data, Vectorized& out) { + load_fp32_from_fp16(data, out); +} + +// Note that we already have specialized member of Vectorized for +// BFloat16 so the following functions would run smoothly: +// using Vec = Vectorized; +// Vec one = Vec(BFloat16(1)); +// vec::map([](Vec x) { return one / (one + x.exp()); }, y_ptr, x_ptr, N); +// +// Then why we still need to specialize "functional"? +// If we do specialization at Vectorized<> level, the above example would need +// 3 pairs of conversion of bf16->fp32/fp32->bf16, each for ".exp()", "+" and +// "/". If we do specialization at vec::map<>() level, we have only 1 pair of +// conversion of bf16->fp32/fp32->bf16, for the input and output BFloat16 +// vector only. +// +// The following BFloat16 functionality will only do data type conversion for +// input and output vector (reduce functionality will only convert the final +// scalar back to bf16). Compared to Vectorized<> specialization, +// 1. better performance since we have less data type conversion; +// 2. less rounding error since immediate results are kept in fp32; +// 3. accumulation done on data type of fp32. +// +// If you plan to extend this file, please ensure adding unit tests at +// aten/src/ATen/test/vec_test_all_types.cpp +// +template < + typename scalar_t, + typename Op, + typename std::enable_if_t, int> = 0> +inline float reduce_all(const Op& vec_fun, const scalar_t* data, int64_t size) { + using bVec = vec::Vectorized; + using fVec = vec::Vectorized; + if (size < bVec::size()) { + bVec data_bvec = bVec::loadu(data, size); + auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); + if (size > fVec::size()) { + data_fvec0 = fVec::set( + data_fvec0, vec_fun(data_fvec0, data_fvec1), size - fVec::size()); + return vec_reduce_all(vec_fun, data_fvec0, fVec::size()); + } else { + return vec_reduce_all(vec_fun, data_fvec0, size); + } + } + int64_t d = bVec::size(); + bVec acc_bvec = bVec::loadu(data); + auto [acc_fvec0, acc_fvec1] = convert_to_float(acc_bvec); + for (; d < size - (size % bVec::size()); d += bVec::size()) { + bVec data_bvec = bVec::loadu(data + d); + auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); + acc_fvec0 = vec_fun(acc_fvec0, data_fvec0); + acc_fvec1 = vec_fun(acc_fvec1, data_fvec1); + } + if (size - d > 0) { + bVec data_bvec = bVec::loadu(data + d, size - d); + auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); + if (size - d > fVec::size()) { + acc_fvec0 = vec_fun(acc_fvec0, data_fvec0); + acc_fvec1 = fVec::set( + acc_fvec1, vec_fun(acc_fvec1, data_fvec1), size - d - fVec::size()); + } else { + acc_fvec0 = + fVec::set(acc_fvec0, vec_fun(acc_fvec0, data_fvec0), size - d); + } + } + acc_fvec0 = vec_fun(acc_fvec0, acc_fvec1); + return vec_reduce_all(vec_fun, acc_fvec0); +} + +template < + typename scalar_t, + typename Op1, + typename Op2, + typename std::enable_if_t, int> = 0> +inline std::pair reduce2_all( + const Op1& vec_fun1, + const Op2& vec_fun2, + const scalar_t* data, + int64_t size) { + using bVec = vec::Vectorized; + using fVec = vec::Vectorized; + if (size < bVec::size()) { + bVec data_bvec = bVec::loadu(data, size); + auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); + if (size > fVec::size()) { + fVec acc1_fvec = fVec::set( + data_fvec0, vec_fun1(data_fvec0, data_fvec1), size - fVec::size()); + fVec acc2_fvec = fVec::set( + data_fvec0, vec_fun2(data_fvec0, data_fvec1), size - fVec::size()); + return std::pair( + vec_reduce_all(vec_fun1, acc1_fvec, fVec::size()), + vec_reduce_all(vec_fun2, acc2_fvec, fVec::size())); + } else { + return std::pair( + vec_reduce_all(vec_fun1, data_fvec0, size), + vec_reduce_all(vec_fun2, data_fvec0, size)); + } + } + int64_t d = bVec::size(); + bVec acc_bvec = bVec::loadu(data); + auto [acc1_fvec0, acc1_fvec1] = convert_to_float(acc_bvec); + auto [acc2_fvec0, acc2_fvec1] = convert_to_float(acc_bvec); + for (; d < size - (size % bVec::size()); d += bVec::size()) { + bVec data_bvec = bVec::loadu(data + d); + auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); + acc1_fvec0 = vec_fun1(acc1_fvec0, data_fvec0); + acc1_fvec1 = vec_fun1(acc1_fvec1, data_fvec1); + acc2_fvec0 = vec_fun2(acc2_fvec0, data_fvec0); + acc2_fvec1 = vec_fun2(acc2_fvec1, data_fvec1); + } + if (size - d > 0) { + bVec data_bvec = bVec::loadu(data + d, size - d); + auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); + if (size - d > fVec::size()) { + acc1_fvec0 = vec_fun1(acc1_fvec0, data_fvec0); + acc1_fvec1 = fVec::set( + acc1_fvec1, + vec_fun1(acc1_fvec1, data_fvec1), + size - d - fVec::size()); + acc2_fvec0 = vec_fun2(acc2_fvec0, data_fvec0); + acc2_fvec1 = fVec::set( + acc2_fvec1, + vec_fun2(acc2_fvec1, data_fvec1), + size - d - fVec::size()); + } else { + acc1_fvec0 = + fVec::set(acc1_fvec0, vec_fun1(acc1_fvec0, data_fvec0), size - d); + acc2_fvec0 = + fVec::set(acc2_fvec0, vec_fun2(acc2_fvec0, data_fvec0), size - d); + } + } + acc1_fvec0 = vec_fun1(acc1_fvec0, acc1_fvec1); + acc2_fvec0 = vec_fun2(acc2_fvec0, acc2_fvec1); + return std::pair( + vec_reduce_all(vec_fun1, acc1_fvec0), + vec_reduce_all(vec_fun2, acc2_fvec0)); +} + +template < + typename scalar_t, + typename MapOp, + typename ReduceOp, + typename std::enable_if_t, int> = 0> +inline float map_reduce_all( + const MapOp& map_fun, + const ReduceOp& red_fun, + const scalar_t* data, + int64_t size) { + using bVec = vec::Vectorized; + using fVec = vec::Vectorized; + if (size < bVec::size()) { + bVec data_bvec = bVec::loadu(data, size); + auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); + if (size > fVec::size()) { + data_fvec0 = map_fun(data_fvec0); + data_fvec1 = map_fun(data_fvec1); + data_fvec0 = fVec::set( + data_fvec0, red_fun(data_fvec0, data_fvec1), size - fVec::size()); + return vec_reduce_all(red_fun, data_fvec0, fVec::size()); + } else { + data_fvec0 = map_fun(data_fvec0); + return vec_reduce_all(red_fun, data_fvec0, size); + } + } + int64_t d = bVec::size(); + bVec acc_bvec = bVec::loadu(data); + auto [acc_fvec0, acc_fvec1] = convert_to_float(acc_bvec); + acc_fvec0 = map_fun(acc_fvec0); + acc_fvec1 = map_fun(acc_fvec1); + for (; d < size - (size % bVec::size()); d += bVec::size()) { + bVec data_bvec = bVec::loadu(data + d); + auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); + data_fvec0 = map_fun(data_fvec0); + data_fvec1 = map_fun(data_fvec1); + acc_fvec0 = red_fun(acc_fvec0, data_fvec0); + acc_fvec1 = red_fun(acc_fvec1, data_fvec1); + } + if (size - d > 0) { + bVec data_bvec = bVec::loadu(data + d, size - d); + auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); + if (size - d > fVec::size()) { + data_fvec0 = map_fun(data_fvec0); + data_fvec1 = map_fun(data_fvec1); + acc_fvec0 = red_fun(acc_fvec0, data_fvec0); + acc_fvec1 = fVec::set( + acc_fvec1, red_fun(acc_fvec1, data_fvec1), size - d - fVec::size()); + } else { + data_fvec0 = map_fun(data_fvec0); + acc_fvec0 = + fVec::set(acc_fvec0, red_fun(acc_fvec0, data_fvec0), size - d); + } + } + acc_fvec0 = red_fun(acc_fvec0, acc_fvec1); + return vec_reduce_all(red_fun, acc_fvec0); +} + +template < + typename scalar_t, + typename MapOp, + typename ReduceOp, + typename std::enable_if_t, int> = 0> +inline float map2_reduce_all( + const MapOp& map_fun, + const ReduceOp& red_fun, + const scalar_t* data, + const scalar_t* data2, + int64_t size) { + using bVec = vec::Vectorized; + using fVec = vec::Vectorized; + if (size < bVec::size()) { + bVec data_bvec = bVec::loadu(data, size); + auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); + bVec data2_bvec = bVec::loadu(data2, size); + auto [data2_fvec0, data2_fvec1] = convert_to_float(data2_bvec); + if (size > fVec::size()) { + data_fvec0 = map_fun(data_fvec0, data2_fvec0); + data_fvec1 = map_fun(data_fvec1, data2_fvec1); + data_fvec0 = fVec::set( + data_fvec0, red_fun(data_fvec0, data_fvec1), size - fVec::size()); + return vec_reduce_all(red_fun, data_fvec0, fVec::size()); + } else { + data_fvec0 = map_fun(data_fvec0, data2_fvec0); + return vec_reduce_all(red_fun, data_fvec0, size); + } + } + int64_t d = bVec::size(); + bVec acc_bvec = bVec::loadu(data); + auto [acc_fvec0, acc_fvec1] = convert_to_float(acc_bvec); + bVec acc2_bvec = bVec::loadu(data2); + auto [acc2_fvec0, acc2_fvec1] = convert_to_float(acc2_bvec); + acc_fvec0 = map_fun(acc_fvec0, acc2_fvec0); + acc_fvec1 = map_fun(acc_fvec1, acc2_fvec1); + for (; d < size - (size % bVec::size()); d += bVec::size()) { + bVec data_bvec = bVec::loadu(data + d); + auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); + bVec data2_bvec = bVec::loadu(data2 + d); + auto [data2_fvec0, data2_fvec1] = convert_to_float(data2_bvec); + data_fvec0 = map_fun(data_fvec0, data2_fvec0); + data_fvec1 = map_fun(data_fvec1, data2_fvec1); + acc_fvec0 = red_fun(acc_fvec0, data_fvec0); + acc_fvec1 = red_fun(acc_fvec1, data_fvec1); + } + if (size - d > 0) { + bVec data_bvec = bVec::loadu(data + d, size - d); + auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); + bVec data2_bvec = bVec::loadu(data2 + d, size - d); + auto [data2_fvec0, data2_fvec1] = convert_to_float(data2_bvec); + if (size - d > fVec::size()) { + data_fvec0 = map_fun(data_fvec0, data2_fvec0); + data_fvec1 = map_fun(data_fvec1, data2_fvec1); + acc_fvec0 = red_fun(acc_fvec0, data_fvec0); + acc_fvec1 = fVec::set( + acc_fvec1, red_fun(acc_fvec1, data_fvec1), size - d - fVec::size()); + } else { + data_fvec0 = map_fun(data_fvec0, data2_fvec0); + acc_fvec0 = + fVec::set(acc_fvec0, red_fun(acc_fvec0, data_fvec0), size - d); + } + } + acc_fvec0 = red_fun(acc_fvec0, acc_fvec1); + return vec_reduce_all(red_fun, acc_fvec0); +} + +template < + typename scalar_t, + typename MapOp, + typename ReduceOp, + typename std::enable_if_t, int> = 0> +inline float map3_reduce_all( + const MapOp& map_fun, + const ReduceOp& red_fun, + const scalar_t* data, + const scalar_t* data2, + const scalar_t* data3, + int64_t size) { + using bVec = vec::Vectorized; + using fVec = vec::Vectorized; + if (size < bVec::size()) { + bVec data_bvec = bVec::loadu(data, size); + auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); + bVec data2_bvec = bVec::loadu(data2, size); + auto [data2_fvec0, data2_fvec1] = convert_to_float(data2_bvec); + bVec data3_bvec = bVec::loadu(data3, size); + auto [data3_fvec0, data3_fvec1] = convert_to_float(data3_bvec); + if (size > fVec::size()) { + data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0); + data_fvec1 = map_fun(data_fvec1, data2_fvec1, data3_fvec1); + data_fvec0 = fVec::set( + data_fvec0, red_fun(data_fvec0, data_fvec1), size - fVec::size()); + return vec_reduce_all(red_fun, data_fvec0, fVec::size()); + } else { + data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0); + return vec_reduce_all(red_fun, data_fvec0, size); + } + } + int64_t d = bVec::size(); + bVec acc_bvec = bVec::loadu(data); + auto [acc_fvec0, acc_fvec1] = convert_to_float(acc_bvec); + bVec acc2_bvec = bVec::loadu(data2); + auto [acc2_fvec0, acc2_fvec1] = convert_to_float(acc2_bvec); + bVec acc3_bvec = bVec::loadu(data3); + auto [acc3_fvec0, acc3_fvec1] = convert_to_float(acc3_bvec); + acc_fvec0 = map_fun(acc_fvec0, acc2_fvec0, acc3_fvec0); + acc_fvec1 = map_fun(acc_fvec1, acc2_fvec1, acc3_fvec1); + for (; d < size - (size % bVec::size()); d += bVec::size()) { + bVec data_bvec = bVec::loadu(data + d); + auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); + bVec data2_bvec = bVec::loadu(data2 + d); + auto [data2_fvec0, data2_fvec1] = convert_to_float(data2_bvec); + bVec data3_bvec = bVec::loadu(data3 + d); + auto [data3_fvec0, data3_fvec1] = convert_to_float(data3_bvec); + data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0); + data_fvec1 = map_fun(data_fvec1, data2_fvec1, data3_fvec1); + acc_fvec0 = red_fun(acc_fvec0, data_fvec0); + acc_fvec1 = red_fun(acc_fvec1, data_fvec1); + } + if (size - d > 0) { + bVec data_bvec = bVec::loadu(data + d, size - d); + auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); + bVec data2_bvec = bVec::loadu(data2 + d, size - d); + auto [data2_fvec0, data2_fvec1] = convert_to_float(data2_bvec); + bVec data3_bvec = bVec::loadu(data3 + d, size - d); + auto [data3_fvec0, data3_fvec1] = convert_to_float(data3_bvec); + if (size - d > fVec::size()) { + data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0); + data_fvec1 = map_fun(data_fvec1, data2_fvec1, data3_fvec1); + acc_fvec0 = red_fun(acc_fvec0, data_fvec0); + acc_fvec1 = fVec::set( + acc_fvec1, red_fun(acc_fvec1, data_fvec1), size - d - fVec::size()); + } else { + data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0); + acc_fvec0 = + fVec::set(acc_fvec0, red_fun(acc_fvec0, data_fvec0), size - d); + } + } + acc_fvec0 = red_fun(acc_fvec0, acc_fvec1); + return vec_reduce_all(red_fun, acc_fvec0); +} + +template < + typename scalar_t, + typename Op, + typename std::enable_if_t< + !(!detail::should_prefer_converting_through_float_v && + std::is_invocable_v>), + int> = 0> +inline void map( + const Op& vec_fun, + scalar_t* output_data, + const scalar_t* input_data, + int64_t size) { + using bVec = vec::Vectorized; + using fVec = vec::Vectorized; + int64_t d = 0; + for (; d < size - (size % bVec::size()); d += bVec::size()) { + bVec data_bvec = bVec::loadu(input_data + d); + auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); + fVec output_fvec0 = vec_fun(data_fvec0); + fVec output_fvec1 = vec_fun(data_fvec1); + bVec output_bvec = convert_from_float(output_fvec0, output_fvec1); + output_bvec.store(output_data + d); + } + if (size - d > 0) { + bVec data_bvec = bVec::loadu(input_data + d, size - d); + auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); + fVec output_fvec0 = vec_fun(data_fvec0); + fVec output_fvec1 = vec_fun(data_fvec1); + bVec output_bvec = convert_from_float(output_fvec0, output_fvec1); + output_bvec.store(output_data + d, size - d); + } +} + +template < + typename scalar_t, + typename Op, + typename std::enable_if_t, int> = 0> +inline void map( + const Op& vec_fun, + scalar_t* output_data, + const float* input_data, + int64_t size) { + using bVec = vec::Vectorized; + using fVec = vec::Vectorized; + int64_t d = 0; + for (; d < size - (size % bVec::size()); d += bVec::size()) { + fVec data_fvec0 = fVec::loadu(input_data + d); + fVec data_fvec1 = fVec::loadu(input_data + d + fVec::size()); + fVec output_fvec0 = vec_fun(data_fvec0); + fVec output_fvec1 = vec_fun(data_fvec1); + bVec output_bvec = convert_from_float(output_fvec0, output_fvec1); + output_bvec.store(output_data + d); + } + if (size - d > 0) { + fVec data_fvec0, data_fvec1; + if (size - d > fVec::size()) { + data_fvec0 = fVec::loadu(input_data + d); + data_fvec1 = + fVec::loadu(input_data + d + fVec::size(), size - d - fVec::size()); + } else { + // choose to align with behaviour of bVec::loadu(ptr, size), + // which leaves data_fvec1 uninitialized + data_fvec0 = fVec::loadu(input_data + d, size - d); + } + fVec output_fvec0 = vec_fun(data_fvec0); + fVec output_fvec1 = vec_fun(data_fvec1); + bVec output_bvec = convert_from_float(output_fvec0, output_fvec1); + output_bvec.store(output_data + d, size - d); + } +} + +template < + typename scalar_t, + typename Op, + typename std::enable_if_t< + !(!detail::should_prefer_converting_through_float_v && + std::is_invocable_v< + Op, + vec::Vectorized, + vec::Vectorized>), + int> = 0> +inline void map2( + const Op& vec_fun, + scalar_t* output_data, + const scalar_t* input_data, + const scalar_t* input_data2, + int64_t size) { + using bVec = vec::Vectorized; + using fVec = vec::Vectorized; + int64_t d = 0; + for (; d < size - (size % bVec::size()); d += bVec::size()) { + bVec data_bvec = bVec::loadu(input_data + d); + auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); + bVec data2_bvec = bVec::loadu(input_data2 + d); + auto [data2_fvec0, data2_fvec1] = convert_to_float(data2_bvec); + fVec output_fvec0 = vec_fun(data_fvec0, data2_fvec0); + fVec output_fvec1 = vec_fun(data_fvec1, data2_fvec1); + bVec output_bvec = convert_from_float(output_fvec0, output_fvec1); + output_bvec.store(output_data + d); + } + if (size - d > 0) { + bVec data_bvec = bVec::loadu(input_data + d, size - d); + auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); + bVec data2_bvec = bVec::loadu(input_data2 + d, size - d); + auto [data2_fvec0, data2_fvec1] = convert_to_float(data2_bvec); + fVec output_fvec0 = vec_fun(data_fvec0, data2_fvec0); + fVec output_fvec1 = vec_fun(data_fvec1, data2_fvec1); + bVec output_bvec = convert_from_float(output_fvec0, output_fvec1); + output_bvec.store(output_data + d, size - d); + } +} + +template < + typename scalar_t, + typename Op, + typename std::enable_if_t< + !(!detail::should_prefer_converting_through_float_v && + std::is_invocable_v< + Op, + vec::Vectorized, + vec::Vectorized, + vec::Vectorized>), + int> = 0> +inline void map3( + const Op& vec_fun, + scalar_t* output_data, + const scalar_t* input_data1, + const scalar_t* input_data2, + const scalar_t* input_data3, + int64_t size) { + using bVec = vec::Vectorized; + using fVec = vec::Vectorized; + int64_t d = 0; + for (; d < size - (size % bVec::size()); d += bVec::size()) { + bVec data1_bvec = bVec::loadu(input_data1 + d); + auto [data1_fvec0, data1_fvec1] = convert_to_float(data1_bvec); + bVec data2_bvec = bVec::loadu(input_data2 + d); + auto [data2_fvec0, data2_fvec1] = convert_to_float(data2_bvec); + bVec data3_bvec = bVec::loadu(input_data3 + d); + auto [data3_fvec0, data3_fvec1] = convert_to_float(data3_bvec); + fVec output_fvec0 = vec_fun(data1_fvec0, data2_fvec0, data3_fvec0); + fVec output_fvec1 = vec_fun(data1_fvec1, data2_fvec1, data3_fvec1); + bVec output_bvec = convert_from_float(output_fvec0, output_fvec1); + output_bvec.store(output_data + d); + } + if (size - d > 0) { + bVec data1_bvec = bVec::loadu(input_data1 + d, size - d); + auto [data1_fvec0, data1_fvec1] = convert_to_float(data1_bvec); + bVec data2_bvec = bVec::loadu(input_data2 + d, size - d); + auto [data2_fvec0, data2_fvec1] = convert_to_float(data2_bvec); + bVec data3_bvec = bVec::loadu(input_data3 + d, size - d); + auto [data3_fvec0, data3_fvec1] = convert_to_float(data3_bvec); + fVec output_fvec0 = vec_fun(data1_fvec0, data2_fvec0, data3_fvec0); + fVec output_fvec1 = vec_fun(data1_fvec1, data2_fvec1, data3_fvec1); + bVec output_bvec = convert_from_float(output_fvec0, output_fvec1); + output_bvec.store(output_data + d, size - d); + } +} + +template < + typename scalar_t, + typename Op, + typename std::enable_if_t< + !(!detail::should_prefer_converting_through_float_v && + std::is_invocable_v< + Op, + vec::Vectorized, + vec::Vectorized, + vec::Vectorized, + vec::Vectorized>), + int> = 0> +inline void map4( + const Op& vec_fun, + scalar_t* output_data, + const scalar_t* input_data1, + const scalar_t* input_data2, + const scalar_t* input_data3, + const scalar_t* input_data4, + int64_t size) { + using bVec = vec::Vectorized; + using fVec = vec::Vectorized; + int64_t d = 0; + for (; d < size - (size % bVec::size()); d += bVec::size()) { + bVec data1_bvec = bVec::loadu(input_data1 + d); + auto [data1_fvec0, data1_fvec1] = convert_to_float(data1_bvec); + bVec data2_bvec = bVec::loadu(input_data2 + d); + auto [data2_fvec0, data2_fvec1] = convert_to_float(data2_bvec); + bVec data3_bvec = bVec::loadu(input_data3 + d); + auto [data3_fvec0, data3_fvec1] = convert_to_float(data3_bvec); + bVec data4_bvec = bVec::loadu(input_data4 + d); + auto [data4_fvec0, data4_fvec1] = convert_to_float(data4_bvec); + fVec output_fvec0 = + vec_fun(data1_fvec0, data2_fvec0, data3_fvec0, data4_fvec0); + fVec output_fvec1 = + vec_fun(data1_fvec1, data2_fvec1, data3_fvec1, data4_fvec1); + bVec output_bvec = convert_from_float(output_fvec0, output_fvec1); + output_bvec.store(output_data + d); + } + if (size - d > 0) { + bVec data1_bvec = bVec::loadu(input_data1 + d, size - d); + auto [data1_fvec0, data1_fvec1] = convert_to_float(data1_bvec); + bVec data2_bvec = bVec::loadu(input_data2 + d, size - d); + auto [data2_fvec0, data2_fvec1] = convert_to_float(data2_bvec); + bVec data3_bvec = bVec::loadu(input_data3 + d, size - d); + auto [data3_fvec0, data3_fvec1] = convert_to_float(data3_bvec); + bVec data4_bvec = bVec::loadu(input_data4 + d, size - d); + auto [data4_fvec0, data4_fvec1] = convert_to_float(data4_bvec); + fVec output_fvec0 = + vec_fun(data1_fvec0, data2_fvec0, data3_fvec0, data4_fvec0); + fVec output_fvec1 = + vec_fun(data1_fvec1, data2_fvec1, data3_fvec1, data4_fvec1); + bVec output_bvec = convert_from_float(output_fvec0, output_fvec1); + output_bvec.store(output_data + d, size - d); + } +} + +} // namespace at::vec + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/intrinsics.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/intrinsics.h new file mode 100644 index 0000000000000000000000000000000000000000..fd3d3a65215450308a807f98d28b701f28e2ff22 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/intrinsics.h @@ -0,0 +1,6 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/sve/sve_helper.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/sve/sve_helper.h new file mode 100644 index 0000000000000000000000000000000000000000..60e0025a2d63d264c9baef2fce846ae400b73cc5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/sve/sve_helper.h @@ -0,0 +1,85 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include + +#include + +#if defined(CPU_CAPABILITY_SVE) + +// Define the data type of VLS(vector-length specific). +typedef svbool_t vls_pred_t + __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); +typedef svint8_t vls_int8_t + __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); +typedef svint16_t vls_int16_t + __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); +typedef svint32_t vls_int32_t + __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); +typedef svint64_t vls_int64_t + __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); +typedef svuint8_t vls_uint8_t + __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); +typedef svuint16_t vls_uint16_t + __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); +typedef svuint32_t vls_uint32_t + __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); +typedef svuint64_t vls_uint64_t + __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); +typedef svfloat16_t vls_float16_t + __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); +typedef svbfloat16_t vls_bfloat16_t + __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); +typedef svfloat32_t vls_float32_t + __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); +typedef svfloat64_t vls_float64_t + __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); + +#define ptrue svptrue_b8() +#define ZERO_S8 svdup_n_s8(0) +#define ZERO_S16 svdup_n_s16(0) +#define ZERO_S32 svdup_n_s32(0) +#define ZERO_S64 svdup_n_s64(0) +#define ZERO_U8 svdup_n_u8(0) +#define ZERO_U16 svdup_n_u16(0) +#define ZERO_U32 svdup_n_u32(0) +#define ZERO_U64 svdup_n_u64(0) +#define ZERO_F16 svdup_n_f16(0.f) +#define ZERO_F32 svdup_n_f32(0.f) +#define ZERO_F64 svdup_n_f64(0.0) +#define ONE_S8 svdup_n_s8(1) +#define ONE_S16 svdup_n_s16(1) +#define ONE_S32 svdup_n_s32(1) +#define ONE_S64 svdup_n_s64(1) +#define ONE_U8 svdup_n_u8(1) +#define ONE_U16 svdup_n_u16(1) +#define ONE_U32 svdup_n_u32(1) +#define ONE_U64 svdup_n_u64(1) +#define ONE_F16 svdup_n_f16(1.f) +#define ONE_BF16 svdup_n_bf16(1.f) +#define ONE_F32 svdup_n_f32(1.f) +#define ONE_F64 svdup_n_f64(1.0) +#define ALL_S8_TRUE_MASK svdup_n_s8(0xff) +#define ALL_S8_FALSE_MASK svdup_n_s8(0x0) +#define ALL_S16_TRUE_MASK svdup_n_s16(0xffff) +#define ALL_S16_FALSE_MASK svdup_n_s16(0x0) +#define ALL_S32_TRUE_MASK svdup_n_s32(0xffffffff) +#define ALL_S32_FALSE_MASK svdup_n_s32(0x0) +#define ALL_S64_TRUE_MASK svdup_n_s64(0xffffffffffffffff) +#define ALL_S64_FALSE_MASK svdup_n_s64(0x0) +#define ALL_U8_TRUE_MASK svdup_n_u8(0x01) +#define ALL_U8_FALSE_MASK svdup_n_u8(0x00) +#define ALL_F16_TRUE_MASK svreinterpret_f16_s16(ALL_S16_TRUE_MASK) +#define ALL_F16_FALSE_MASK svreinterpret_f16_s16(ALL_S16_FALSE_MASK) +#define ALL_BF16_TRUE_MASK svreinterpret_bf16_s16(ALL_S16_TRUE_MASK) +#define ALL_BF16_FALSE_MASK svreinterpret_bf16_s16(ALL_S16_FALSE_MASK) +#define ALL_F32_TRUE_MASK svreinterpret_f32_s32(ALL_S32_TRUE_MASK) +#define ALL_F32_FALSE_MASK svreinterpret_f32_s32(ALL_S32_FALSE_MASK) +#define ALL_F64_TRUE_MASK svreinterpret_f64_s64(ALL_S64_TRUE_MASK) +#define ALL_F64_FALSE_MASK svreinterpret_f64_s64(ALL_S64_FALSE_MASK) + +#endif // defined(CPU_CAPABILITY_SVE) + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/sve/vec_bfloat16.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/sve/vec_bfloat16.h new file mode 100644 index 0000000000000000000000000000000000000000..bb712e8d7ee510503f0a812fdcd4617b7678922a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/sve/vec_bfloat16.h @@ -0,0 +1,598 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include +#include +#include +namespace at { +namespace vec { +// Note [CPU_CAPABILITY namespace] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// This header, and all of its subheaders, will be compiled with +// different architecture flags for each supported set of vector +// intrinsics. So we need to make sure they aren't inadvertently +// linked together. We do this by declaring objects in an `inline +// namespace` which changes the name mangling, but can still be +// accessed as `at::vec`. +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_SVE256) && defined(__ARM_FEATURE_BF16) + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized { + private: + vls_bfloat16_t values; + + public: + using value_type = BFloat16; + using size_type = int; + + static constexpr size_type size() { + return VECTOR_WIDTH / sizeof(BFloat16); + } + + Vectorized(); + Vectorized(svbfloat16_t v) : values(v) {} + Vectorized(int val); + Vectorized(BFloat16 val); + + template < + typename... Args, + typename = std::enable_if_t<(sizeof...(Args) == size())>> + Vectorized(Args... vals) { + __at_align__ BFloat16 buffer[size()] = {vals...}; + values = svld1_bf16(ptrue, reinterpret_cast(buffer)); + } + + operator svbfloat16_t() const { + return values; + } + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask_) { + svbool_t mask = + svcmpeq_s16(ptrue, svreinterpret_s16_bf16(mask_), ALL_S16_TRUE_MASK); + return svsel_bf16(mask, b, a); + } + template + static Vectorized arange( + BFloat16 base = 0.f, + step_t step = static_cast(1)) { + __at_align__ BFloat16 buffer[size()]; + for (int64_t i = 0; i < size(); i++) { + buffer[i] = base + i * step; + } + return svld1_bf16(ptrue, reinterpret_cast(buffer)); + } + static Vectorized set( + const Vectorized& a, + const Vectorized& b, + int64_t count = size()) { + if (count == 0) { + return a; + } else if (count < size()) { + return svsel_bf16(svwhilelt_b16(0ull, count), b, a); + } + return b; + } + static Vectorized loadu(const void* ptr, int64_t count = size()) { + if (count == size()) + return svld1_bf16(ptrue, reinterpret_cast(ptr)); + svbool_t pg = svwhilelt_b16(0ull, count); + return svld1_bf16(pg, reinterpret_cast(ptr)); + } + void store(void* ptr, int64_t count = size()) const { + __at_align__ bfloat16_t tmp[size()]; + std::memset(tmp, 0, sizeof(tmp)); + if (count == size()) { + svst1_bf16(ptrue, reinterpret_cast(tmp), values); + } else { + svbool_t pg = svwhilelt_b16(0ull, count); + svst1_bf16(pg, reinterpret_cast(tmp), values); + } + std::memcpy( + reinterpret_cast(ptr), + reinterpret_cast(tmp), + count * sizeof(bfloat16_t)); + } + const BFloat16& operator[](int idx) const = delete; + BFloat16& operator[](int idx) = delete; + int64_t zero_mask() const { + int64_t mask = 0; + // returns an integer mask where all zero elements are translated to + // 1-bit and others are translated to 0-bit int64_t mask = 0; + __at_align__ int16_t mask_array[size()]; + + svbool_t svbool_mask = + svcmpeq_f16(ptrue, svreinterpret_f16_bf16(values), ZERO_F16); + svst1_s16( + ptrue, + mask_array, + svsel_s16(svbool_mask, ALL_S16_TRUE_MASK, ALL_S16_FALSE_MASK)); + for (int64_t i = 0; i < size(); ++i) { + if (mask_array[i]) + mask |= (1ull << i); + } + return mask; + } + Vectorized isnan() const; + bool has_inf_nan() const; + Vectorized map(BFloat16 (*f)(BFloat16)) const { + __at_align__ BFloat16 tmp[size()]; + store(tmp); + for (int64_t i = 0; i < size(); ++i) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + Vectorized abs() const { + auto mask = svdup_n_u16(0x7FFF); + auto vals = svreinterpret_u16_bf16(values); + vals = svand_u16_x(ptrue, vals, mask); + return svreinterpret_bf16_u16(vals); + } + Vectorized angle() const; + Vectorized real() const { + return values; + } + Vectorized imag() const { + return Vectorized(0.f); + } + Vectorized conj() const { + return values; + } + Vectorized acos() const; + Vectorized acosh() const; + Vectorized asin() const; + Vectorized atan() const; + Vectorized atanh() const; + Vectorized atan2(const Vectorized& b) const; + Vectorized copysign(const Vectorized& sign) const; + Vectorized erf() const; + Vectorized erfc() const; + Vectorized erfinv() const; + Vectorized exp() const; + Vectorized exp2() const; + Vectorized expm1() const; + Vectorized exp_u20() const { + return exp(); + } + Vectorized fexp_u20() const { + return exp(); + } + Vectorized fmod(const Vectorized& q) const; + Vectorized hypot(const Vectorized& b) const; + Vectorized i0() const; + Vectorized i0e() const; + Vectorized digamma() const; + Vectorized igamma(const Vectorized& x) const; + Vectorized igammac(const Vectorized& x) const; + Vectorized nextafter(const Vectorized& b) const; + Vectorized log() const; + Vectorized log2() const; + Vectorized log10() const; + Vectorized log1p() const; + Vectorized frac() const; + Vectorized sin() const; + Vectorized sinh() const; + Vectorized cos() const; + Vectorized cosh() const; + Vectorized ceil() const; + Vectorized floor() const; + Vectorized neg() const { + auto mask = svdup_n_u16(0x8000); + auto vals = svreinterpret_u16_bf16(values); + vals = sveor_u16_x(ptrue, vals, mask); + return svreinterpret_bf16_u16(vals); + } + Vectorized round() const; + Vectorized tan() const; + Vectorized tanh() const; + Vectorized trunc() const; + Vectorized lgamma() const; + Vectorized sqrt() const; + Vectorized reciprocal() const; + Vectorized rsqrt() const; + Vectorized pow(const Vectorized& b) const; + // Comparison using the _CMP_**_OQ predicate. + // `O`: get false if an operand is NaN + // `Q`: do not raise if an operand is NaN + Vectorized operator==(const Vectorized& other) const; + + Vectorized operator!=(const Vectorized& other) const; + + Vectorized operator<(const Vectorized& other) const; + + Vectorized operator<=(const Vectorized& other) const; + + Vectorized operator>(const Vectorized& other) const; + + Vectorized operator>=(const Vectorized& other) const; + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +#if defined(__GNUC__) && __GNUC__ == 14 +// Workaround for gcc-14.2.0 ICE during RTL pass: vregs when compiling for SVE +__attribute__((optimize("no-tree-vectorize"))) +#endif +inline std::tuple, Vectorized> +convert_bfloat16_float(const Vectorized& a) { + static_assert( + Vectorized::size() == 2 * Vectorized::size()); + auto zero = svreinterpret_bf16_f32(svdup_n_f32(0.0f)); + auto bf16_vec1 = svzip1_bf16(zero, a); + auto bf16_vec2 = svzip2_bf16(zero, a); + auto x1 = svreinterpret_f32_bf16(bf16_vec1); + auto x2 = svreinterpret_f32_bf16(bf16_vec2); + return {Vectorized(x1), Vectorized(x2)}; +} + +inline Vectorized convert_float_bfloat16( + const Vectorized& a, + const Vectorized& b) { + static_assert( + Vectorized::size() == 2 * Vectorized::size()); + svbfloat16_t x1 = svcvt_bf16_f32_z(ptrue, a); + svbfloat16_t x2 = svcvt_bf16_f32_z(ptrue, b); + return Vectorized(svuzp1_bf16(x1, x2)); +} + +inline void load_fp32_from_bf16(const BFloat16* data, Vectorized& out) { + __at_align__ float values[Vectorized::size()]; + for (const auto k : c10::irange(Vectorized::size())) { + values[k] = data[k]; + } + out = Vectorized::loadu(values); +} + +inline void load_fp32_from_bf16( + const BFloat16* data, + Vectorized& out1, + Vectorized& out2) { + Vectorized bf16_vec = Vectorized::loadu(data); + auto floats = convert_bfloat16_float(bf16_vec); + out1 = std::get<0>(floats); + out2 = std::get<1>(floats); +} + +template +Vectorized binary_operator_via_float( + Op op, + const Vectorized& a, + const Vectorized& b) { + const auto [a_float_low, a_float_high] = convert_bfloat16_float(a); + const auto [b_float_low, b_float_high] = convert_bfloat16_float(b); + return convert_float_bfloat16( + op(a_float_low, b_float_low), op(a_float_high, b_float_high)); +} + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return binary_operator_via_float(std::plus>(), a, b); +} + +template <> +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return binary_operator_via_float(std::minus>(), a, b); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return binary_operator_via_float(std::multiplies>(), a, b); +} + +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return binary_operator_via_float(std::divides>(), a, b); +} + +inline Vectorized::Vectorized() { + auto vals_f = svdup_n_f32(0); + values = convert_float_bfloat16(vals_f, vals_f); +} + +inline Vectorized::Vectorized(int val) { + auto vals_f = svdup_n_f32(val); + values = convert_float_bfloat16(vals_f, vals_f); +} + +inline Vectorized::Vectorized(BFloat16 val) { + auto vals_f = svdup_n_f32((float)val); + values = convert_float_bfloat16(vals_f, vals_f); +} + +bool inline Vectorized::has_inf_nan() const { + auto [v1, v2] = convert_bfloat16_float(values); + return v1.has_inf_nan() || v2.has_inf_nan(); +} +// frac. Implement this here so we can use subtraction +Vectorized inline Vectorized::frac() const { + return *this - this->trunc(); +} + +#define DEFINE_BF16_FUNC_VIA_FLOAT(func_name) \ + Vectorized inline Vectorized::func_name() const { \ + auto [v1, v2] = convert_bfloat16_float(*this); \ + v1 = v1.func_name(); \ + v2 = v2.func_name(); \ + return convert_float_bfloat16(v1, v2); \ + } + +#define DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(func_name) \ + Vectorized inline Vectorized::func_name( \ + const Vectorized& a) const { \ + auto [v1, v2] = convert_bfloat16_float(*this); \ + auto [v3, v4] = convert_bfloat16_float(a); \ + v1 = v1.func_name(v3); \ + v2 = v2.func_name(v4); \ + return convert_float_bfloat16(v1, v2); \ + } + +DEFINE_BF16_FUNC_VIA_FLOAT(isnan) +DEFINE_BF16_FUNC_VIA_FLOAT(angle) +DEFINE_BF16_FUNC_VIA_FLOAT(acos) +DEFINE_BF16_FUNC_VIA_FLOAT(acosh) +DEFINE_BF16_FUNC_VIA_FLOAT(asin) +DEFINE_BF16_FUNC_VIA_FLOAT(atan) +DEFINE_BF16_FUNC_VIA_FLOAT(atanh) +DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(atan2) +DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(copysign) +DEFINE_BF16_FUNC_VIA_FLOAT(erf) +DEFINE_BF16_FUNC_VIA_FLOAT(erfc) +DEFINE_BF16_FUNC_VIA_FLOAT(exp) +DEFINE_BF16_FUNC_VIA_FLOAT(exp2) +DEFINE_BF16_FUNC_VIA_FLOAT(expm1) +DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(fmod) +DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(hypot) +DEFINE_BF16_FUNC_VIA_FLOAT(i0) +DEFINE_BF16_FUNC_VIA_FLOAT(i0e) +DEFINE_BF16_FUNC_VIA_FLOAT(digamma) +DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(igamma) +DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(igammac) +DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(nextafter) +DEFINE_BF16_FUNC_VIA_FLOAT(log) +DEFINE_BF16_FUNC_VIA_FLOAT(log2) +DEFINE_BF16_FUNC_VIA_FLOAT(log10) +DEFINE_BF16_FUNC_VIA_FLOAT(log1p) +DEFINE_BF16_FUNC_VIA_FLOAT(sin) +DEFINE_BF16_FUNC_VIA_FLOAT(sinh) +DEFINE_BF16_FUNC_VIA_FLOAT(cos) +DEFINE_BF16_FUNC_VIA_FLOAT(cosh) +DEFINE_BF16_FUNC_VIA_FLOAT(ceil) +DEFINE_BF16_FUNC_VIA_FLOAT(floor) +DEFINE_BF16_FUNC_VIA_FLOAT(round) +DEFINE_BF16_FUNC_VIA_FLOAT(tan) +DEFINE_BF16_FUNC_VIA_FLOAT(tanh) +DEFINE_BF16_FUNC_VIA_FLOAT(trunc) +DEFINE_BF16_FUNC_VIA_FLOAT(lgamma) +DEFINE_BF16_FUNC_VIA_FLOAT(sqrt) +DEFINE_BF16_FUNC_VIA_FLOAT(reciprocal) +DEFINE_BF16_FUNC_VIA_FLOAT(rsqrt) +DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(pow) + +Vectorized inline Vectorized::operator==( + const Vectorized& other) const { + auto [f1, f2] = convert_bfloat16_float(values); + auto [f3, f4] = convert_bfloat16_float(other); + svbool_t mask1 = svcmpeq_f32(ptrue, f1, f3); + svbool_t mask2 = svcmpeq_f32(ptrue, f2, f4); + auto res1 = svsel_f32(mask1, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK); + auto res2 = svsel_f32(mask2, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK); + + auto bf16_1 = svreinterpret_bf16_f32(res1); + auto bf16_2 = svreinterpret_bf16_f32(res2); + return svuzp1_bf16(bf16_1, bf16_2); +} +Vectorized inline Vectorized::operator!=( + const Vectorized& other) const { + auto [f1, f2] = convert_bfloat16_float(values); + auto [f3, f4] = convert_bfloat16_float(other); + svbool_t mask1 = svcmpne_f32(ptrue, f1, f3); + svbool_t mask2 = svcmpne_f32(ptrue, f2, f4); + auto res1 = svsel_f32(mask1, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK); + auto res2 = svsel_f32(mask2, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK); + + auto bf16_1 = svreinterpret_bf16_f32(res1); + auto bf16_2 = svreinterpret_bf16_f32(res2); + return svuzp1_bf16(bf16_1, bf16_2); +} +Vectorized inline Vectorized::operator>( + const Vectorized& other) const { + auto [v1, v2] = convert_bfloat16_float(*this); + auto [v3, v4] = convert_bfloat16_float(other); + return convert_float_bfloat16(v1 > v3, v2 > v4); +} +Vectorized inline Vectorized::operator>=( + const Vectorized& other) const { + auto [v1, v2] = convert_bfloat16_float(*this); + auto [v3, v4] = convert_bfloat16_float(other); + return convert_float_bfloat16(v1 >= v3, v2 >= v4); +} +Vectorized inline Vectorized::operator<( + const Vectorized& other) const { + auto [v1, v2] = convert_bfloat16_float(*this); + auto [v3, v4] = convert_bfloat16_float(other); + return convert_float_bfloat16(v1 < v3, v2 < v4); +} +Vectorized inline Vectorized::operator<=( + const Vectorized& other) const { + auto [v1, v2] = convert_bfloat16_float(*this); + auto [v3, v4] = convert_bfloat16_float(other); + return convert_float_bfloat16(v1 <= v3, v2 <= v4); +} + +// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return binary_operator_via_float( + static_cast (*)( + const Vectorized&, const Vectorized&)>(&maximum), + a, + b); +} + +// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return binary_operator_via_float( + static_cast (*)( + const Vectorized&, const Vectorized&)>(&minimum), + a, + b); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max) { + return binary_operator_via_float( + static_cast (*)( + const Vectorized&, const Vectorized&)>(&clamp_max), + a, + max); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min) { + return binary_operator_via_float( + static_cast (*)( + const Vectorized&, const Vectorized&)>(&clamp_min), + a, + min); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min, + const Vectorized& max) { + return clamp_min(clamp_max(a, max), min); +} + +template <> +Vectorized inline operator&( + const Vectorized& a, + const Vectorized& b) { + return svreinterpret_bf16_u16( + svand_u16_x(ptrue, svreinterpret_u16_bf16(a), svreinterpret_u16_bf16(b))); +} + +template <> +Vectorized inline operator|( + const Vectorized& a, + const Vectorized& b) { + return svreinterpret_bf16_u16( + svorr_u16_x(ptrue, svreinterpret_u16_bf16(a), svreinterpret_u16_bf16(b))); +} + +template <> +Vectorized inline operator^( + const Vectorized& a, + const Vectorized& b) { + return svreinterpret_bf16_u16( + sveor_u16_x(ptrue, svreinterpret_u16_bf16(a), svreinterpret_u16_bf16(b))); +} + +Vectorized inline Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1.0f); +} + +Vectorized inline Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1.0f); +} + +Vectorized inline Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1.0f); +} + +Vectorized inline Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1.0f); +} + +Vectorized inline Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1.0f); +} + +Vectorized inline Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1.0f); +} + +template <> +inline void convert(const BFloat16* src, BFloat16* dst, int64_t n) { + const int64_t fraction = n % Vectorized::size(); +#pragma unroll + for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { + svst1_bf16( + ptrue, + const_cast(reinterpret_cast(dst)) + i, + svldnt1_bf16( + ptrue, + const_cast(reinterpret_cast(src)) + + i)); + } +#pragma unroll + for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { + svbool_t pg = svwhilelt_b16(i, n); + svst1_bf16( + pg, + const_cast(reinterpret_cast(dst)) + i, + svldnt1_bf16( + pg, + const_cast(reinterpret_cast(src)) + + i)); + } +} + +template <> +Vectorized inline fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return a * b + c; +} + +#endif // defined(CPU_CAPABILITY_SVE) && defined(__ARM_FEATURE_BF16) + +} // namespace CPU_CAPABILITY +} // namespace vec +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/sve/vec_common_sve.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/sve/vec_common_sve.h new file mode 100644 index 0000000000000000000000000000000000000000..d11be323e05416cb0d7ef821e8bd0dde7ad1d0c7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/sve/vec_common_sve.h @@ -0,0 +1,241 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with SVE] + +#include + +#include +#include + +#if defined(CPU_CAPABILITY_SVE) +#include +#include +#include +#include +#include +#endif + +namespace at::vec { +// Note [CPU_CAPABILITY namespace] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// This header, and all of its subheaders, will be compiled with +// different architecture flags for each supported set of vector +// intrinsics. So we need to make sure they aren't inadvertently +// linked together. We do this by declaring objects in an `inline +// namespace` which changes the name mangling, but can still be +// accessed as `at::vec`. +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_SVE) + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CAST ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +#define DEFINE_SVE_CAST(t1_t, t1_prefix, t2_t, t2_prefix) \ + template <> \ + inline Vectorized cast(const Vectorized& src) { \ + return svreinterpret_##t1_prefix##_##t2_prefix(src); \ + } \ + template <> \ + inline Vectorized cast(const Vectorized& src) { \ + return svreinterpret_##t2_prefix##_##t1_prefix(src); \ + } + +DEFINE_SVE_CAST(int64_t, s64, double, f64) +DEFINE_SVE_CAST(int32_t, s32, double, f64) +DEFINE_SVE_CAST(int16_t, s16, double, f64) +DEFINE_SVE_CAST(int64_t, s64, float, f32) +DEFINE_SVE_CAST(int32_t, s32, float, f32) +DEFINE_SVE_CAST(int16_t, s16, float, f32) +DEFINE_SVE_CAST(float, f32, double, f64) + +#ifdef __ARM_FEATURE_BF16 +DEFINE_SVE_CAST(int64_t, s64, c10::BFloat16, bf16) +DEFINE_SVE_CAST(int32_t, s32, c10::BFloat16, bf16) +DEFINE_SVE_CAST(int16_t, s16, c10::BFloat16, bf16) +#endif // __ARM_FEATURE_BF16 + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template +std::enable_if_t< + scale == 1 || scale == 2 || scale == 4 || scale == 8, + Vectorized< + double>> inline gather(const double* base_addr, const Vectorized& vindex_) { + svint64_t vindex = + svasrd_n_s64_x(ptrue, svmul_s64_x(ptrue, vindex_, svdup_n_s64(scale)), 3); + return svld1_gather_s64index_f64(ptrue, base_addr, vindex); +} + +template +std::enable_if_t< + scale == 1 || scale == 2 || scale == 4 || scale == 8, + Vectorized< + float>> inline gather(const float* base_addr, const Vectorized& vindex_) { + svint32_t vindex = + svasrd_n_s32_x(ptrue, svmul_s32_x(ptrue, vindex_, svdup_n_s32(scale)), 2); + return svld1_gather_s32index_f32(ptrue, base_addr, vindex); +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MASK GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template +std:: + enable_if_t> inline mask_gather( + const Vectorized& src, + const double* base_addr, + const Vectorized& vindex_, + const Vectorized& mask_) { + svbool_t mask = + svcmpeq_s64(ptrue, svreinterpret_s64_f64(mask_), ALL_S64_TRUE_MASK); + svint64_t vindex = + svasrd_n_s64_x(ptrue, svmul_s64_x(ptrue, vindex_, svdup_n_s64(scale)), 3); + return svsel_f64( + mask, svld1_gather_s64index_f64(mask, base_addr, vindex), src); +} + +template +std:: + enable_if_t> inline mask_gather( + const Vectorized& src, + const float* base_addr, + const Vectorized& vindex_, + const Vectorized& mask_) { + svbool_t mask = + svcmpeq_s32(ptrue, svreinterpret_s32_f32(mask_), ALL_S32_TRUE_MASK); + svint32_t vindex = + svasrd_n_s32_x(ptrue, svmul_s32_x(ptrue, vindex_, svdup_n_s32(scale)), 2); + return svsel_f32( + mask, svld1_gather_s32index_f32(mask, base_addr, vindex), src); +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CONVERT ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Only works for inputs in the range: [-2^51, 2^51] +// From: https://stackoverflow.com/a/41148578 +template <> +Vectorized inline convert_to_int_of_same_size( + const Vectorized& src) { + svfloat64_t x = svadd_f64_x(ptrue, src, svdup_n_f64(0x0018000000000000)); + return svsub_s64_x( + ptrue, + svreinterpret_s64_f64(x), + svreinterpret_s64_f64(svdup_n_f64(0x0018000000000000))); +} + +template <> +Vectorized inline convert_to_int_of_same_size( + const Vectorized& src) { + return svcvt_s32_f32_x(ptrue, src); +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ INTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template <> +std::pair, Vectorized> inline interleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, a1, a3, a3} + // b = {b0, b1, b2, b3} + // group cols crossing lanes: + // return {a0, b0, a1, b1} + // {a2, b2, a3, b3} + return std::make_pair( + Vectorized(svzip1_f64(a, b)), + Vectorized(svzip2_f64(a, b))); +} + +template <> +std::pair, Vectorized> inline interleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, a1, a2, a3, a4, a5, a6, a7} + // b = {b0, b1, b2, b3, b4, b5, b6, b7} + // group cols crossing lanes: + // return {a0, b0, a1, b1, a2, b2, a3, b3} + // {a4, b4, a5, b5, a6, b6, a7, b7} + return std::make_pair( + Vectorized(svzip1_f32(a, b)), Vectorized(svzip2_f32(a, b))); +} + +#ifdef __ARM_FEATURE_BF16 +template <> +std::pair< + Vectorized, + Vectorized> inline interleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, a1, a2, a3, a4, a5, a6, a7} + // b = {b0, b1, b2, b3, b4, b5, b6, b7} + // group cols crossing lanes: + // return {a0, b0, a1, b1, a2, b2, a3, b3} + // {a4, b4, a5, b5, a6, b6, a7, b7} + return std::make_pair( + Vectorized(svzip1_bf16(a, b)), + Vectorized(svzip2_bf16(a, b))); +} +#endif // __ARM_FEATURE_BF16 + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ DEINTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template <> +std::pair, Vectorized> inline deinterleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, b0, a1, b1} + // b = {a2, b2, a3, b3} + // swap lanes: + // return {a0, a1, a2, a3} + // {b0, b1, b2, b3} + return std::make_pair( + Vectorized(svuzp1_f64(a, b)), + Vectorized(svuzp2_f64(a, b))); +} + +template <> +std::pair, Vectorized> inline deinterleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, b0, a1, b1, a2, b2, a3, b3} + // b = {a4, b4, a5, b5, a6, b6, a7, b7} + // swap lanes: + // return {a0, a1, a2, a3, a4, a5, a6, a7} + // {b0, b1, b2, b3, b4, b5, b6, b7} + return std::make_pair( + Vectorized(svuzp1_f32(a, b)), Vectorized(svuzp2_f32(a, b))); +} + +#ifdef __ARM_FEATURE_BF16 +template <> +std::pair< + Vectorized, + Vectorized> inline deinterleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, b0, a1, b1, a2, b2, a3, b3} + // b = {a4, b4, a5, b5, a6, b6, a7, b7} + // swap lanes: + // return {a0, a1, a2, a3, a4, a5, a6, a7} + // {b0, b1, b2, b3, b4, b5, b6, b7} + return std::make_pair( + Vectorized(svuzp1_bf16((svbfloat16_t)a, (svbfloat16_t)b)), + Vectorized(svuzp2_bf16((svbfloat16_t)a, (svbfloat16_t)b))); +} +#endif // __ARM_FEATURE_BF16 + +#endif // defined(CPU_CAPABILITY_SVE) + +} // namespace CPU_CAPABILITY +} // namespace at::vec + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/sve/vec_double.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/sve/vec_double.h new file mode 100644 index 0000000000000000000000000000000000000000..8abd6d275e80db7658c8c187ccc78031b6c600b5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/sve/vec_double.h @@ -0,0 +1,622 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#if defined(__aarch64__) && defined(AT_BUILD_ARM_VEC256_WITH_SLEEF) +#include +#define USE_SLEEF(sleef_code, non_sleef_code) sleef_code +#else +#define USE_SLEEF(sleef_code, non_sleef_code) non_sleef_code +#endif + +namespace at::vec { +// Note [CPU_CAPABILITY namespace] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// This header, and all of its subheaders, will be compiled with +// different architecture flags for each supported set of vector +// intrinsics. So we need to make sure they aren't inadvertently +// linked together. We do this by declaring objects in an `inline +// namespace` which changes the name mangling, but can still be +// accessed as `at::vec`. +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_SVE) + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized { + private: + vls_float64_t values; + + public: + using value_type = double; + using size_type = int; + static constexpr size_type size() { + return VECTOR_WIDTH / sizeof(double); + } + Vectorized() { + values = svdup_n_f64(0); + } + Vectorized(svfloat64_t v) : values(v) {} + Vectorized(double val) { + values = svdup_n_f64(val); + } + template < + typename... Args, + typename = std::enable_if_t<(sizeof...(Args) == size())>> + Vectorized(Args... vals) { + __at_align__ double buffer[size()] = {vals...}; + values = svld1_f64(ptrue, buffer); + } + operator svfloat64_t() const { + return values; + } + template + static Vectorized blend( + const Vectorized& a, + const Vectorized& b) { + // Build an array of flags: each element is 1 if the corresponding bit in + // 'mask' is set, 0 otherwise. + __at_align__ int64_t flag_arr[size()]; + for (int i = 0; i < size(); i++) { + flag_arr[i] = (mask & (1ULL << i)) ? 1 : 0; + } + // Load the flag array into an SVE int64 vector. + svint64_t int_mask = svld1_s64(svptrue_b64(), flag_arr); + // Compare each lane of int_mask to 0; returns an svbool_t predicate where + // true indicates a nonzero flag. + svbool_t blend_mask = svcmpne_n_s64(svptrue_b64(), int_mask, 0); + + // Use svsel to select elements from b where the predicate is true, else + // from a. + svfloat64_t result = svsel(blend_mask, b.values, a.values); + return Vectorized(result); + } + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask_) { + svbool_t mask = + svcmpeq_s64(ptrue, svreinterpret_s64_f64(mask_), ALL_S64_TRUE_MASK); + return svsel_f64(mask, b, a); + } + template + static Vectorized arange( + double base = 0., + step_t step = static_cast(1)) { + __at_align__ double buffer[size()]; + for (int64_t i = 0; i < size(); i++) { + buffer[i] = base + i * step; + } + return svld1_f64(ptrue, buffer); + } + static Vectorized set( + const Vectorized& a, + const Vectorized& b, + int64_t count = size()) { + if (count == 0) { + return a; + } else if (count < size()) { + return svsel_f64(svwhilelt_b64(0ull, count), b, a); + } + return b; + } + static Vectorized loadu(const void* ptr, int64_t count = size()) { + if (count == size()) + return svld1_f64(ptrue, reinterpret_cast(ptr)); + svbool_t pg = svwhilelt_b64(0ull, count); + return svld1_f64(pg, reinterpret_cast(ptr)); + } + void store(void* ptr, int64_t count = size()) const { + if (count == size()) { + svst1_f64(ptrue, reinterpret_cast(ptr), values); + } else { + svbool_t pg = svwhilelt_b64(0ull, count); + svst1_f64(pg, reinterpret_cast(ptr), values); + } + } + const double& operator[](int idx) const = delete; + double& operator[](int idx) = delete; + int64_t zero_mask() const { + // returns an integer mask where all zero elements are translated to 1-bit + // and others are translated to 0-bit + int64_t mask = 0; + __at_align__ int64_t mask_array[size()]; + + svbool_t svbool_mask = svcmpeq_f64(ptrue, values, ZERO_F64); + svst1_s64( + ptrue, + mask_array, + svsel_s64(svbool_mask, ALL_S64_TRUE_MASK, ALL_S64_FALSE_MASK)); + for (int64_t i = 0; i < size(); ++i) { + if (mask_array[i]) + mask |= (1ull << i); + } + return mask; + } + Vectorized isnan() const { + // NaN check + svbool_t mask = svcmpuo_f64(ptrue, values, ZERO_F64); + return svsel_f64(mask, ALL_F64_TRUE_MASK, ALL_F64_FALSE_MASK); + } + bool has_inf_nan() const { + return svptest_any( + ptrue, + svcmpuo_f64(ptrue, svsub_f64_x(ptrue, values, values), ZERO_F64)); + } + Vectorized map(double (*f)(double)) const { + __at_align__ double tmp[size()]; + store(tmp); + for (int64_t i = 0; i < size(); ++i) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + Vectorized abs() const { + return svabs_f64_x(ptrue, values); + } + Vectorized angle() const { + const auto nan_vec = svdup_n_f64(NAN); + const auto nan_mask = svcmpuo_f64(ptrue, values, ZERO_F64); + const auto pi = svdup_n_f64(c10::pi); + + const auto neg_mask = svcmplt_f64(ptrue, values, ZERO_F64); + auto angle = svsel_f64(neg_mask, pi, ZERO_F64); + angle = svsel_f64(nan_mask, nan_vec, angle); + return angle; + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return Vectorized(0.0); + } + Vectorized conj() const { + return *this; + } + Vectorized acos() const { + return USE_SLEEF( + Vectorized(Sleef_acosdx_u10sve(values)), map(std::acos)); + } + Vectorized acosh() const { + return USE_SLEEF( + Vectorized(Sleef_acoshdx_u10sve(values)), map(std::acosh)); + } + Vectorized asin() const { + return USE_SLEEF( + Vectorized(Sleef_asindx_u10sve(values)), map(std::asin)); + } + Vectorized asinh() const { + return USE_SLEEF( + Vectorized(Sleef_asinhdx_u10sve(values)), map(std::asinh)); + } + Vectorized atan() const { + return USE_SLEEF( + Vectorized(Sleef_atandx_u10sve(values)), map(std::atan)); + } + Vectorized atanh() const { + return USE_SLEEF( + Vectorized(Sleef_atanhdx_u10sve(values)), map(std::atanh)); + } + Vectorized atan2(const Vectorized& b) const {USE_SLEEF( + { return Vectorized(Sleef_atan2dx_u10sve(values, b)); }, + { + __at_align__ double tmp[size()]; + __at_align__ double tmp_b[size()]; + store(tmp); + b.store(tmp_b); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = std::atan2(tmp[i], tmp_b[i]); + } + return loadu(tmp); + })} Vectorized copysign(const Vectorized& sign) const { + USE_SLEEF( + { return Vectorized(Sleef_copysigndx_sve(values, sign)); }, + { + __at_align__ double tmp[size()]; + __at_align__ double tmp_sign[size()]; + store(tmp); + sign.store(tmp_sign); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = std::copysign(tmp[i], tmp_sign[i]); + } + return loadu(tmp); + })} Vectorized erf() const { + return USE_SLEEF( + Vectorized(Sleef_erfdx_u10sve(values)), map(std::erf)); + } + Vectorized erfc() const { + return USE_SLEEF( + Vectorized(Sleef_erfcdx_u15sve(values)), map(std::erfc)); + } + Vectorized erfinv() const { + return map(calc_erfinv); + } + Vectorized exp() const { + return USE_SLEEF( + Vectorized(Sleef_expdx_u10sve(values)), map(std::exp)); + } + Vectorized exp2() const { + return USE_SLEEF( + Vectorized(Sleef_exp2dx_u10sve(values)), map(std::exp2)); + } + Vectorized expm1() const { + return USE_SLEEF( + Vectorized(Sleef_expm1dx_u10sve(values)), map(std::expm1)); + } + Vectorized exp_u20() const { + return exp(); + } + Vectorized fexp_u20() const { + return exp(); + } + Vectorized fmod(const Vectorized& q) const {USE_SLEEF( + { return Vectorized(Sleef_fmoddx_sve(values, q)); }, + { + __at_align__ double tmp[size()]; + __at_align__ double tmp_q[size()]; + store(tmp); + q.store(tmp_q); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = std::fmod(tmp[i], tmp_q[i]); + } + return loadu(tmp); + })} Vectorized hypot(const Vectorized& b) const { + USE_SLEEF( + { return Vectorized(Sleef_hypotdx_u05sve(values, b)); }, + { + __at_align__ double tmp[size()]; + __at_align__ double tmp_b[size()]; + store(tmp); + b.store(tmp_b); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = std::hypot(tmp[i], tmp_b[i]); + } + return loadu(tmp); + })} Vectorized i0() const { + return map(calc_i0); + } + Vectorized i0e() const { + return map(calc_i0e); + } + Vectorized digamma() const { + return map(calc_digamma); + } + Vectorized igamma(const Vectorized& x) const { + __at_align__ double tmp[size()]; + __at_align__ double tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = calc_igamma(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } + Vectorized igammac(const Vectorized& x) const { + __at_align__ double tmp[size()]; + __at_align__ double tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = calc_igammac(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } + Vectorized nextafter(const Vectorized& b) const {USE_SLEEF( + { return Vectorized(Sleef_nextafterdx_sve(values, b)); }, + { + __at_align__ double tmp[size()]; + __at_align__ double tmp_b[size()]; + store(tmp); + b.store(tmp_b); + for (int64_t i = 0; i < size(); ++i) { + tmp[i] = std::nextafter(tmp[i], tmp_b[i]); + } + return loadu(tmp); + })} Vectorized log() const { + return USE_SLEEF( + Vectorized(Sleef_logdx_u10sve(values)), map(std::log)); + } + Vectorized log2() const { + return USE_SLEEF( + Vectorized(Sleef_log2dx_u10sve(values)), map(std::log2)); + } + Vectorized log10() const { + return USE_SLEEF( + Vectorized(Sleef_log10dx_u10sve(values)), map(std::log10)); + } + Vectorized log1p() const { + return USE_SLEEF( + Vectorized(Sleef_log1pdx_u10sve(values)), map(std::log1p)); + } + Vectorized frac() const; + Vectorized sin() const { + return USE_SLEEF( + Vectorized(Sleef_sindx_u10sve(values)), map(std::sin)); + } + Vectorized sinh() const { + return USE_SLEEF( + Vectorized(Sleef_sinhdx_u10sve(values)), map(std::sinh)); + } + Vectorized cos() const { + return USE_SLEEF( + Vectorized(Sleef_cosdx_u10sve(values)), map(std::cos)); + } + Vectorized cosh() const { + return USE_SLEEF( + Vectorized(Sleef_coshdx_u10sve(values)), map(std::cosh)); + } + Vectorized ceil() const { + return svrintp_f64_x(ptrue, values); + } + Vectorized floor() const { + return svrintm_f64_x(ptrue, values); + } + Vectorized neg() const { + return svneg_f64_x(ptrue, values); + } + Vectorized round() const { + return svrinti_f64_x(ptrue, values); + } + Vectorized tan() const { + return USE_SLEEF( + Vectorized(Sleef_tandx_u10sve(values)), map(std::tan)); + } + Vectorized tanh() const { + return USE_SLEEF( + Vectorized(Sleef_tanhdx_u10sve(values)), map(std::tanh)); + } + Vectorized trunc() const { + return svrintz_f64_x(ptrue, values); + } + Vectorized lgamma() const { + return USE_SLEEF( + Vectorized(Sleef_lgammadx_u10sve(values)), map(std::lgamma)); + } + Vectorized sqrt() const { + return svsqrt_f64_x(ptrue, values); + } + Vectorized reciprocal() const { + return svdivr_f64_x(ptrue, values, ONE_F64); + } + Vectorized rsqrt() const { + return svdivr_f64_x(ptrue, svsqrt_f64_x(ptrue, values), ONE_F64); + } + Vectorized pow(const Vectorized& b) const {USE_SLEEF( + { return Vectorized(Sleef_powdx_u10sve(values, b)); }, + { + __at_align__ double tmp[size()]; + __at_align__ double tmp_b[size()]; + store(tmp); + b.store(tmp_b); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = std::pow(tmp[i], tmp_b[i]); + } + return loadu(tmp); + })} // Comparison using the _CMP_**_OQ predicate. + // `O`: get false if an operand is NaN + // `Q`: do not raise if an operand is NaN + Vectorized operator==(const Vectorized& other) const { + svbool_t mask = svcmpeq_f64(ptrue, values, other); + return svsel_f64(mask, ALL_F64_TRUE_MASK, ALL_F64_FALSE_MASK); + } + + Vectorized operator!=(const Vectorized& other) const { + svbool_t mask = svcmpne_f64(ptrue, values, other); + return svsel_f64(mask, ALL_F64_TRUE_MASK, ALL_F64_FALSE_MASK); + } + + Vectorized operator<(const Vectorized& other) const { + svbool_t mask = svcmplt_f64(ptrue, values, other); + return svsel_f64(mask, ALL_F64_TRUE_MASK, ALL_F64_FALSE_MASK); + } + + Vectorized operator<=(const Vectorized& other) const { + svbool_t mask = svcmple_f64(ptrue, values, other); + return svsel_f64(mask, ALL_F64_TRUE_MASK, ALL_F64_FALSE_MASK); + } + + Vectorized operator>(const Vectorized& other) const { + svbool_t mask = svcmpgt_f64(ptrue, values, other); + return svsel_f64(mask, ALL_F64_TRUE_MASK, ALL_F64_FALSE_MASK); + } + + Vectorized operator>=(const Vectorized& other) const { + svbool_t mask = svcmpge_f64(ptrue, values, other); + return svsel_f64(mask, ALL_F64_TRUE_MASK, ALL_F64_FALSE_MASK); + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return svadd_f64_x(ptrue, a, b); +} + +template <> +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return svsub_f64_x(ptrue, a, b); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return svmul_f64_x(ptrue, a, b); +} + +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return svdiv_f64_x(ptrue, a, b); +} + +// frac. Implement this here so we can use subtraction +Vectorized inline Vectorized::frac() const { + return *this - this->trunc(); +} + +// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return svmax_f64_x(ptrue, a, b); +} + +// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return svmin_f64_x(ptrue, a, b); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min, + const Vectorized& max) { + return svmin_f64_x(ptrue, max, svmax_f64_x(ptrue, min, a)); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max) { + return svmin_f64_x(ptrue, max, a); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min) { + return svmax_f64_x(ptrue, min, a); +} + +template <> +Vectorized inline operator&( + const Vectorized& a, + const Vectorized& b) { + return svreinterpret_f64_s64( + svand_s64_x(ptrue, svreinterpret_s64_f64(a), svreinterpret_s64_f64(b))); +} + +template <> +Vectorized inline operator|( + const Vectorized& a, + const Vectorized& b) { + return svreinterpret_f64_s64( + svorr_s64_x(ptrue, svreinterpret_s64_f64(a), svreinterpret_s64_f64(b))); +} + +template <> +Vectorized inline operator^( + const Vectorized& a, + const Vectorized& b) { + return svreinterpret_f64_s64( + sveor_s64_x(ptrue, svreinterpret_s64_f64(a), svreinterpret_s64_f64(b))); +} + +Vectorized inline Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1.0); +} + +Vectorized inline Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1.0); +} + +Vectorized inline Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1.0); +} + +Vectorized inline Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1.0); +} + +Vectorized inline Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1.0); +} + +Vectorized inline Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1.0); +} + +template <> +inline void convert(const double* src, double* dst, int64_t n) { + const int64_t fraction = n % Vectorized::size(); +#pragma unroll + for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { + svst1_f64(ptrue, dst + i, svldnt1_f64(ptrue, src + i)); + } +#pragma unroll + for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { + svbool_t pg = svwhilelt_b64(i, n); + svst1_f64(pg, dst + i, svldnt1_f64(pg, src + i)); + } +} + +template <> +Vectorized inline fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return svmad_f64_x(ptrue, a, b, c); +} + +template <> +Vectorized inline fnmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return svmsb_f64_x(ptrue, a, b, c); +} + +template <> +Vectorized inline fmsub( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return svnmsb_f64_x(ptrue, a, b, c); +} + +template <> +Vectorized inline fnmsub( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return svnmad_f64_x(ptrue, a, b, c); +} + +#endif // defined(CPU_CAPABILITY_SVE) + +} // namespace CPU_CAPABILITY +} // namespace at::vec + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/sve/vec_float.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/sve/vec_float.h new file mode 100644 index 0000000000000000000000000000000000000000..008b7bb711ad0888d8ba8fac509c6e8f31599c28 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/sve/vec_float.h @@ -0,0 +1,760 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#if defined(__aarch64__) && defined(AT_BUILD_ARM_VEC256_WITH_SLEEF) +#include +#define USE_SLEEF(sleef_code, non_sleef_code) sleef_code +#else +#define USE_SLEEF(sleef_code, non_sleef_code) non_sleef_code +#endif + +namespace at::vec { +// Note [CPU_CAPABILITY namespace] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// This header, and all of its subheaders, will be compiled with +// different architecture flags for each supported set of vector +// intrinsics. So we need to make sure they aren't inadvertently +// linked together. We do this by declaring objects in an `inline +// namespace` which changes the name mangling, but can still be +// accessed as `at::vec`. +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_SVE) + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized { + private: + vls_float32_t values; + + public: + using value_type = float; + using size_type = int; + static constexpr size_type size() { + return VECTOR_WIDTH / sizeof(float); + } + Vectorized() { + values = svdup_n_f32(0); + } + Vectorized(svfloat32_t v) : values(v) {} + Vectorized(float val) { + values = svdup_n_f32(val); + } + template < + typename... Args, + typename = std::enable_if_t<(sizeof...(Args) == size())>> + Vectorized(Args... vals) { + __at_align__ float buffer[size()] = {vals...}; + values = svld1_f32(ptrue, buffer); + } + operator svfloat32_t() const { + return values; + } + template + static Vectorized blend( + const Vectorized& a, + const Vectorized& b) { + // Build an array of flags: each element is 1 if the corresponding bit in + // 'mask' is set, 0 otherwise. + __at_align__ int32_t flag_arr[size()]; + for (int i = 0; i < size(); i++) { + flag_arr[i] = (mask & (1ULL << i)) ? 1 : 0; + } + // Load the flag array into an SVE int32 vector. + svint32_t int_mask = svld1_s32(svptrue_b32(), flag_arr); + // Compare each lane of int_mask to 0; returns an svbool_t predicate where + // true indicates a nonzero flag. + svbool_t blend_mask = svcmpne_n_s32(svptrue_b32(), int_mask, 0); + // Use svsel to select elements from b where the predicate is true, else + // from a. + svfloat32_t result = svsel_f32(blend_mask, b.values, a.values); + return Vectorized(result); + } + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask_) { + svbool_t mask = + svcmpeq_s32(ptrue, svreinterpret_s32_f32(mask_), ALL_S32_TRUE_MASK); + return svsel_f32(mask, b, a); + } + template + static Vectorized arange( + float base = 0.f, + step_t step = static_cast(1)) { + __at_align__ float buffer[size()]; + for (int64_t i = 0; i < size(); i++) { + buffer[i] = base + i * step; + } + return svld1_f32(ptrue, buffer); + } + static Vectorized set( + const Vectorized& a, + const Vectorized& b, + int64_t count = size()) { + if (count == 0) { + return a; + } else if (count < size()) { + return svsel_f32(svwhilelt_b32(0ull, count), b, a); + } + return b; + } + static Vectorized loadu(const void* ptr, int64_t count = size()) { + if (count == size()) + return svld1_f32(ptrue, reinterpret_cast(ptr)); + svbool_t pg = svwhilelt_b32(0ull, count); + return svld1_f32(pg, reinterpret_cast(ptr)); + } + void store(void* ptr, int64_t count = size()) const { + if (count == size()) { + svst1_f32(ptrue, reinterpret_cast(ptr), values); + } else { + svbool_t pg = svwhilelt_b32(0ull, count); + svst1_f32(pg, reinterpret_cast(ptr), values); + } + } + const float& operator[](int idx) const = delete; + float& operator[](int idx) = delete; + int64_t zero_mask() const { + // returns an integer mask where all zero elements are translated to 1-bit + // and others are translated to 0-bit + int64_t mask = 0; + __at_align__ int32_t mask_array[size()]; + + svbool_t svbool_mask = svcmpeq_f32(ptrue, values, ZERO_F32); + svst1_s32( + ptrue, + mask_array, + svsel_s32(svbool_mask, ALL_S32_TRUE_MASK, ALL_S32_FALSE_MASK)); + for (int64_t i = 0; i < size(); ++i) { + if (mask_array[i]) + mask |= (1ull << i); + } + return mask; + } + Vectorized isnan() const { + // NaN check + svbool_t mask = svcmpuo_f32(ptrue, values, ZERO_F32); + return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK); + } + bool has_inf_nan() const { + return svptest_any( + ptrue, + svcmpuo_f32(ptrue, svsub_f32_x(ptrue, values, values), ZERO_F32)); + } + Vectorized map(float (*f)(float)) const { + __at_align__ float tmp[size()]; + store(tmp); + for (int64_t i = 0; i < size(); ++i) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + Vectorized abs() const { + return svabs_f32_x(ptrue, values); + } + Vectorized angle() const { + const auto nan_vec = svdup_n_f32(NAN); + const auto nan_mask = svcmpuo_f32(ptrue, values, ZERO_F32); + const auto pi = svdup_n_f32(c10::pi); + + const auto neg_mask = svcmplt_f32(ptrue, values, ZERO_F32); + auto angle = svsel_f32(neg_mask, pi, ZERO_F32); + angle = svsel_f32(nan_mask, nan_vec, angle); + return angle; + } + Vectorized real() const { + return values; + } + Vectorized imag() const { + return Vectorized(0.f); + } + Vectorized conj() const { + return values; + } + Vectorized acos() const { + return USE_SLEEF( + Vectorized(Sleef_acosfx_u10sve(values)), map(std::acos)); + } + Vectorized acosh() const { + return USE_SLEEF( + Vectorized(Sleef_acoshfx_u10sve(values)), map(std::acosh)); + } + Vectorized asin() const { + return USE_SLEEF( + Vectorized(Sleef_asinfx_u10sve(values)), map(std::asin)); + } + Vectorized asinh() const { + return USE_SLEEF( + Vectorized(Sleef_asinhfx_u10sve(values)), map(std::asinh)); + } + Vectorized atan() const { + return USE_SLEEF( + Vectorized(Sleef_atanfx_u10sve(values)), map(std::atan)); + } + Vectorized atanh() const { + return USE_SLEEF( + Vectorized(Sleef_atanhfx_u10sve(values)), map(std::atanh)); + } + Vectorized atan2(const Vectorized& b) const {USE_SLEEF( + { return Vectorized(Sleef_atan2fx_u10sve(values, b)); }, + { + __at_align__ float tmp[size()]; + __at_align__ float tmp_b[size()]; + store(tmp); + b.store(tmp_b); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = std::atan2(tmp[i], tmp_b[i]); + } + return loadu(tmp); + })} Vectorized copysign(const Vectorized& sign) const { + + USE_SLEEF( + { return Vectorized(Sleef_copysignfx_sve(values, sign)); }, + { + __at_align__ float tmp[size()]; + __at_align__ float tmp_sign[size()]; + store(tmp); + sign.store(tmp_sign); + for (int64_t i = 0; i < size(); ++i) { + tmp[i] = std::copysign(tmp[i], tmp_sign[i]); + } + return loadu(tmp); + })} Vectorized erf() const { + return USE_SLEEF( + Vectorized(Sleef_erffx_u10sve(values)), map(std::erf)); + } + Vectorized erfc() const { + return USE_SLEEF( + Vectorized(Sleef_erfcfx_u15sve(values)), map(std::erfc)); + } + Vectorized erfinv() const { + return map(calc_erfinv); + } + Vectorized exp() const { + return USE_SLEEF( + Vectorized(Sleef_expfx_u10sve(values)), map(std::exp)); + } + Vectorized exp2() const { + return USE_SLEEF( + Vectorized(Sleef_exp2fx_u10sve(values)), map(std::exp2)); + } + Vectorized expm1() const { + return USE_SLEEF( + Vectorized(Sleef_expm1fx_u10sve(values)), map(std::expm1)); + } + // Implementation copied from Arm Optimized Routines: + // https://github.com/ARM-software/optimized-routines/blob/master/math/aarch64/sve/expf.c + Vectorized exp_u20() const { + // special case to handle special inputs that are too large or too small + // i.e. where there's at least one element x, s.t. |x| >= 87.3... + svbool_t is_special_case = svacgt(svptrue_b32(), values, 0x1.5d5e2ap+6f); + if (svptest_any(svptrue_b32(), is_special_case)) { + return exp(); + } + const svfloat32_t ln2_hi = svdup_n_f32(0x1.62e4p-1f); + const svfloat32_t ln2_lo = svdup_n_f32(0x1.7f7d1cp-20f); + const svfloat32_t c1 = svdup_n_f32(0.5f); + const svfloat32_t inv_ln2 = svdup_n_f32(0x1.715476p+0f); + + const float shift = 0x1.803f8p17f; + + /* n = round(x/(ln2/N)). */ + svfloat32_t z = svmad_x(svptrue_b32(), inv_ln2, values, shift); + svfloat32_t n = svsub_x(svptrue_b32(), z, shift); + + /* r = x - n*ln2/N. */ + svfloat32_t r = values; + r = svmls_x(svptrue_b32(), r, n, ln2_hi); + r = svmls_x(svptrue_b32(), r, n, ln2_lo); + + /* scale = 2^(n/N). */ + svfloat32_t scale = svexpa(svreinterpret_u32(z)); + + /* poly(r) = exp(r) - 1 ~= r + 0.5 r^2. */ + svfloat32_t r2 = svmul_x(svptrue_b32(), r, r); + svfloat32_t poly = svmla_x(svptrue_b32(), r, r2, c1); + return svmla_x(svptrue_b32(), scale, scale, poly); + } + Vectorized fexp_u20() const { + return exp_u20(); + } + Vectorized fmod(const Vectorized& q) const {USE_SLEEF( + { return Vectorized(Sleef_fmodfx_sve(values, q)); }, + { + __at_align__ float tmp[size()]; + __at_align__ float tmp_q[size()]; + store(tmp); + q.store(tmp_q); + for (int64_t i = 0; i < size(); ++i) { + tmp[i] = std::fmod(tmp[i], tmp_q[i]); + } + return loadu(tmp); + })} Vectorized hypot(const Vectorized& b) const { + USE_SLEEF( + { return Vectorized(Sleef_hypotfx_u05sve(values, b)); }, + { + __at_align__ float tmp[size()]; + __at_align__ float tmp_b[size()]; + store(tmp); + b.store(tmp_b); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = std::hypot(tmp[i], tmp_b[i]); + } + return loadu(tmp); + })} Vectorized i0() const { + return map(calc_i0); + } + Vectorized i0e() const { + return map(calc_i0e); + } + Vectorized digamma() const { + return map(calc_digamma); + } + Vectorized igamma(const Vectorized& x) const { + __at_align__ float tmp[size()]; + __at_align__ float tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = calc_igamma(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } + Vectorized igammac(const Vectorized& x) const { + __at_align__ float tmp[size()]; + __at_align__ float tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = calc_igammac(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } + Vectorized nextafter(const Vectorized& b) const {USE_SLEEF( + { return Vectorized(Sleef_nextafterfx_sve(values, b)); }, + { + __at_align__ float tmp[size()]; + __at_align__ float tmp_b[size()]; + store(tmp); + b.store(tmp_b); + for (int64_t i = 0; i < size(); ++i) { + tmp[i] = std::nextafter(tmp[i], tmp_b[i]); + } + return loadu(tmp); + })} Vectorized log() const { + return USE_SLEEF( + Vectorized(Sleef_logfx_u10sve(values)), map(std::log)); + } + Vectorized log2() const { + return USE_SLEEF( + Vectorized(Sleef_log2fx_u10sve(values)), map(std::log2)); + } + Vectorized log10() const { + return USE_SLEEF( + Vectorized(Sleef_log10fx_u10sve(values)), map(std::log10)); + } + Vectorized log1p() const { + return USE_SLEEF( + Vectorized(Sleef_log1pfx_u10sve(values)), map(std::log1p)); + } + Vectorized frac() const; + Vectorized sin() const { + return USE_SLEEF( + Vectorized(Sleef_sinfx_u10sve(values)), map(std::sin)); + } + Vectorized sinh() const { + return USE_SLEEF( + Vectorized(Sleef_sinhfx_u10sve(values)), map(std::sinh)); + } + Vectorized cos() const { + return USE_SLEEF( + Vectorized(Sleef_cosfx_u10sve(values)), map(std::cos)); + } + Vectorized cosh() const { + return USE_SLEEF( + Vectorized(Sleef_coshfx_u10sve(values)), map(std::cosh)); + } + Vectorized ceil() const { + return svrintp_f32_x(ptrue, values); + } + Vectorized floor() const { + return svrintm_f32_x(ptrue, values); + } + Vectorized neg() const { + return svneg_f32_x(ptrue, values); + } + Vectorized round() const { + return svrinti_f32_x(ptrue, values); + } + Vectorized tan() const { + return USE_SLEEF( + Vectorized(Sleef_tanfx_u10sve(values)), map(std::tan)); + } + // Implementation is picked from + // https://github.com/ARM-software/ComputeLibrary/blob/v25.01/src/core/NEON/SVEMath.inl#L179 + Vectorized tanh() const { + // Constants used for the tanh calculation. + const svfloat32_t CONST_1 = + svdup_n_f32(1.f); // Constant 1.0f for the tanh formula. + const svfloat32_t CONST_2 = svdup_n_f32( + 2.f); // Constant 2.0f for the tanh formula (used in exp(2x)). + const svfloat32_t CONST_MIN_TANH = svdup_n_f32( + -10.f); // Minimum threshold for input values to prevent overflow. + const svfloat32_t CONST_MAX_TANH = svdup_n_f32( + 10.f); // Maximum threshold for input values to prevent overflow. + + // Step 1: Clamp the values within the range [-10, 10] to prevent overflow + // during exponentiation. The tanh function approaches ±1 rapidly as the + // input grows large, so we limit the input range to avoid numerical + // instability. svmax_f32_z ensures values are greater than -10, and + // svmin_f32_z ensures they are less than 10. + svfloat32_t x = svmin_f32_z( + ptrue, svmax_f32_z(ptrue, values, CONST_MIN_TANH), CONST_MAX_TANH); + + // Step 2: Calculate exp(2 * x), where x is the clamped value. + // svmul_f32_z computes 2 * x, and exp_u20() computes the exponential of + // the result (via Vectorized, then auto-converts back to + // svfloat32_t). + svfloat32_t exp2x = + Vectorized(svmul_f32_z(ptrue, CONST_2, x)).exp_u20(); + + // Step 3: Calculate the numerator of the tanh function, which is exp(2x) + // - 1. + svfloat32_t num = svsub_f32_z(ptrue, exp2x, CONST_1); + + // Step 4: Calculate the denominator of the tanh function, which is exp(2x) + // + 1. + svfloat32_t den = svadd_f32_z(ptrue, exp2x, CONST_1); + + // Step 5: Calculate the tanh function as the ratio of the numerator and + // denominator: num / den. + svfloat32_t tanh = svdiv_f32_z(ptrue, num, den); + + // Return the calculated tanh values. + return tanh; + } + Vectorized trunc() const { + return svrintz_f32_x(ptrue, values); + } + Vectorized lgamma() const { + return USE_SLEEF( + Vectorized(Sleef_lgammafx_u10sve(values)), map(std::lgamma)); + } + Vectorized sqrt() const { + return svsqrt_f32_x(ptrue, values); + } + Vectorized reciprocal() const { + return svdivr_f32_x(ptrue, values, ONE_F32); + } + Vectorized rsqrt() const { + return svdivr_f32_x(ptrue, svsqrt_f32_x(ptrue, values), ONE_F32); + } + Vectorized pow(const Vectorized& b) const {USE_SLEEF( + { return Vectorized(Sleef_powfx_u10sve(values, b)); }, + { + __at_align__ float tmp[size()]; + __at_align__ float tmp_b[size()]; + store(tmp); + b.store(tmp_b); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = std::pow(tmp[i], tmp_b[i]); + } + return loadu(tmp); + })} // Comparison using the _CMP_**_OQ predicate. + // `O`: get false if an operand is NaN + // `Q`: do not raise if an operand is NaN + Vectorized operator==(const Vectorized& other) const { + svbool_t mask = svcmpeq_f32(ptrue, values, other); + return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK); + } + + Vectorized operator!=(const Vectorized& other) const { + svbool_t mask = svcmpne_f32(ptrue, values, other); + return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK); + } + + Vectorized operator<(const Vectorized& other) const { + svbool_t mask = svcmplt_f32(ptrue, values, other); + return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK); + } + + Vectorized operator<=(const Vectorized& other) const { + svbool_t mask = svcmple_f32(ptrue, values, other); + return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK); + } + + Vectorized operator>(const Vectorized& other) const { + svbool_t mask = svcmpgt_f32(ptrue, values, other); + return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK); + } + + Vectorized operator>=(const Vectorized& other) const { + svbool_t mask = svcmpge_f32(ptrue, values, other); + return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK); + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return svadd_f32_x(ptrue, a, b); +} + +template <> +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return svsub_f32_x(ptrue, a, b); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return svmul_f32_x(ptrue, a, b); +} + +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return svdiv_f32_x(ptrue, a, b); +} + +// frac. Implement this here so we can use subtraction +Vectorized inline Vectorized::frac() const { + return *this - this->trunc(); +} + +// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return svmax_f32_x(ptrue, a, b); +} + +// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return svmin_f32_x(ptrue, a, b); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min, + const Vectorized& max) { + return svmin_f32_x(ptrue, max, svmax_f32_x(ptrue, min, a)); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max) { + return svmin_f32_x(ptrue, max, a); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min) { + return svmax_f32_x(ptrue, min, a); +} + +template <> +Vectorized inline operator&( + const Vectorized& a, + const Vectorized& b) { + return svreinterpret_f32_s32( + svand_s32_x(ptrue, svreinterpret_s32_f32(a), svreinterpret_s32_f32(b))); +} + +template <> +Vectorized inline operator|( + const Vectorized& a, + const Vectorized& b) { + return svreinterpret_f32_s32( + svorr_s32_x(ptrue, svreinterpret_s32_f32(a), svreinterpret_s32_f32(b))); +} + +template <> +Vectorized inline operator^( + const Vectorized& a, + const Vectorized& b) { + return svreinterpret_f32_s32( + sveor_s32_x(ptrue, svreinterpret_s32_f32(a), svreinterpret_s32_f32(b))); +} + +Vectorized inline Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1.0f); +} + +Vectorized inline Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1.0f); +} + +Vectorized inline Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1.0f); +} + +Vectorized inline Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1.0f); +} + +Vectorized inline Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1.0f); +} + +Vectorized inline Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1.0f); +} + +template <> +inline void convert(const float* src, float* dst, int64_t n) { + const int64_t fraction = n % Vectorized::size(); +#pragma unroll + for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { + svst1_f32(ptrue, dst + i, svldnt1_f32(ptrue, src + i)); + } +#pragma unroll + for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { + svbool_t pg = svwhilelt_b32(i, n); + svst1_f32(pg, dst + i, svldnt1_f32(pg, src + i)); + } +} + +template <> +inline void convert(const float* src, at::Half* dst, int64_t n) { + const int64_t fraction = n % Vectorized::size(); + svbool_t pg_16 = svwhilelt_b16(0ull, Vectorized::size()); + svbool_t pg_32 = svwhilelt_b32(0ull, Vectorized::size()); +#pragma unroll + for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { + svfloat16_t src_vec = svuzp1_f16( + svcvt_f16_f32_x(ptrue, svldnt1_f32(pg_32, src + i)), ZERO_F16); + svst1_f16(pg_16, reinterpret_cast(dst) + i, src_vec); + } +#pragma unroll + for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { + pg_16 = svwhilelt_b16(i, n); + pg_32 = svwhilelt_b32(i, n); + svfloat16_t src_vec = svuzp1_f16( + svcvt_f16_f32_x(ptrue, svldnt1_f32(pg_32, src + i)), ZERO_F16); + svst1_f16(pg_16, reinterpret_cast(dst) + i, src_vec); + } +} + +template <> +inline void convert(const at::Half* src, float* dst, int64_t n) { + const int64_t fraction = n % Vectorized::size(); + svbool_t pg_16 = svwhilelt_b16(0ull, Vectorized::size()); + svbool_t pg_32 = svwhilelt_b32(0ull, Vectorized::size()); +#pragma unroll + for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { + svfloat16_t src_vec = svzip1_f16( + svldnt1_f16(pg_16, reinterpret_cast(src) + i), + ZERO_F16); + svst1_f32(pg_32, dst + i, svcvt_f32_f16_x(ptrue, src_vec)); + } +#pragma unroll + for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { + pg_16 = svwhilelt_b16(i, n); + pg_32 = svwhilelt_b32(i, n); + svfloat16_t src_vec = svzip1_f16( + svldnt1_f16(pg_16, reinterpret_cast(src) + i), + ZERO_F16); + svst1_f32(pg_32, dst + i, svcvt_f32_f16_x(ptrue, src_vec)); + } +} + +template <> +inline void convert(const bool* src, float* dst, int64_t n) { + const int64_t fraction = n % Vectorized::size(); + svbool_t pg_8 = svwhilelt_b8(0ull, Vectorized::size()); + svbool_t pg_32 = svwhilelt_b32(0ull, Vectorized::size()); +#pragma unroll + for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { + svuint8_t src_vec_u8 = + svldnt1_u8(pg_8, reinterpret_cast(src) + i); + svuint32_t src_vec_u32 = svunpklo_u32(svunpklo_u16(src_vec_u8)); + svbool_t mask = svcmpne_u32(pg_32, src_vec_u32, ZERO_U32); + svst1_f32(pg_32, dst + i, svsel_f32(mask, ONE_F32, ZERO_F32)); + } +#pragma unroll + for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { + pg_8 = svwhilelt_b8(i, n); + pg_32 = svwhilelt_b32(i, n); + svuint8_t src_vec_u8 = + svldnt1_u8(pg_8, reinterpret_cast(src) + i); + svuint32_t src_vec_u32 = svunpklo_u32(svunpklo_u16(src_vec_u8)); + svbool_t mask = svcmpne_u32(pg_32, src_vec_u32, ZERO_U32); + svst1_f32(pg_32, dst + i, svsel_f32(mask, ONE_F32, ZERO_F32)); + } +} + +template <> +Vectorized inline fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return svmad_f32_x(ptrue, a, b, c); +} + +template <> +Vectorized inline fnmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return svmsb_f32_x(ptrue, a, b, c); +} + +template <> +Vectorized inline fmsub( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return svnmsb_f32_x(ptrue, a, b, c); +} + +template <> +Vectorized inline fnmsub( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return svnmad_f32_x(ptrue, a, b, c); +} + +#endif // defined(CPU_CAPABILITY_SVE) + +} // namespace CPU_CAPABILITY +} // namespace at::vec + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/sve/vec_int.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/sve/vec_int.h new file mode 100644 index 0000000000000000000000000000000000000000..3dee484491f505993e1c523591b88747e782ede0 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/sve/vec_int.h @@ -0,0 +1,504 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include + +namespace at::vec { +// Note [CPU_CAPABILITY namespace] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// This header, and all of its subheaders, will be compiled with +// different architecture flags for each supported set of vector +// intrinsics. So we need to make sure they aren't inadvertently +// linked together. We do this by declaring objects in an `inline +// namespace` which changes the name mangling, but can still be +// accessed as `at::vec`. +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_SVE) + +#define VEC_INT_SVE_TEMPLATE(vl, bit) \ + template <> \ + struct is_vec_specialized_for : std::bool_constant {}; \ + \ + template <> \ + class Vectorized { \ + private: \ + vls_int##bit##_t values; \ + \ + public: \ + using value_type = int##bit##_t; \ + using size_type = int; \ + static constexpr size_type size() { \ + return vl; \ + } \ + Vectorized() { \ + values = svdup_n_s##bit(0); \ + } \ + Vectorized(svint##bit##_t v) : values(v) {} \ + Vectorized(int##bit##_t val) { \ + values = svdup_n_s##bit(val); \ + } \ + template < \ + typename... Args, \ + typename = std::enable_if_t<(sizeof...(Args) == size())>> \ + Vectorized(Args... vals) { \ + __at_align__ int##bit##_t buffer[size()] = {vals...}; \ + values = svld1_s##bit(ptrue, buffer); \ + } \ + operator svint##bit##_t() const { \ + return values; \ + } \ + template \ + static Vectorized blend( \ + const Vectorized& a, \ + const Vectorized& b) { \ + __at_align__ int##bit##_t flag_arr[size()]; \ + for (int i = 0; i < size(); ++i) { \ + flag_arr[i] = (i < 64 && (mask & (1ULL << i))) ? 1 : 0; \ + } \ + svbool_t blend_mask = svcmpne_n_s##bit( \ + svptrue_b##bit(), svld1_s##bit(svptrue_b##bit(), flag_arr), 0); \ + return Vectorized( \ + svsel_s##bit(blend_mask, b.values, a.values)); \ + } \ + static Vectorized blendv( \ + const Vectorized& a, \ + const Vectorized& b, \ + const Vectorized& mask_) { \ + svbool_t mask = svcmpeq_s##bit(ptrue, mask_, ALL_S##bit##_TRUE_MASK); \ + return svsel_s##bit(mask, b, a); \ + } \ + /* step sometimes requires a higher precision type (e.g., T=int, \ + * step_t=double) */ \ + template \ + static Vectorized arange( \ + int##bit##_t base = 0, \ + step_t step = static_cast(1)) { \ + __at_align__ int##bit##_t buffer[size()]; \ + for (int64_t i = 0; i < size(); i++) { \ + buffer[i] = base + i * step; \ + } \ + return svld1_s##bit(ptrue, buffer); \ + } \ + static Vectorized set( \ + const Vectorized& a, \ + const Vectorized& b, \ + int##bit##_t count = size()) { \ + if (count == 0) { \ + return a; \ + } else if (count < size()) { \ + return svsel_s##bit(svwhilelt_b##bit(0ull, count), b, a); \ + } \ + return b; \ + } \ + static Vectorized loadu( \ + const void* ptr, \ + int64_t count = size()) { \ + if (count == size()) \ + return svld1_s##bit( \ + ptrue, reinterpret_cast(ptr)); \ + svbool_t pg = svwhilelt_b##bit(0ull, count); \ + return svld1_s##bit(pg, reinterpret_cast(ptr)); \ + } \ + void store(void* ptr, int64_t count = size()) const { \ + if (count == size()) { \ + svst1_s##bit(ptrue, reinterpret_cast(ptr), values); \ + } else { \ + svbool_t pg = svwhilelt_b##bit(0ull, count); \ + svst1_s##bit(pg, reinterpret_cast(ptr), values); \ + } \ + } \ + const int##bit##_t& operator[](int idx) const = delete; \ + int##bit##_t& operator[](int idx) = delete; \ + Vectorized abs() const { \ + return svabs_s##bit##_x(ptrue, values); \ + } \ + Vectorized real() const { \ + return values; \ + } \ + Vectorized imag() const { \ + return svdup_n_s##bit(0); \ + } \ + Vectorized conj() const { \ + return values; \ + } \ + Vectorized frac() const; \ + Vectorized neg() const { \ + return svneg_s##bit##_x(ptrue, values); \ + } \ + Vectorized operator==( \ + const Vectorized& other) const { \ + svbool_t mask = svcmpeq_s##bit(ptrue, values, other); \ + return svsel_s##bit( \ + mask, ALL_S##bit##_TRUE_MASK, ALL_S##bit##_FALSE_MASK); \ + } \ + Vectorized operator!=( \ + const Vectorized& other) const { \ + svbool_t mask = svcmpne_s##bit(ptrue, values, other); \ + return svsel_s##bit( \ + mask, ALL_S##bit##_TRUE_MASK, ALL_S##bit##_FALSE_MASK); \ + } \ + Vectorized operator<( \ + const Vectorized& other) const { \ + svbool_t mask = svcmplt_s##bit(ptrue, values, other); \ + return svsel_s##bit( \ + mask, ALL_S##bit##_TRUE_MASK, ALL_S##bit##_FALSE_MASK); \ + } \ + Vectorized operator<=( \ + const Vectorized& other) const { \ + svbool_t mask = svcmple_s##bit(ptrue, values, other); \ + return svsel_s##bit( \ + mask, ALL_S##bit##_TRUE_MASK, ALL_S##bit##_FALSE_MASK); \ + } \ + Vectorized operator>( \ + const Vectorized& other) const { \ + svbool_t mask = svcmpgt_s##bit(ptrue, values, other); \ + return svsel_s##bit( \ + mask, ALL_S##bit##_TRUE_MASK, ALL_S##bit##_FALSE_MASK); \ + } \ + Vectorized operator>=( \ + const Vectorized& other) const { \ + svbool_t mask = svcmpge_s##bit(ptrue, values, other); \ + return svsel_s##bit( \ + mask, ALL_S##bit##_TRUE_MASK, ALL_S##bit##_FALSE_MASK); \ + } \ + Vectorized eq(const Vectorized& other) const; \ + Vectorized ne(const Vectorized& other) const; \ + Vectorized gt(const Vectorized& other) const; \ + Vectorized ge(const Vectorized& other) const; \ + Vectorized lt(const Vectorized& other) const; \ + Vectorized le(const Vectorized& other) const; \ + }; \ + template <> \ + Vectorized inline operator+( \ + const Vectorized& a, const Vectorized& b) { \ + return svadd_s##bit##_x(ptrue, a, b); \ + } \ + template <> \ + Vectorized inline operator-( \ + const Vectorized& a, const Vectorized& b) { \ + return svsub_s##bit##_x(ptrue, a, b); \ + } \ + template <> \ + Vectorized inline operator*( \ + const Vectorized& a, const Vectorized& b) { \ + return svmul_s##bit##_x(ptrue, a, b); \ + } \ + template <> \ + Vectorized inline maximum( \ + const Vectorized& a, const Vectorized& b) { \ + return svmax_s##bit##_x(ptrue, a, b); \ + } \ + template <> \ + Vectorized inline minimum( \ + const Vectorized& a, const Vectorized& b) { \ + return svmin_s##bit##_x(ptrue, a, b); \ + } \ + template <> \ + Vectorized inline clamp( \ + const Vectorized& a, \ + const Vectorized& min, \ + const Vectorized& max) { \ + return svmin_s##bit##_x(ptrue, max, svmax_s##bit##_x(ptrue, min, a)); \ + } \ + template <> \ + Vectorized inline clamp_max( \ + const Vectorized& a, \ + const Vectorized& max) { \ + return svmin_s##bit##_x(ptrue, max, a); \ + } \ + template <> \ + Vectorized inline clamp_min( \ + const Vectorized& a, \ + const Vectorized& min) { \ + return svmax_s##bit##_x(ptrue, min, a); \ + } \ + template <> \ + Vectorized inline operator&( \ + const Vectorized& a, const Vectorized& b) { \ + return svand_s##bit##_x(ptrue, a, b); \ + } \ + template <> \ + Vectorized inline operator|( \ + const Vectorized& a, const Vectorized& b) { \ + return svorr_s##bit##_x(ptrue, a, b); \ + } \ + template <> \ + Vectorized inline operator^( \ + const Vectorized& a, const Vectorized& b) { \ + return sveor_s##bit##_x(ptrue, a, b); \ + } \ + template <> \ + inline Vectorized operator~( \ + const Vectorized& a) { \ + return sveor_s##bit##_x(ptrue, a, svdup_n_s##bit(-1)); \ + } \ + Vectorized inline Vectorized::eq( \ + const Vectorized& other) const { \ + return (*this == other) & Vectorized(1); \ + } \ + Vectorized inline Vectorized::ne( \ + const Vectorized& other) const { \ + return (*this != other) & Vectorized(1); \ + } \ + Vectorized inline Vectorized::gt( \ + const Vectorized& other) const { \ + return (*this > other) & Vectorized(1); \ + } \ + Vectorized inline Vectorized::ge( \ + const Vectorized& other) const { \ + return (*this >= other) & Vectorized(1); \ + } \ + Vectorized inline Vectorized::lt( \ + const Vectorized& other) const { \ + return (*this < other) & Vectorized(1); \ + } \ + Vectorized inline Vectorized::le( \ + const Vectorized& other) const { \ + return (*this <= other) & Vectorized(1); \ + } + +VEC_INT_SVE_TEMPLATE(VECTOR_WIDTH / sizeof(int64_t), 64) +VEC_INT_SVE_TEMPLATE(VECTOR_WIDTH / sizeof(int32_t), 32) +VEC_INT_SVE_TEMPLATE(VECTOR_WIDTH / sizeof(int16_t), 16) +VEC_INT_SVE_TEMPLATE(VECTOR_WIDTH / sizeof(int8_t), 8) + +template +Vectorized inline intdiv_nosve( + const Vectorized& a, + const Vectorized& b) { + T values_a[Vectorized::size()]; + T values_b[Vectorized::size()]; + a.store(values_a); + b.store(values_b); + for (int i = 0; i != Vectorized::size(); i++) { + values_a[i] /= values_b[i]; + } + return Vectorized::loadu(values_a); +} + +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return svdiv_s64_x(ptrue, a, b); +} + +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return svdiv_s32_x(ptrue, a, b); +} + +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return intdiv_nosve(a, b); +} + +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return intdiv_nosve(a, b); +} + +template <> +inline void convert(const int32_t* src, int64_t* dst, int64_t n) { + const int64_t fraction = n % Vectorized::size(); + svbool_t pg_32 = svwhilelt_b32(0ull, Vectorized::size()); + svbool_t pg_64 = svwhilelt_b64(0ull, Vectorized::size()); +#pragma unroll + for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) + svst1_s64(pg_64, dst + i, svunpklo_s64(svldnt1_s32(pg_32, src + i))); +#pragma unroll + for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { + pg_32 = svwhilelt_b32(i, n); + pg_64 = svwhilelt_b64(i, n); + svst1_s64(pg_64, dst + i, svunpklo_s64(svldnt1_s32(pg_32, src + i))); + } +} + +template <> +inline void convert(const int64_t* src, float* dst, int64_t n) { + const int64_t fraction = n % Vectorized::size(); + svbool_t pg_32 = svwhilelt_b32(0ull, Vectorized::size()); + svbool_t pg_64 = svwhilelt_b64(0ull, Vectorized::size()); +#pragma unroll + for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { + svint64_t src_vec_s64 = svldnt1_s64(pg_64, src + i); + svfloat32_t src_vec_f32 = + svuzp1_f32(svcvt_f32_s64_x(pg_64, src_vec_s64), ZERO_F32); + svst1_f32(pg_32, dst + i, src_vec_f32); + } +#pragma unroll + for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { + pg_32 = svwhilelt_b32(i, n); + pg_64 = svwhilelt_b64(i, n); + svint64_t src_vec_s64 = svldnt1_s64(pg_64, src + i); + svfloat32_t src_vec_f32 = + svuzp1_f32(svcvt_f32_s64_x(pg_64, src_vec_s64), ZERO_F32); + svst1_f32(pg_32, dst + i, src_vec_f32); + } +} + +template <> +inline void convert(const int32_t* src, float* dst, int64_t n) { + const int64_t fraction = n % Vectorized::size(); + svbool_t pg = svwhilelt_b32(0ull, Vectorized::size()); +#pragma unroll + for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { + svint32_t src_vec = svldnt1_s32(pg, src + i); + svst1_f32(pg, dst + i, svcvt_f32_s32_x(pg, src_vec)); + } +#pragma unroll + for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { + pg = svwhilelt_b32(i, n); + svint32_t src_vec = svldnt1_s32(pg, src + i); + svst1_f32(pg, dst + i, svcvt_f32_s32_x(pg, src_vec)); + } +} + +template <> +inline void convert(const bool* src, int64_t* dst, int64_t n) { + const int64_t fraction = n % Vectorized::size(); + svbool_t pg_8 = svwhilelt_b8(0ull, Vectorized::size()); + svbool_t pg_64 = svwhilelt_b64(0ull, Vectorized::size()); +#pragma unroll + for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { + svuint8_t src_vec_u8 = + svldnt1_u8(pg_8, reinterpret_cast(src) + i); + svuint64_t src_vec_u64 = + svunpklo_u64(svunpklo_u32(svunpklo_u16(src_vec_u8))); + svbool_t mask = svcmpne_u64(pg_64, src_vec_u64, ZERO_U64); + svst1_s64(pg_64, dst + i, svsel_s64(mask, ONE_S64, ZERO_S64)); + } +#pragma unroll + for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { + pg_8 = svwhilelt_b8(i, n); + pg_64 = svwhilelt_b64(i, n); + svuint8_t src_vec_u8 = + svldnt1_u8(pg_8, reinterpret_cast(src) + i); + svuint64_t src_vec_u64 = + svunpklo_u64(svunpklo_u32(svunpklo_u16(src_vec_u8))); + svbool_t mask = svcmpne_u64(pg_64, src_vec_u64, ZERO_U64); + svst1_s64(pg_64, dst + i, svsel_s64(mask, ONE_S64, ZERO_S64)); + } +} + +template <> +inline void convert(const bool* src, int32_t* dst, int64_t n) { + const int64_t fraction = n % Vectorized::size(); + svbool_t pg_8 = svwhilelt_b8(0ull, Vectorized::size()); + svbool_t pg_32 = svwhilelt_b32(0ull, Vectorized::size()); +#pragma unroll + for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { + svuint8_t src_vec_u8 = + svldnt1_u8(pg_8, reinterpret_cast(src) + i); + svuint32_t src_vec_u32 = svunpklo_u32(svunpklo_u16(src_vec_u8)); + svbool_t mask = svcmpne_u32(pg_32, src_vec_u32, ZERO_U32); + svst1_s32(pg_32, dst + i, svsel_s32(mask, ONE_S32, ZERO_S32)); + } +#pragma unroll + for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { + pg_8 = svwhilelt_b8(i, n); + pg_32 = svwhilelt_b32(i, n); + svuint8_t src_vec_u8 = + svldnt1_u8(pg_8, reinterpret_cast(src) + i); + svuint32_t src_vec_u32 = svunpklo_u32(svunpklo_u16(src_vec_u8)); + svbool_t mask = svcmpne_u32(pg_32, src_vec_u32, ZERO_U32); + svst1_s32(pg_32, dst + i, svsel_s32(mask, ONE_S32, ZERO_S32)); + } +} + +template <> +inline void convert(const uint8_t* src, bool* dst, int64_t n) { + const int64_t fraction = n % Vectorized::size(); + svbool_t pg = svwhilelt_b8(0ull, Vectorized::size()); +#pragma unroll + for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { + svbool_t mask = svcmpne_u8(pg, svldnt1_u8(pg, src + i), ZERO_U8); + svst1_u8( + pg, + reinterpret_cast(dst) + i, + svsel_u8(mask, ALL_U8_TRUE_MASK, ALL_U8_FALSE_MASK)); + } +#pragma unroll + for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { + pg = svwhilelt_b8(i, n); + svbool_t mask = svcmpne_u8(pg, svldnt1_u8(pg, src + i), ZERO_U8); + svst1_u8( + pg, + reinterpret_cast(dst) + i, + svsel_u8(mask, ALL_U8_TRUE_MASK, ALL_U8_FALSE_MASK)); + } +} + +template <> +Vectorized inline operator<<( + const Vectorized& a, + const Vectorized& b) { + return svlsl_s64_x(ptrue, a, svreinterpret_u64_s64(b)); +} + +template <> +Vectorized inline operator<<( + const Vectorized& a, + const Vectorized& b) { + return svlsl_s32_x(ptrue, a, svreinterpret_u32_s32(b)); +} + +template <> +Vectorized inline operator<<( + const Vectorized& a, + const Vectorized& b) { + return svlsl_s16_x(ptrue, a, svreinterpret_u16_s16(b)); +} + +template <> +Vectorized inline operator<<( + const Vectorized& a, + const Vectorized& b) { + return svlsl_s8_x(ptrue, a, svreinterpret_u8_s8(b)); +} + +template <> +Vectorized inline operator>>( + const Vectorized& a, + const Vectorized& b) { + return svasr_s64_x(ptrue, a, svreinterpret_u64_s64(b)); +} + +template <> +Vectorized inline operator>>( + const Vectorized& a, + const Vectorized& b) { + return svasr_s32_x(ptrue, a, svreinterpret_u32_s32(b)); +} + +template <> +Vectorized inline operator>>( + const Vectorized& a, + const Vectorized& b) { + return svasr_s16_x(ptrue, a, svreinterpret_u16_s16(b)); +} + +template <> +Vectorized inline operator>>( + const Vectorized& a, + const Vectorized& b) { + return svasr_s8_x(ptrue, a, svreinterpret_u8_s8(b)); +} + +#endif // defined(CPU_CAPABILITY_SVE) + +} // namespace CPU_CAPABILITY +} // namespace at::vec + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/sve/vec_qint.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/sve/vec_qint.h new file mode 100644 index 0000000000000000000000000000000000000000..98d45ba0790f208cb165d29974d99ff1547999b1 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/sve/vec_qint.h @@ -0,0 +1,611 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with SVE] + +#include +#include +#include +#include +#include +#include + +#include + +// This file defines Vectorized<> for the quantized types. +// +// +// Currently, we simply use these classes as efficient converters between +// the quantized types and Vectorized, usually in bandwidth-bound cases +// where doing the arithmetic in full-precision is acceptable (e.g. +// elementwise operators). +// +// +// Conversions are as follows: +// Vectorized -> 4x Vectorized +// Vectorized -> 4x Vectorized +// Vectorized -> 1x Vectorized +// +// The size of the returned float vector is specified by the special +// constexpr function float_num_vecs. The type of the value returned +// from dequantize (and expected as an argument to quantize) is +// specified by float_vec_return_type. +// +// When writing kernels with these vectors, it is expected that floating- +// point operations will be carried out in a loop over +// Vectorized::float_num_vecs iterations. + +namespace at::vec { +// Note [CPU_CAPABILITY namespace] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// This header, and all of its subheaders, will be compiled with +// different architecture flags for each supported set of vector +// intrinsics. So we need to make sure they aren't inadvertently +// linked together. We do this by declaring objects in an `inline +// namespace` which changes the name mangling, but can still be +// accessed as `at::vec`. +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_SVE) + +// NOTE: These are low-performance implementations that we fall back on +// if we are not building with SVE. This may not be an issue, because +// currently for quantization we assume the user has at least SVE +// installed, so these can simply act as a reference implementation. +// +// If in the future we relax this requirement (SVE+), we should probably +// revisit these implementations + +template < + typename T, + typename float_vec_return_type_, + typename int_vec_return_type_, + int size_> +struct VectorizedQuantizedConverter { + using size_type = int; + static constexpr size_type size() { + return size_; + } + + static constexpr int float_num_vecs() { + return size() / Vectorized::size(); + } + + static constexpr int int_num_vecs() { + return size() / Vectorized::size(); + } + + using float_vec_return_type = float_vec_return_type_; + using int_vec_return_type = int_vec_return_type_; + + using value_type = typename T::underlying; + std::array vals; + + VectorizedQuantizedConverter(T val) { + for (size_t i = 0; i < size(); ++i) { + vals[i] = val.val_; + } + } + + VectorizedQuantizedConverter(const void* ptr) { + memcpy(vals.data(), ptr, sizeof(value_type) * size()); + } + + void store(void* ptr, int count = size()) const { + memcpy(ptr, vals.data(), count * sizeof(value_type)); + } + + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point, + Vectorized scale_zp_premul) const { + float_vec_return_type rv; + float tmp_scale[Vectorized::size()]; + float tmp_zero_point[Vectorized::size()]; + scale.store(tmp_scale); + zero_point.store(tmp_zero_point); + for (int i = 0; i < float_num_vecs(); ++i) { + float tmp_vals[Vectorized::size()]; + for (int j = 0; j < Vectorized::size(); ++j) { + tmp_vals[j] = at::native::dequantize_val( + tmp_scale[j], + tmp_zero_point[j], + T(vals[Vectorized::size() * i + j])); + } + rv[i] = Vectorized::loadu(tmp_vals); + } + return rv; + } + + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point) const { + float_vec_return_type rv; + float tmp_scale[Vectorized::size()]; + float tmp_zero_point[Vectorized::size()]; + scale.store(tmp_scale); + zero_point.store(tmp_zero_point); + for (int i = 0; i < float_num_vecs(); ++i) { + float tmp_vals[Vectorized::size()]; + for (int j = 0; j < Vectorized::size(); ++j) { + tmp_vals[j] = at::native::dequantize_val( + tmp_scale[j], + tmp_zero_point[j], + T(vals[Vectorized::size() * i + j])); + } + rv[i] = Vectorized::loadu(tmp_vals); + } + return rv; + } + + protected: + VectorizedQuantizedConverter() {} +}; + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +struct Vectorized : public VectorizedQuantizedConverter< + c10::qint32, + std::array, 1>, + std::array, 1>, + VECTOR_WIDTH / 4> { + Vectorized() + : VectorizedQuantizedConverter< + c10::qint32, + std::array, 1>, + std::array, 1>, + VECTOR_WIDTH / 4>() {} + Vectorized(c10::qint32 val) + : VectorizedQuantizedConverter< + c10::qint32, + std::array, 1>, + std::array, 1>, + VECTOR_WIDTH / 4>(val) {} + Vectorized(const void* ptr) + : VectorizedQuantizedConverter< + c10::qint32, + std::array, 1>, + std::array, 1>, + VECTOR_WIDTH / 4>(ptr) {} +#if 1 + static Vectorized loadu(const void* ptr) { + return Vectorized(ptr); + } + + static Vectorized loadu(const void* ptr, int64_t count) { + __at_align__ value_type tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(value_type)); + return loadu(tmp_values); + } +#else + static Vectorized loadu( + const void* ptr, + int64_t count = size()) { + if (count == size()) + return svld1_s32(ptrue, reinterpret_cast(ptr)); + svbool_t pg = svwhilelt_b32(0ull, count); + return svld1_s32(pg, reinterpret_cast(ptr)); + } +#endif + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale) { + std::array qvals; + std::array::size()> float_vals; + + for (int i = 0; i < float_num_vecs(); ++i) { + rhs[i].store( + &float_vals[i * Vectorized::size()], + Vectorized::size()); + } + + at::native::quantize_vec( + scale, + zero_point, + float_vals.data(), + (c10::qint32*)qvals.data(), + Vectorized::size() * float_num_vecs()); + + return Vectorized::loadu(qvals.data()); + } + + Vectorized maximum(Vectorized b) const { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = std::max(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized minimum(Vectorized b) const { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = std::min(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized relu(Vectorized zero_point) const { + return maximum(zero_point); + } + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = std::min( + std::max(vals[i], zero_point.vals[i]), q_six.vals[i]); + } + return retval; + } + + int_vec_return_type widening_subtract(Vectorized b) const { + int_vec_return_type retval; + for (size_t i = 0; i < size(); ++i) { + retval[0].vals[i] = vals[i] - b.vals[i]; + } + return retval; + } + + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = + nearbyint(static_cast(inp[0].vals[i]) * multiplier) + + zero_point; + } + return retval; + } +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + Vectorized retval; + for (size_t i = 0; i < std::decay_t::size(); ++i) { + retval.vals[i] = a.vals[i] * b.vals[i]; + } + return retval; +} + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + Vectorized retval; + for (size_t i = 0; i < std::decay_t::size(); ++i) { + retval.vals[i] = a.vals[i] + b.vals[i]; + } + return retval; +} + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +struct Vectorized : public VectorizedQuantizedConverter< + c10::qint8, + std::array, 4>, + std::array, 4>, + VECTOR_WIDTH> { + Vectorized() + : VectorizedQuantizedConverter< + c10::qint8, + std::array, 4>, + std::array, 4>, + VECTOR_WIDTH>() {} + Vectorized(c10::qint8 val) + : VectorizedQuantizedConverter< + c10::qint8, + std::array, 4>, + std::array, 4>, + VECTOR_WIDTH>(val) {} + Vectorized(const void* ptr) + : VectorizedQuantizedConverter< + c10::qint8, + std::array, 4>, + std::array, 4>, + VECTOR_WIDTH>(ptr) {} + + static Vectorized loadu(const void* ptr) { + return Vectorized(ptr); + } + + static Vectorized loadu(const void* ptr, int64_t count) { + __at_align__ value_type tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(value_type)); + return loadu(tmp_values); + } + + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale) { + std::array qvals; + std::array::size()> float_vals; + + for (int i = 0; i < float_num_vecs(); ++i) { + rhs[i].store( + &float_vals[i * Vectorized::size()], + Vectorized::size()); + } + + at::native::quantize_vec( + scale, + zero_point, + float_vals.data(), + (c10::qint8*)qvals.data(), + Vectorized::size() * float_num_vecs()); + + return Vectorized::loadu(qvals.data()); + } + + Vectorized maximum(Vectorized b) const { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = std::max(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized minimum(Vectorized b) const { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = std::min(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized relu(Vectorized zero_point) const { + return maximum(zero_point); + } + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = std::min( + std::max(vals[i], zero_point.vals[i]), q_six.vals[i]); + } + return retval; + } + + int_vec_return_type widening_subtract(Vectorized b) const { + int_vec_return_type retval; + constexpr int elem_per_int_vec = size() / int_num_vecs(); + for (size_t i = 0; i < int_num_vecs(); ++i) { + for (size_t j = 0; j < elem_per_int_vec; ++j) { + retval[i].vals[j] = + static_cast(vals[i * elem_per_int_vec + j]) - + static_cast(b.vals[i * elem_per_int_vec + j]); + } + } + return retval; + } + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + constexpr int elem_per_int_vec = size() / int_num_vecs(); + constexpr auto min_val = std::numeric_limits::min(); + constexpr auto max_val = std::numeric_limits::max(); + Vectorized retval; + for (size_t i = 0; i < int_num_vecs(); ++i) { + for (size_t j = 0; j < elem_per_int_vec; ++j) { + int32_t rounded = + nearbyint(static_cast(inp[i].vals[j]) * multiplier) + + zero_point; + retval.vals[i * elem_per_int_vec + j] = + std::min(std::max(rounded, min_val), max_val); + } + } + return retval; + } +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +struct Vectorized : public VectorizedQuantizedConverter< + c10::quint8, + std::array, 4>, + std::array, 4>, + VECTOR_WIDTH> { + Vectorized() + : VectorizedQuantizedConverter< + c10::quint8, + std::array, 4>, + std::array, 4>, + VECTOR_WIDTH>() {} + Vectorized(c10::quint8 val) + : VectorizedQuantizedConverter< + c10::quint8, + std::array, 4>, + std::array, 4>, + VECTOR_WIDTH>(val) {} + Vectorized(const void* ptr) + : VectorizedQuantizedConverter< + c10::quint8, + std::array, 4>, + std::array, 4>, + VECTOR_WIDTH>(ptr) {} +#if 1 + static Vectorized loadu(const void* ptr) { + return Vectorized(ptr); + } + + static Vectorized loadu(const void* ptr, int64_t count) { + __at_align__ value_type tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(value_type)); + return loadu(tmp_values); + } +#else + static Vectorized loadu( + const void* ptr, + int64_t count = size()) { + if (count == size()) + return svld1_u8(ptrue, reinterpret_cast(ptr)); + svbool_t pg = svwhilelt_b8(0ull, count); + return svld1_u8(pg, reinterpret_cast(ptr)); + } +#endif + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale) { + std::array qvals; + std::array::size()> float_vals; + + for (int i = 0; i < float_num_vecs(); ++i) { + rhs[i].store( + &float_vals[i * Vectorized::size()], + Vectorized::size()); + } + + at::native::quantize_vec( + scale, + zero_point, + float_vals.data(), + (c10::quint8*)qvals.data(), + Vectorized::size() * float_num_vecs()); + + return Vectorized::loadu(qvals.data()); + } + + Vectorized maximum(Vectorized b) const { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = std::max(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized minimum(Vectorized b) const { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = std::min(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized relu(Vectorized zero_point) const { + return maximum(zero_point); + } + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = std::min( + std::max(vals[i], zero_point.vals[i]), q_six.vals[i]); + } + return retval; + } + + int_vec_return_type widening_subtract(Vectorized b) const { + int_vec_return_type retval; + constexpr int elem_per_int_vec = size() / int_num_vecs(); + for (size_t i = 0; i < int_num_vecs(); ++i) { + for (size_t j = 0; j < elem_per_int_vec; ++j) { + retval[i].vals[j] = + static_cast(vals[i * elem_per_int_vec + j]) - + static_cast(b.vals[i * elem_per_int_vec + j]); + } + } + return retval; + } + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + constexpr int elem_per_int_vec = size() / int_num_vecs(); + constexpr auto min_val = std::numeric_limits::min(); + constexpr auto max_val = std::numeric_limits::max(); + Vectorized retval; + for (size_t i = 0; i < int_num_vecs(); ++i) { + for (size_t j = 0; j < elem_per_int_vec; ++j) { + int32_t rounded = + nearbyint(static_cast(inp[i].vals[j]) * multiplier) + + zero_point; + retval.vals[i * elem_per_int_vec + j] = + std::min(std::max(rounded, min_val), max_val); + } + } + return retval; + } +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +#endif // defined(CPU_CAPABILITY_SVE) + +} // namespace CPU_CAPABILITY +} // namespace at::vec + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec.h new file mode 100644 index 0000000000000000000000000000000000000000..c5c4fb5c289aeb2f3c54172adbc614aebf490e4c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec.h @@ -0,0 +1,62 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#if defined(CPU_CAPABILITY_AVX512) +#include +#else +#include +#include +#endif + +namespace at::vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +inline Vectorized convert_to_bool(Vectorized x) { + __at_align__ bool buffer[x.size()]; + x.ne(Vectorized(0)).store(buffer); + + Vectorized ret; + static_assert(x.size() == ret.size()); + std::memcpy(ret, buffer, ret.size() * sizeof(bool)); + return ret; +} + +template <> +inline Vectorized Vectorized::loadu(const void* ptr) { + // See NOTE [Loading boolean values] + return convert_to_bool(Vectorized::loadu(ptr)); +} + +template <> +inline Vectorized Vectorized::loadu( + const void* ptr, + int64_t count) { + // See NOTE [Loading boolean values] + return convert_to_bool(Vectorized::loadu(ptr, count)); +} + +template +struct VecHoldType { + using hold_type = typename VT::value_type; +}; + +template <> +struct VecHoldType> { + using hold_type = BFloat16; +}; + +template <> +struct VecHoldType> { + using hold_type = Half; +}; + +template +using vechold_type = typename VecHoldType::hold_type; + +} // namespace CPU_CAPABILITY +} // namespace at::vec + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec128/vec128.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec128/vec128.h new file mode 100644 index 0000000000000000000000000000000000000000..766f980da7088f7f7f830bf84299de836e361837 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec128/vec128.h @@ -0,0 +1,22 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +// ARM NEON uses 128-bit vector registers. + +#include + +#ifdef __aarch64__ +#if !defined(CPU_CAPABILITY_SVE) +#include +#include +#include +#include +#include +#include +#endif + +#include +#endif + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec128/vec128_bfloat16_neon.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec128/vec128_bfloat16_neon.h new file mode 100644 index 0000000000000000000000000000000000000000..5ae7920fa4a90b434bfba8238c96926bcc522f96 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec128/vec128_bfloat16_neon.h @@ -0,0 +1,703 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] +#include +#include +#include +#include +#include +#include + +namespace at::vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +// Following vec128_half_neon.h, we only support aarch64. +#if !defined(C10_MOBILE) && defined(__aarch64__) +#ifdef __BIG_ENDIAN__ +#error "Big endian is not supported." +#endif + +// GCC does not properly optimize bf16 operators +#if defined(__ARM_FEATURE_BF16) && (__clang_major__ >= 19) +#define BF16_ARITHMETIC_SUPPORTED() 1 +#else +#define BF16_ARITHMETIC_SUPPORTED() 0 +#endif + +// Unlike the float16_t family of types, bfloat16_t is not available +// when we're not targeting bfloat16 hardware support on some +// platforms (but not Mac, so we have to be careful not to shadow the +// definitions in case they are actually there!). (See +// https://godbolt.org/z/orv6e94n4 ) So, we need to handle it as +// uint16_t in that case. +#define IMPLEMENT_AT_BF16_SHIM(vec_suffix) \ + inline at_bfloat16x4_t at_vget_low_bf16(at_bfloat16x8_t a) { \ + return vget_low_##vec_suffix(a); \ + } \ + \ + inline at_bfloat16x4_t at_vget_high_bf16(at_bfloat16x8_t a) { \ + return vget_high_##vec_suffix(a); \ + } \ + \ + inline at_bfloat16x8_t at_vcombine_bf16( \ + at_bfloat16x4_t low, at_bfloat16x4_t high) { \ + return vcombine_##vec_suffix(low, high); \ + } \ + \ + inline at_bfloat16x8_t at_vdupq_n_bf16(at_bfloat16_t value) { \ + return vdupq_n_##vec_suffix(value); \ + } \ + \ + inline at_bfloat16x8_t at_vld1q_bf16(const at_bfloat16_t* ptr) { \ + return vld1q_##vec_suffix(ptr); \ + } \ + \ + inline void at_vst1q_bf16(at_bfloat16_t* ptr, at_bfloat16x8_t value) { \ + vst1q_##vec_suffix(ptr, value); \ + } \ + \ + template \ + inline at_bfloat16x8_t at_vreinterpretq_bf16_u16(T val) { \ + if constexpr (std::is_same_v) { \ + return val; \ + } else { \ + return vreinterpretq_bf16_u16(val); \ + } \ + } \ + template \ + inline at_bfloat16x4_t at_vreinterpret_bf16_u16(T val) { \ + if constexpr (std::is_same_v) { \ + return val; \ + } else { \ + return vreinterpret_bf16_u16(val); \ + } \ + } \ + template \ + inline uint16x8_t at_vreinterpretq_u16_bf16(T val) { \ + if constexpr (std::is_same_v) { \ + return val; \ + } else { \ + return vreinterpretq_u16_bf16(val); \ + } \ + } \ + template \ + inline uint16x4_t at_vreinterpret_u16_bf16(T val) { \ + if constexpr (std::is_same_v) { \ + return val; \ + } else { \ + return vreinterpret_u16_bf16(val); \ + } \ + } + +#ifdef __ARM_FEATURE_BF16 +using at_bfloat16x8_t = bfloat16x8_t; +using at_bfloat16x4_t = bfloat16x4_t; +using at_bfloat16_t = bfloat16_t; +IMPLEMENT_AT_BF16_SHIM(bf16) +#define at_vsetq_lane_bf16 vsetq_lane_bf16 +#define at_vgetq_lane_bf16 vgetq_lane_bf16 +#else +using at_bfloat16x8_t = uint16x8_t; +using at_bfloat16x4_t = uint16x4_t; +using at_bfloat16_t = uint16_t; +IMPLEMENT_AT_BF16_SHIM(u16) +#define at_vsetq_lane_bf16 vsetq_lane_u16 +#define at_vgetq_lane_bf16 vgetq_lane_u16 +#endif // __ARM_FEATURE_BF16 + +template +struct BlendBFloat16Regs { + static at_bfloat16x8_t impl( + const at_bfloat16x8_t& a, + const at_bfloat16x8_t& b, + at_bfloat16x8_t& res); +}; + +template +struct BlendBFloat16Regs { + static at_bfloat16x8_t impl( + const at_bfloat16x8_t& a, + const at_bfloat16x8_t& b, + at_bfloat16x8_t& res) { + return at_vsetq_lane_bf16(at_vgetq_lane_bf16(b, index), res, index); + } +}; + +template +struct BlendBFloat16Regs { + static at_bfloat16x8_t impl( + const at_bfloat16x8_t& a, + const at_bfloat16x8_t& b, + at_bfloat16x8_t& res) { + return at_vsetq_lane_bf16(at_vgetq_lane_bf16(a, index), res, index); + } +}; + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized : public Vectorized16< + at_bfloat16x8_t, + c10::BFloat16, + BlendBFloat16Regs, + Vectorized> { + using Base = Vectorized16< + at_bfloat16x8_t, + c10::BFloat16, + BlendBFloat16Regs, + Vectorized>; + friend Base; + friend std::tuple, Vectorized> convert_bfloat16_float( + const Vectorized& a); + friend Vectorized convert_float_bfloat16( + const Vectorized& a, + const Vectorized& b); + + private: + Vectorized map2( + const Vectorized& second, + c10::BFloat16 (*const f)(c10::BFloat16, c10::BFloat16)) const { + __at_align__ c10::BFloat16 tmp_first[size()]; + __at_align__ c10::BFloat16 tmp_second[size()]; + store(tmp_first); // store this to tmp_first + second.store(tmp_second); + for (const auto i : c10::irange(size())) { + tmp_first[i] = f(tmp_first[i], tmp_second[i]); + } + return loadu(tmp_first); + } + + static float32x4_t convert_f32_bf16(at_bfloat16x4_t bf16) { +#ifdef __ARM_FEATURE_BF16 + return vcvt_f32_bf16(bf16); +#else + int32x4_t shift = vdupq_n_s32(16); + return vreinterpretq_f32_u32(vshlq_u32(vmovl_u16(bf16), shift)); +#endif // __ARM_FEATURE_BF16 + } + + static at_bfloat16x4_t convert_bf16_f32(const Vectorized& f32) { +#ifdef __ARM_FEATURE_BF16 + return vcvt_bf16_f32(f32); +#else + static_assert(std::is_same_v); + uint32x4_t as_uint32 = vreinterpretq_u32_f32(f32); + uint32x4_t rounding_bias = vaddq_u32( + vandq_u32(vshrq_n_u32(as_uint32, 16), vdupq_n_u32(1)), + vdupq_n_u32(0x7FFF)); + at_bfloat16x4_t rounded = + vshrn_n_u32(vaddq_u32(as_uint32, rounding_bias), 16); + const auto bf16_nan = vdup_n_u16(0x7FC0); + return vbsl_u16( + vmovn_u32(vreinterpretq_u32_f32(f32.isnan())), bf16_nan, rounded); +#endif // __ARM_FEATURE_BF16 + } + + Vectorized map_with_vec_float_method( + Vectorized (Vectorized::*m)() const) const { + float32x4_t v00 = convert_f32_bf16(at_vget_low_bf16(values)); + float32x4_t v01 = convert_f32_bf16(at_vget_high_bf16(values)); + Vectorized mv0 = (Vectorized(v00).*m)(); + Vectorized mv1 = (Vectorized(v01).*m)(); + at_bfloat16x4_t r00 = convert_bf16_f32(mv0); + at_bfloat16x4_t r01 = convert_bf16_f32(mv1); + return Vectorized(at_vcombine_bf16(r00, r01)); + } + + Vectorized map2_with_vec_float_method( + const Vectorized& second, + Vectorized (Vectorized::*m)(const Vectorized&) + const) const { + float32x4_t v00 = convert_f32_bf16(at_vget_low_bf16(values)); + float32x4_t v01 = convert_f32_bf16(at_vget_high_bf16(values)); + float32x4_t second_v00 = convert_f32_bf16(at_vget_low_bf16(second.values)); + float32x4_t second_v01 = convert_f32_bf16(at_vget_high_bf16(second.values)); + Vectorized mv0 = (Vectorized(v00).*m)(second_v00); + Vectorized mv1 = (Vectorized(v01).*m)(second_v01); + at_bfloat16x4_t r00 = convert_bf16_f32(mv0); + at_bfloat16x4_t r01 = convert_bf16_f32(mv1); + return Vectorized(at_vcombine_bf16(r00, r01)); + } + + Vectorized map2_bitmask_with_vec_float_method( + const Vectorized& second, + Vectorized (Vectorized::*m)(const Vectorized&) + const) const { + float32x4_t v00 = convert_f32_bf16(at_vget_low_bf16(values)); + float32x4_t v01 = convert_f32_bf16(at_vget_high_bf16(values)); + float32x4_t second_v00 = convert_f32_bf16(at_vget_low_bf16(second.values)); + float32x4_t second_v01 = convert_f32_bf16(at_vget_high_bf16(second.values)); + Vectorized mv0 = (Vectorized(v00).*m)(second_v00); + Vectorized mv1 = (Vectorized(v01).*m)(second_v01); + // Assume the operator returns a bitmask, not "real" floats, and + // just narrow the bits. All-ones is a NaN and will get mangled by + // conversion! + at_bfloat16x4_t r00 = + at_vreinterpret_bf16_u16(vmovn_u32(vreinterpretq_u32_f32(mv0))); + at_bfloat16x4_t r01 = + at_vreinterpret_bf16_u16(vmovn_u32(vreinterpretq_u32_f32(mv1))); + return Vectorized(at_vcombine_bf16(r00, r01)); + } + + public: + using Vectorized16::Vectorized16; + + Vectorized() = default; + + Vectorized(c10::BFloat16 val) + : Vectorized16(at_vdupq_n_bf16(c10::bit_cast(val.x))) {} + Vectorized(float val) : Vectorized(c10::BFloat16(val)) {} + Vectorized( + value_type val0, + value_type val1, + value_type val2, + value_type val3, + value_type val4, + value_type val5, + value_type val6, + value_type val7) + : Vectorized16(at_bfloat16x8_t{ + c10::bit_cast(val0.x), + c10::bit_cast(val1.x), + c10::bit_cast(val2.x), + c10::bit_cast(val3.x), + c10::bit_cast(val4.x), + c10::bit_cast(val5.x), + c10::bit_cast(val6.x), + c10::bit_cast(val7.x)}) {} + + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + // NOTE: blendv has the same problems as it does for Half; see comments in + // vec128_half_neon.h. + Vectorized vec(mask.values); + vec.values = at_vreinterpretq_bf16_u16(vbslq_u16( + at_vreinterpretq_u16_bf16(vec.values), + at_vreinterpretq_u16_bf16(b.values), + at_vreinterpretq_u16_bf16(a.values))); + return vec; + } + static Vectorized set( + const Vectorized& a, + const Vectorized& b, + int64_t count = size()) { + uint16_t pre_mask[size()] = {0}; + for (int i = 0; i < count; i++) { + pre_mask[i] = 0xFFFF; + } + uint16x8_t mask = vld1q_u16(pre_mask); + + Vectorized vec(at_vreinterpretq_bf16_u16(vbslq_u16( + mask, + at_vreinterpretq_u16_bf16(b.values), + at_vreinterpretq_u16_bf16(a.values)))); + + return vec; + } + static Vectorized loadu( + const void* ptr, + int64_t count = size()) { + if (count == size()) { + return at_vld1q_bf16(reinterpret_cast(ptr)); + } + __at_align__ at_bfloat16_t tmp_values[size()]; + std::memset(tmp_values, 0, sizeof(tmp_values)); + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(at_bfloat16_t)); + return at_vld1q_bf16(reinterpret_cast(tmp_values)); + } + void store(void* ptr, int64_t count = size()) const { + if (count == size()) { + at_vst1q_bf16(reinterpret_cast(ptr), values); + return; + } else { + at_bfloat16_t tmp_values[size()]; + at_vst1q_bf16(reinterpret_cast(tmp_values), values); + std::memcpy(ptr, tmp_values, count * sizeof(at_bfloat16_t)); + } + } + Vectorized isnan() const { + // NOTE: we could make this faster by doing vectorized checks of + // exponent/payload bits. + __at_align__ c10::BFloat16 tmp[size()]; + __at_align__ c10::BFloat16 res[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + if (_isnan(tmp[i])) { + std::memset(static_cast(&res[i]), 0xFF, sizeof(c10::BFloat16)); + } else { + std::memset(static_cast(&res[i]), 0, sizeof(c10::BFloat16)); + } + } + return loadu(res); + } + bool has_inf_nan() const { + __at_align__ c10::BFloat16 tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + if (_isnan(tmp[i]) || _isinf(tmp[i])) { + return true; + } + } + return false; + } +#define DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(name) \ + Vectorized name() const { \ + return map_with_vec_float_method(&Vectorized::name); \ + } + +#define DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(name) \ + Vectorized name(const Vectorized& other) const { \ + return map2_bitmask_with_vec_float_method( \ + other, &Vectorized::name); \ + } + + Vectorized frac() const; + DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(trunc) + DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(sqrt) + +#ifdef __ARM_FEATURE_BF16 + // Flip sign bit + Vectorized neg() const { + return vreinterpretq_bf16_s16(vreinterpretq_s16_bf16(values) ^ (-32768)); + } + // Fast reciprocal is fine because we are truncating results + Vectorized reciprocal() const { + auto x = vcvtq_low_f32_bf16(values); + auto y = vcvtq_high_f32_bf16(values); + x = vrecpeq_f32(x); + y = vrecpeq_f32(y); + return vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(x), y); + } + // Clearing the sign bit + Vectorized abs() const { + return vreinterpretq_bf16_u16(vreinterpretq_u16_bf16(values) & 0x7FFF); + } +#else + DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(abs) + DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(neg) + DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(reciprocal) +#endif + +// These functions are optimized on clang-21+ +#if BF16_ARITHMETIC_SUPPORTED() && (__clang_major__ >= 21) + Vectorized operator==( + const Vectorized& other) const { + return values == other.values; + } + + Vectorized operator!=( + const Vectorized& other) const { + return values != other.values; + } + + Vectorized operator<( + const Vectorized& other) const { + return values < other.values; + } + + Vectorized operator<=( + const Vectorized& other) const { + return values <= other.values; + } + + Vectorized operator>( + const Vectorized& other) const { + return values > other.values; + } + + Vectorized operator>=( + const Vectorized& other) const { + return values >= other.values; + } +#else + DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator==) + DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator!=) + DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator<) + DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator<=) + DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator>) + DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator>=) +#endif + +#undef DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD +#undef DEFINE_BINARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; // Vectorized + +inline std::tuple, Vectorized> convert_bfloat16_float( + const Vectorized& a) { + static_assert( + Vectorized::size() == 2 * Vectorized::size()); + at_bfloat16x8_t x = a; + float32x4_t x1 = + Vectorized::convert_f32_bf16(at_vget_low_bf16(x)); + float32x4_t x2 = + Vectorized::convert_f32_bf16(at_vget_high_bf16(x)); + return {Vectorized(x1), Vectorized(x2)}; +} +inline Vectorized convert_float_bfloat16( + const Vectorized& a, + const Vectorized& b) { + static_assert( + Vectorized::size() == 2 * Vectorized::size()); + at_bfloat16x4_t x1 = Vectorized::convert_bf16_f32(a); + at_bfloat16x4_t x2 = Vectorized::convert_bf16_f32(b); + return Vectorized(at_vcombine_bf16(x1, x2)); +} + +template +Vectorized binary_operator_via_float( + Op op, + const Vectorized& a, + const Vectorized& b) { + const auto [a_float_low, a_float_high] = convert_bfloat16_float(a); + const auto [b_float_low, b_float_high] = convert_bfloat16_float(b); + return convert_float_bfloat16( + op(a_float_low, b_float_low), op(a_float_high, b_float_high)); +} + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { +#if BF16_ARITHMETIC_SUPPORTED() + bfloat16x8_t x = a; + bfloat16x8_t y = b; + return x + y; +#else + return binary_operator_via_float(std::plus>(), a, b); +#endif +} + +template <> +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { +#if BF16_ARITHMETIC_SUPPORTED() + bfloat16x8_t x = a; + bfloat16x8_t y = b; + return x - y; +#else + return binary_operator_via_float(std::minus>(), a, b); +#endif +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { +#if BF16_ARITHMETIC_SUPPORTED() + bfloat16x8_t x = a; + bfloat16x8_t y = b; + return x * y; +#else + return binary_operator_via_float(std::multiplies>(), a, b); +#endif +} + +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { +#if BF16_ARITHMETIC_SUPPORTED() + bfloat16x8_t x = a; + bfloat16x8_t y = b; + return x / y; +#else + return binary_operator_via_float(std::divides>(), a, b); +#endif +} + +// frac. Implement this here so we can use subtraction +inline Vectorized Vectorized::frac() const { + return *this - this->trunc(); +} + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return binary_operator_via_float( + static_cast (*)( + const Vectorized&, const Vectorized&)>(&maximum), + a, + b); +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return binary_operator_via_float( + static_cast (*)( + const Vectorized&, const Vectorized&)>(&minimum), + a, + b); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min, + const Vectorized& max) { + return minimum(max, maximum(min, a)); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max) { + return minimum(max, a); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min) { + return maximum(min, a); +} + +template <> +Vectorized inline operator&( + const Vectorized& a, + const Vectorized& b) { + return Vectorized(at_vreinterpretq_bf16_u16( + vandq_u16(at_vreinterpretq_u16_bf16(a), at_vreinterpretq_u16_bf16(b)))); +} + +template <> +Vectorized inline operator|( + const Vectorized& a, + const Vectorized& b) { + return Vectorized(at_vreinterpretq_bf16_u16( + vorrq_u16(at_vreinterpretq_u16_bf16(a), at_vreinterpretq_u16_bf16(b)))); +} + +template <> +Vectorized inline operator^( + const Vectorized& a, + const Vectorized& b) { + return Vectorized(at_vreinterpretq_bf16_u16( + veorq_u16(at_vreinterpretq_u16_bf16(a), at_vreinterpretq_u16_bf16(b)))); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1); +} + +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1); +} + +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1); +} + +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1); +} + +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1); +} + +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1); +} + +template <> +Vectorized inline fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { +#if BF16_ARITHMETIC_SUPPORTED() + bfloat16x8_t x = a; + bfloat16x8_t y = b; + bfloat16x8_t z = c; + return x * y + z; +#else + // NOTE [BF16 FMA]: There isn't an FMA that accumulates into BF16! Also, + // vbfmlalbq_f32 and vbfmlaltq_f32 take the even and odd-numbered + // elements, not the bottom and top half, so they don't seem + // particularly useful here. Ideally we would include dot product in + // the Vectorized interface... + return a * b + c; +#endif +} + +template <> +Vectorized inline fnmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { +#if BF16_ARITHMETIC_SUPPORTED() + bfloat16x8_t x = a; + bfloat16x8_t y = b; + bfloat16x8_t z = c; + return (-x) * y + z; +#else + // See NOTE [BF16 FMA] above. + return -a * b + c; +#endif +} + +template <> +Vectorized inline fmsub( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { +#if BF16_ARITHMETIC_SUPPORTED() + bfloat16x8_t x = a; + bfloat16x8_t y = b; + bfloat16x8_t z = c; + return x * y - z; +#else + // See NOTE [BF16 FMA] above. + return a * b - c; +#endif +} + +template <> +Vectorized inline fnmsub( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { +#if BF16_ARITHMETIC_SUPPORTED() + bfloat16x8_t x = a; + bfloat16x8_t y = b; + bfloat16x8_t z = c; + return (-x) * y - z; +#else + // See NOTE [BF16 FMA] above. + return -a * b - c; +#endif +} + +#endif // !defined(C10_MOBILE) && defined(__aarch64__) + +} // namespace CPU_CAPABILITY +} // namespace at::vec + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec128/vec128_convert.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec128/vec128_convert.h new file mode 100644 index 0000000000000000000000000000000000000000..da9fb21eb24e3e9ad179fea82ad1ce6d242bc1a3 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec128/vec128_convert.h @@ -0,0 +1,383 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include +#include + +namespace at::vec { +inline namespace CPU_CAPABILITY { +#if (defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256)) + +// Enable auto-vectorization for clang-17+ +// GCC-12 has a bug: gcc.gnu.org/bugzilla/show_bug.cgi?id=117001 +#if defined(__clang__) && (__clang_major__ >= 17) + +template +inline void convertImpl( + const from_type* __restrict src, + to_type* __restrict dst, + int64_t n) { + uint64_t len = static_cast(n); + for (uint64_t i = 0; i < len; i++) { + dst[i] = static_cast(src[i]); + } +} + +template +inline void convertFromBool( + const bool* __restrict src, + to_type* __restrict dst, + int64_t n) { + const uint8_t* srcPtr = reinterpret_cast(src); + uint64_t len = static_cast(n); + for (uint64_t i = 0; i < len; i++) { + dst[i] = srcPtr[i] != 0 ? static_cast(1) : static_cast(0); + } +} + +template +inline void convertToBool( + const from_type* __restrict src, + bool* __restrict dst, + int64_t n) { + uint8_t* dstPtr = reinterpret_cast(dst); + uint64_t len = static_cast(n); + for (uint64_t i = 0; i < len; i++) { + dstPtr[i] = src[i] != static_cast(0) ? 1 : 0; + } +} + +#define CONVERT_TEMPLATE(from_type, to_type) \ + template <> \ + inline void convert(const from_type* src, to_type* dst, int64_t n) { \ + return convertImpl(src, dst, n); \ + } + +#define CONVERT_FROM_BOOL_TEMPLATE(to_type) \ + inline void convert(const bool* src, to_type* dst, int64_t n) { \ + return convertFromBool(src, dst, n); \ + } + +#define CONVERT_TO_BOOL_TEMPLATE(from_type) \ + inline void convert(const from_type* src, bool* dst, int64_t n) { \ + return convertToBool(src, dst, n); \ + } + +CONVERT_TEMPLATE(uint8_t, uint8_t) +CONVERT_TEMPLATE(uint8_t, int8_t) +CONVERT_TEMPLATE(uint8_t, int16_t) +CONVERT_TEMPLATE(uint8_t, int32_t) +CONVERT_TEMPLATE(uint8_t, int64_t) +CONVERT_TEMPLATE(uint8_t, float) +CONVERT_TEMPLATE(uint8_t, double) +CONVERT_TO_BOOL_TEMPLATE(uint8_t) +CONVERT_TEMPLATE(int8_t, uint8_t) +CONVERT_TEMPLATE(int8_t, int8_t) +CONVERT_TEMPLATE(int8_t, int16_t) +CONVERT_TEMPLATE(int8_t, int32_t) +CONVERT_TEMPLATE(int8_t, int64_t) +CONVERT_TEMPLATE(int8_t, float) +CONVERT_TEMPLATE(int8_t, double) +CONVERT_TO_BOOL_TEMPLATE(int8_t) +CONVERT_TEMPLATE(int16_t, uint8_t) +CONVERT_TEMPLATE(int16_t, int8_t) +CONVERT_TEMPLATE(int16_t, int16_t) +CONVERT_TEMPLATE(int16_t, int32_t) +CONVERT_TEMPLATE(int16_t, int64_t) +CONVERT_TEMPLATE(int16_t, float) +CONVERT_TEMPLATE(int16_t, double) +CONVERT_TO_BOOL_TEMPLATE(int16_t) +CONVERT_TEMPLATE(int32_t, uint8_t) +CONVERT_TEMPLATE(int32_t, int8_t) +CONVERT_TEMPLATE(int32_t, int16_t) +CONVERT_TEMPLATE(int32_t, int32_t) +CONVERT_TEMPLATE(int32_t, int64_t) +CONVERT_TEMPLATE(int32_t, float) +CONVERT_TEMPLATE(int32_t, double) +CONVERT_TO_BOOL_TEMPLATE(int32_t) +CONVERT_TEMPLATE(int64_t, uint8_t) +CONVERT_TEMPLATE(int64_t, int8_t) +CONVERT_TEMPLATE(int64_t, int16_t) +CONVERT_TEMPLATE(int64_t, int32_t) +CONVERT_TEMPLATE(int64_t, int64_t) +CONVERT_TEMPLATE(int64_t, float) +CONVERT_TEMPLATE(int64_t, double) +CONVERT_TO_BOOL_TEMPLATE(int64_t) +CONVERT_TEMPLATE(float, uint8_t) +CONVERT_TEMPLATE(float, int8_t) +CONVERT_TEMPLATE(float, int16_t) +CONVERT_TEMPLATE(float, int32_t) +CONVERT_TEMPLATE(float, int64_t) +CONVERT_TEMPLATE(float, float) +CONVERT_TEMPLATE(float, double) +CONVERT_TO_BOOL_TEMPLATE(float) +CONVERT_TEMPLATE(double, uint8_t) +CONVERT_TEMPLATE(double, int8_t) +CONVERT_TEMPLATE(double, int16_t) +CONVERT_TEMPLATE(double, int32_t) +CONVERT_TEMPLATE(double, int64_t) +CONVERT_TEMPLATE(double, float) +CONVERT_TEMPLATE(double, double) +CONVERT_TO_BOOL_TEMPLATE(double) +CONVERT_FROM_BOOL_TEMPLATE(uint8_t) +CONVERT_FROM_BOOL_TEMPLATE(int8_t) +CONVERT_FROM_BOOL_TEMPLATE(int16_t) +CONVERT_FROM_BOOL_TEMPLATE(int32_t) +CONVERT_FROM_BOOL_TEMPLATE(int64_t) +CONVERT_FROM_BOOL_TEMPLATE(float) +CONVERT_FROM_BOOL_TEMPLATE(double) +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + +#define CONVERT_FROM_FP16_TEMPLATE(to_type) \ + template <> \ + inline void convert(const at::Half* src, to_type* dst, int64_t n) { \ + const float16_t* srcPtr = reinterpret_cast(src); \ + return convertImpl(srcPtr, dst, n); \ + } + +#define CONVERT_TO_FP16_TEMPLATE(from_type) \ + template <> \ + inline void convert(const from_type* src, at::Half* dst, int64_t n) { \ + float16_t* dstPtr = reinterpret_cast(dst); \ + return convertImpl(src, dstPtr, n); \ + } + +CONVERT_FROM_FP16_TEMPLATE(uint8_t) +CONVERT_FROM_FP16_TEMPLATE(int8_t) +CONVERT_FROM_FP16_TEMPLATE(int16_t) +CONVERT_FROM_FP16_TEMPLATE(int32_t) +CONVERT_FROM_FP16_TEMPLATE(int64_t) +CONVERT_FROM_FP16_TEMPLATE(float16_t) +CONVERT_FROM_FP16_TEMPLATE(float) +CONVERT_FROM_FP16_TEMPLATE(double) +CONVERT_TO_FP16_TEMPLATE(uint8_t) +CONVERT_TO_FP16_TEMPLATE(int8_t) +CONVERT_TO_FP16_TEMPLATE(int16_t) +CONVERT_TO_FP16_TEMPLATE(int32_t) +CONVERT_TO_FP16_TEMPLATE(int64_t) +CONVERT_TO_FP16_TEMPLATE(float) +CONVERT_TO_FP16_TEMPLATE(double) + +inline void convertBoolToFp16Impl( + const bool* __restrict src, + at::Half* __restrict dst, + int64_t n) { + const uint8_t* srcPtr = reinterpret_cast(src); + float16_t* dstPtr = reinterpret_cast(dst); + uint64_t len = static_cast(n); + for (uint64_t i = 0; i < len; i++) { + dstPtr[i] = srcPtr[i] != 0 ? 1.0 : 0; + } +} + +template <> +inline void convert(const bool* src, at::Half* dst, int64_t n) { + return convertBoolToFp16Impl(src, dst, n); +} + +inline void convertFp16ToBoolImpl( + const at::Half* __restrict src, + bool* __restrict dst, + int64_t n) { + const float16_t* srcPtr = reinterpret_cast(src); + uint8_t* dstPtr = reinterpret_cast(dst); + uint64_t len = static_cast(n); + for (uint64_t i = 0; i < len; i++) { + dstPtr[i] = srcPtr[i] != 0.0 ? 1 : 0; + } +} + +template <> +inline void convert(const at::Half* src, bool* dst, int64_t n) { + return convertFp16ToBoolImpl(src, dst, n); +} + +#endif + +template +inline void convertFromBf16Impl( + const c10::BFloat16* __restrict src, + to_type* __restrict dst, + int64_t n) { + const uint16_t* srcPtr = reinterpret_cast(src); + uint64_t len = static_cast(n); + for (uint64_t i = 0; i < len; i++) { + uint32_t tmp = static_cast(srcPtr[i]) << 16; + float tmpF; + __builtin_memcpy(&tmpF, &tmp, sizeof(float)); + dst[i] = static_cast(tmpF); + } +} +#define CONVERT_FROM_BF16_TEMPLATE(to_type) \ + template <> \ + inline void convert(const c10::BFloat16* src, to_type* dst, int64_t n) { \ + return convertFromBf16Impl(src, dst, n); \ + } + +CONVERT_FROM_BF16_TEMPLATE(uint8_t) +CONVERT_FROM_BF16_TEMPLATE(int8_t) +CONVERT_FROM_BF16_TEMPLATE(int16_t) +CONVERT_FROM_BF16_TEMPLATE(int32_t) +CONVERT_FROM_BF16_TEMPLATE(int64_t) +CONVERT_FROM_BF16_TEMPLATE(float) +CONVERT_FROM_BF16_TEMPLATE(double) +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +CONVERT_FROM_BF16_TEMPLATE(float16_t) +#endif + +#ifdef __ARM_FEATURE_BF16 + +// clang-[17, 20] crashes when autovectorizing static cast to bf16 +// Below is a workaround to have some vectorization +// Works decently well for smaller int types +template +inline void convertToBf16Impl( + const from_type* __restrict src, + c10::BFloat16* __restrict dst, + uint64_t n) { + bfloat16_t* dstPtr = reinterpret_cast(dst); + uint64_t loopBound = n - (n % 16); + uint64_t i = 0; + for (; i < loopBound; i += 16) { + float32x4_t a, b, c, d; + a[0] = static_cast(src[i]); + a[1] = static_cast(src[i + 1]); + a[2] = static_cast(src[i + 2]); + a[3] = static_cast(src[i + 3]); + b[0] = static_cast(src[i + 4]); + b[1] = static_cast(src[i + 5]); + b[2] = static_cast(src[i + 6]); + b[3] = static_cast(src[i + 7]); + c[0] = static_cast(src[i + 8]); + c[1] = static_cast(src[i + 9]); + c[2] = static_cast(src[i + 10]); + c[3] = static_cast(src[i + 11]); + d[0] = static_cast(src[i + 12]); + d[1] = static_cast(src[i + 13]); + d[2] = static_cast(src[i + 14]); + d[3] = static_cast(src[i + 15]); + + vst1q_bf16(dstPtr + i, vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(a), b)); + vst1q_bf16(dstPtr + i + 8, vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(c), d)); + } + +#pragma clang loop vectorize(disable) interleave(disable) unroll(disable) + for (; i < n; i++) { + float a = static_cast(src[i]); + dstPtr[i] = vcvth_bf16_f32(a); + } +} + +#define CONVERT_TO_BF16_TEMPLATE(from_type) \ + template <> \ + inline void convert(const from_type* src, c10::BFloat16* dst, int64_t n) { \ + return convertToBf16Impl(src, dst, n); \ + } + +CONVERT_TO_BF16_TEMPLATE(uint8_t) +CONVERT_TO_BF16_TEMPLATE(int8_t) +CONVERT_TO_BF16_TEMPLATE(int16_t) +CONVERT_TO_BF16_TEMPLATE(int32_t) + +#endif + +inline void convertBoolToBfloat16Impl( + const bool* __restrict src, + c10::BFloat16* __restrict dst, + int64_t n) { + const uint8_t* srcPtr = reinterpret_cast(src); + uint16_t* dstPtr = reinterpret_cast(dst); + uint64_t len = static_cast(n); + constexpr uint16_t kBf16One = 0x3f80; // 1.0 in bfloat16 + for (uint64_t i = 0; i < len; i++) { + dstPtr[i] = srcPtr[i] != 0 ? kBf16One : 0; + } +} + +template <> +inline void convert(const bool* src, c10::BFloat16* dst, int64_t n) { + return convertBoolToBfloat16Impl(src, dst, n); +} + +inline void convertBfloat16ToBoolImpl( + const c10::BFloat16* __restrict src, + bool* __restrict dst, + int64_t n) { + uint8_t* dstPtr = reinterpret_cast(dst); + const uint16_t* srcPtr = reinterpret_cast(src); + uint64_t len = static_cast(n); + for (uint64_t i = 0; i < len; i++) { + // Check if all non-sign bits are 0 + bool isBf16Zero = (srcPtr[i] & 0x7fff) == 0; + dstPtr[i] = isBf16Zero ? 0 : 1; + } +} + +template <> +inline void convert(const c10::BFloat16* src, bool* dst, int64_t n) { + return convertBfloat16ToBoolImpl(src, dst, n); +} + +#endif + +template +struct VecConvert< + float, + 1, + src_t, + 1, + typename std::enable_if_t, void>> { + static inline VectorizedN apply(const VectorizedN& src) { + return convert_int8_half_register_to_float(src[0]); + } +}; +template +struct VecConvert< + float, + 2, + src_t, + 1, + typename std::enable_if_t, void>> { + static inline VectorizedN apply(const VectorizedN& src) { + const auto [v0, v1] = convert_int8_to_float(src[0]); + return VectorizedN(v0, v1); + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + VectorizedN result; + uint16x8_t u16_8 = vld1q_u16(reinterpret_cast(&src[0])); + auto u16_low1 = vget_low_u16(u16_8); + auto u16_high1 = vget_high_u16(u16_8); + float32x4_t f32x4_0 = + vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(u16_low1), 16)); + float32x4_t f32x4_1 = + vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(u16_high1), 16)); + result[0] = f32x4_0; + result[1] = f32x4_1; + return result; + } +}; +// Half register to full register. +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + VectorizedN result; + uint16x4_t u16_8 = vld1_u16(reinterpret_cast(&src[0])); + float32x4_t f32x4_0 = + vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(u16_8), 16)); + result[0] = f32x4_0; + return result; + } +}; + +#endif // defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256) +} // namespace CPU_CAPABILITY +} // namespace at::vec + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec128/vec128_double_neon.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec128/vec128_double_neon.h new file mode 100644 index 0000000000000000000000000000000000000000..f27f9b272224af260be8b9d25ce1b0f2d2f7be90 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec128/vec128_double_neon.h @@ -0,0 +1,591 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include + +namespace at::vec { +// Note [CPU_CAPABILITY namespace] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// This header, and all of its subheaders, will be compiled with +// different architecture flags for each supported set of vector +// intrinsics. So we need to make sure they aren't inadvertently +// linked together. We do this by declaring objects in an `inline +// namespace` which changes the name mangling, but can still be +// accessed as `at::vec`. +inline namespace CPU_CAPABILITY { + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized { + private: + float64x2_t values; + + public: + using value_type = double; + using size_type = int; + static constexpr size_type size() { + return 2; + } + Vectorized() { + values = vdupq_n_f64(0.0); + } + Vectorized(float64x2_t v) : values(v) {} + Vectorized(double val) { + values = vdupq_n_f64(val); + } + template < + typename... Args, + typename = std::enable_if_t<(sizeof...(Args) == size())>> + Vectorized(Args... vals) { + __at_align__ double buffer[size()] = {vals...}; + values = vld1q_f64(buffer); + } + operator float64x2_t() const { + return values; + } + template + static Vectorized blend( + const Vectorized& a, + const Vectorized& b) { + // Build an array of flags: each bit of element is 1 if the corresponding + // bit in 'mask' is set, 0 otherwise. + uint64x2_t maskArray = { + (mask & 1ULL) ? 0xFFFFFFFFFFFFFFFF : 0, + (mask & 2ULL) ? 0xFFFFFFFFFFFFFFFF : 0}; + // Use BSL to select elements from b where the mask is 1, else from a + return vbslq_f64(maskArray, b.values, a.values); + } + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask_) { + return vbslq_f64(vreinterpretq_u64_f64(mask_.values), b.values, a.values); + } + template + static Vectorized arange( + double base = 0., + step_t step = static_cast(1)) { + return {base, base + static_cast(step)}; + } + static inline Vectorized set( + const Vectorized& a, + const Vectorized& b, + int64_t count = size()) { + if (count == 0) { + return a; + } else if (count >= 2) { + return b; + } else { + float64x2_t c = {b.values[0], a.values[1]}; + return c; + } + } + static Vectorized loadu(const void* ptr, int64_t count = size()) { + if (count == size()) { + return vld1q_f64(reinterpret_cast(ptr)); + } else if (count == 1) { + float64x1_t x = vld1_f64(reinterpret_cast(ptr)); + float64x1_t z = {0.0}; + return vcombine_f64(x, z); + } else { + return vdupq_n_f64(0.0); + } + } + void store(void* ptr, int64_t count = size()) const { + if (count == size()) { + vst1q_f64(reinterpret_cast(ptr), values); + } else if (count == 1) { + vst1_f64(reinterpret_cast(ptr), vget_low_f64(values)); + } + } + const double& operator[](int idx) const = delete; + double& operator[](int idx) = delete; + int64_t zero_mask() const { + // returns an integer mask where all zero elements are translated to 1-bit + // and others are translated to 0-bit + uint64x2_t cmpReg = vceqzq_f64(values); + uint64x2_t mask = {1, 2}; + uint64x2_t res = vandq_u64(cmpReg, mask); + return res[0] | res[1]; + } + Vectorized isnan() const { + // NaN check + return vreinterpretq_f64_u32( + vmvnq_u32(vreinterpretq_u32_u64(vceqq_f64(values, values)))); + } + bool has_inf_nan() const { + Vectorized x = vsubq_f64(values, values); + float64x2_t r = x.isnan(); + uint64x2_t u = vreinterpretq_u64_f64(r); + return u[0] | u[1]; + } + Vectorized map(double (*f)(double)) const { + float64x2_t result; + result[0] = f(values[0]); + result[1] = f(values[1]); + return result; + } + Vectorized map2( + const Vectorized& second, + double (*const f)(double, double)) const { + float64x2_t result; + result[0] = f(values[0], second.values[0]); + result[1] = f(values[1], second.values[1]); + return result; + } + Vectorized abs() const { + return vabsq_f64(values); + } + Vectorized angle() const { + auto zero = Vectorized(0.0); + auto pi = Vectorized(c10::pi); + auto tmp = blendv(zero, pi, vreinterpretq_f64_u64(vcltzq_f64(values))); + return blendv(tmp, *this, isnan()); + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return Vectorized(0.0); + } + Vectorized conj() const { + return *this; + } + Vectorized acos() const { + return USE_SLEEF( + Vectorized(Sleef_acosd2_u10(values)), map(std::acos)); + } + Vectorized acosh() const { + return USE_SLEEF( + Vectorized(Sleef_acoshd2_u10(values)), map(std::acosh)); + } + Vectorized asin() const { + return USE_SLEEF( + Vectorized(Sleef_asind2_u10(values)), map(std::asin)); + } + Vectorized asinh() const { + return USE_SLEEF( + Vectorized(Sleef_asinhd2_u10(values)), map(std::asinh)); + } + Vectorized atan() const { + return USE_SLEEF( + Vectorized(Sleef_atand2_u10(values)), map(std::atan)); + } + Vectorized atanh() const { + return USE_SLEEF( + Vectorized(Sleef_atanhd2_u10(values)), map(std::atanh)); + } + Vectorized atan2(const Vectorized& b) const {USE_SLEEF( + { return Vectorized(Sleef_atan2d2_u10(values, b)); }, + { + __at_align__ double tmp[size()]; + __at_align__ double tmp_b[size()]; + store(tmp); + b.store(tmp_b); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = std::atan2(tmp[i], tmp_b[i]); + } + return loadu(tmp); + })} Vectorized copysign(const Vectorized& sign) const { + USE_SLEEF( + { return Vectorized(Sleef_copysignd2(values, sign)); }, + { + __at_align__ double tmp[size()]; + __at_align__ double tmp_sign[size()]; + store(tmp); + sign.store(tmp_sign); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = std::copysign(tmp[i], tmp_sign[i]); + } + return loadu(tmp); + })} Vectorized erf() const { + return USE_SLEEF( + Vectorized(Sleef_erfd2_u10(values)), map(std::erf)); + } + Vectorized erfc() const { + return USE_SLEEF( + Vectorized(Sleef_erfcd2_u15(values)), map(std::erfc)); + } + Vectorized exp() const { + return USE_SLEEF( + Vectorized(Sleef_expd2_u10(values)), map(std::exp)); + } + Vectorized exp2() const { + return USE_SLEEF( + Vectorized(Sleef_exp2d2_u10(values)), map(std::exp2)); + } + Vectorized expm1() const { + return USE_SLEEF( + Vectorized(Sleef_expm1d2_u10(values)), map(std::expm1)); + } + Vectorized fmod(const Vectorized& q) const {USE_SLEEF( + { return Vectorized(Sleef_fmodd2(values, q)); }, + { + __at_align__ double tmp[size()]; + __at_align__ double tmp_q[size()]; + store(tmp); + q.store(tmp_q); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = std::fmod(tmp[i], tmp_q[i]); + } + return loadu(tmp); + })} Vectorized hypot(const Vectorized& b) const { + USE_SLEEF( + { return Vectorized(Sleef_hypotd2_u05(values, b)); }, + { + __at_align__ double tmp[size()]; + __at_align__ double tmp_b[size()]; + store(tmp); + b.store(tmp_b); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = std::hypot(tmp[i], tmp_b[i]); + } + return loadu(tmp); + })} Vectorized i0() const { + return map(calc_i0); + } + Vectorized nextafter(const Vectorized& b) const {USE_SLEEF( + { return Vectorized(Sleef_nextafterd2(values, b)); }, + { + __at_align__ double tmp[size()]; + __at_align__ double tmp_b[size()]; + store(tmp); + b.store(tmp_b); + for (int64_t i = 0; i < size(); ++i) { + tmp[i] = std::nextafter(tmp[i], tmp_b[i]); + } + return loadu(tmp); + })} Vectorized log() const { + return USE_SLEEF( + Vectorized(Sleef_logd2_u10(values)), map(std::log)); + } + Vectorized log2() const { + return USE_SLEEF( + Vectorized(Sleef_log2d2_u10(values)), map(std::log2)); + } + Vectorized log10() const { + return USE_SLEEF( + Vectorized(Sleef_log10d2_u10(values)), map(std::log10)); + } + Vectorized log1p() const { + return USE_SLEEF( + Vectorized(Sleef_log1pd2_u10(values)), map(std::log1p)); + } + Vectorized frac() const; + Vectorized sin() const { + return USE_SLEEF( + Vectorized(Sleef_sind2_u10(values)), map(std::sin)); + } + Vectorized sinh() const { + return USE_SLEEF( + Vectorized(Sleef_sinhd2_u10(values)), map(std::sinh)); + } + Vectorized cos() const { + return USE_SLEEF( + Vectorized(Sleef_cosd2_u10(values)), map(std::cos)); + } + Vectorized cosh() const { + return USE_SLEEF( + Vectorized(Sleef_coshd2_u10(values)), map(std::cosh)); + } + Vectorized pow(const Vectorized& b) const {USE_SLEEF( + { return Vectorized(Sleef_powd2_u10(values, b)); }, + { + __at_align__ double tmp[size()]; + __at_align__ double tmp_b[size()]; + store(tmp); + b.store(tmp_b); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = std::pow(tmp[i], tmp_b[i]); + } + return loadu(tmp); + })} // Comparison using the _CMP_**_OQ predicate. + // `O`: get false if an operand is NaN + // `Q`: do not raise if an operand is NaN + Vectorized tan() const { + return USE_SLEEF( + Vectorized(Sleef_tand2_u10(values)), map(std::tan)); + } + Vectorized tanh() const { + return USE_SLEEF( + Vectorized(Sleef_tanhd2_u10(values)), map(std::tanh)); + } + Vectorized lgamma() const { + return USE_SLEEF( + Vectorized(Sleef_lgammad2_u10(values)), map(std::lgamma)); + } + Vectorized erfinv() const { + return map(calc_erfinv); + } + Vectorized exp_u20() const { + return exp(); + } + Vectorized fexp_u20() const { + return exp(); + } + Vectorized i0e() const { + return map(calc_i0e); + } + Vectorized digamma() const { + return map(calc_digamma); + } + Vectorized igamma(const Vectorized& x) const { + __at_align__ double tmp[size()]; + __at_align__ double tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = calc_igamma(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } + Vectorized igammac(const Vectorized& x) const { + __at_align__ double tmp[size()]; + __at_align__ double tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = calc_igammac(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } + Vectorized ceil() const { + return vrndpq_f64(values); + } + Vectorized floor() const { + return vrndmq_f64(values); + } + Vectorized neg() const { + return vnegq_f64(values); + } + Vectorized round() const { + return vrndiq_f64(values); + } + Vectorized trunc() const { + return vrndq_f64(values); + } + Vectorized sqrt() const { + return vsqrtq_f64(values); + } + Vectorized reciprocal() const { + return vdivq_f64(vdupq_n_f64(1.0), values); + } + Vectorized rsqrt() const { + return vdivq_f64(vdupq_n_f64(1.0), vsqrtq_f64(values)); + } + double reduce_add() const { + return vaddvq_f64(values); + } + double reduce_max() const { + return vmaxvq_f64(values); + } + Vectorized operator==(const Vectorized& other) const { + return Vectorized( + vreinterpretq_f64_u64(vceqq_f64(values, other.values))); + } + + Vectorized operator!=(const Vectorized& other) const { + float64x2_t r0 = vreinterpretq_f64_u32( + vmvnq_u32(vreinterpretq_u32_u64(vceqq_f64(values, other.values)))); + return Vectorized(r0); + } + + Vectorized operator<(const Vectorized& other) const { + return Vectorized( + vreinterpretq_f64_u64(vcltq_f64(values, other.values))); + } + + Vectorized operator<=(const Vectorized& other) const { + return Vectorized( + vreinterpretq_f64_u64(vcleq_f64(values, other.values))); + } + + Vectorized operator>(const Vectorized& other) const { + return Vectorized( + vreinterpretq_f64_u64(vcgtq_f64(values, other.values))); + } + + Vectorized operator>=(const Vectorized& other) const { + return Vectorized( + vreinterpretq_f64_u64(vcgeq_f64(values, other.values))); + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return vaddq_f64(a, b); +} + +template <> +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return vsubq_f64(a, b); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return vmulq_f64(a, b); +} + +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return vdivq_f64(a, b); +} + +// frac. Implement this here so we can use subtraction +Vectorized inline Vectorized::frac() const { + return *this - this->trunc(); +} + +// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return vmaxq_f64(a, b); +} + +// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return vminq_f64(a, b); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min, + const Vectorized& max) { + return vminq_f64(max, vmaxq_f64(min, a)); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max) { + return vminq_f64(max, a); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min) { + return vmaxq_f64(min, a); +} + +template <> +Vectorized inline operator&( + const Vectorized& a, + const Vectorized& b) { + return vreinterpretq_f64_u64( + vandq_u64(vreinterpretq_u64_f64(a), vreinterpretq_u64_f64(b))); +} + +template <> +Vectorized inline operator|( + const Vectorized& a, + const Vectorized& b) { + return vreinterpretq_f64_u64( + vorrq_u64(vreinterpretq_u64_f64(a), vreinterpretq_u64_f64(b))); +} + +template <> +Vectorized inline operator^( + const Vectorized& a, + const Vectorized& b) { + return vreinterpretq_f64_u64( + veorq_u64(vreinterpretq_u64_f64(a), vreinterpretq_u64_f64(b))); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1.0); +} + +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1.0); +} + +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1.0); +} + +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1.0); +} + +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1.0); +} + +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1.0); +} + +template <> +Vectorized inline fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return vfmaq_f64(c, a, b); +} + +template <> +Vectorized inline fnmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return vfmsq_f64(c, a, b); +} + +template <> +Vectorized inline fmsub( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return vfmaq_f64(vnegq_f64(c), a, b); +} + +template <> +Vectorized inline fnmsub( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return vfmsq_f64(vnegq_f64(c), a, b); +} + +} // namespace CPU_CAPABILITY +} // namespace at::vec + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec128/vec128_float_neon.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec128/vec128_float_neon.h new file mode 100644 index 0000000000000000000000000000000000000000..c6f047f86fc4f62fc82e24506f688e7d39a92214 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec128/vec128_float_neon.h @@ -0,0 +1,661 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include +#include + +#if defined(__aarch64__) && defined(AT_BUILD_ARM_VEC256_WITH_SLEEF) +#include +#endif + +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-default") + +// Sleef offers vectorized versions of some transcedentals +// such as sin, cos, tan etc.. +// However for now opting for STL, since we are not building +// with Sleef for mobile yet. + +namespace at::vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +// Right now contains only aarch64 implementation. +// Due to follow two reasons aarch32 is not currently supported. +// 1. Due to difference in ISA been aarch32 and aarch64, intrinsics +// that work for aarch64 dont work for aarch32. +// 2. Android NDK r21 has problems with compiling aarch32. +// Clang seg faults. +// https://github.com/android/ndk/issues/1248 +// https://bugs.llvm.org/show_bug.cgi?id=45824 +// Most likely we will do aarch32 support with inline asm. +#if defined(__aarch64__) + +#ifdef __BIG_ENDIAN__ +#error "Big endian is not supported." +#endif + +#if defined(AT_BUILD_ARM_VEC256_WITH_SLEEF) +#define USE_SLEEF(sleef_code, non_sleef_code) sleef_code +#else +#define USE_SLEEF(sleef_code, non_sleef_code) non_sleef_code +#endif + +template +struct BlendRegs { + static float32x4_t impl( + const float32x4_t& a, + const float32x4_t& b, + float32x4_t& res); +}; + +template +struct BlendRegs { + static float32x4_t impl( + const float32x4_t& a, + const float32x4_t& b, + float32x4_t& res) { + return vsetq_lane_f32(vgetq_lane_f32(b, index), res, index); + } +}; + +template +struct BlendRegs { + static float32x4_t impl( + const float32x4_t& a, + const float32x4_t& b, + float32x4_t& res) { + return vsetq_lane_f32(vgetq_lane_f32(a, index), res, index); + } +}; + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized { + private: + float32x4_t values; + + public: + using value_type = float; + using size_type = int; + static constexpr size_type size() { + return 4; + } + Vectorized() { + values = vmovq_n_f32(0); + } + Vectorized(float32x4_t v) : values(v) {} + Vectorized(float val) : values{vdupq_n_f32(val)} {} + Vectorized(float val0, float val1, float val2, float val3) + : values{val0, val1, val2, val3} {} + Vectorized(float (&arr)[4]) : Vectorized(arr[0], arr[1], arr[2], arr[3]) {} + operator float32x4_t() const { + return values; + } + template + static Vectorized blend( + const Vectorized& a, + const Vectorized& b) { + Vectorized vec; + vec.values = BlendRegs < 0, + (mask & 0x01) != 0 > ::impl(a.values, b.values, vec.values); + vec.values = BlendRegs < 1, + (mask & 0x02) != 0 > ::impl(a.values, b.values, vec.values); + vec.values = BlendRegs < 2, + (mask & 0x04) != 0 > ::impl(a.values, b.values, vec.values); + vec.values = BlendRegs < 3, + (mask & 0x08) != 0 > ::impl(a.values, b.values, vec.values); + return vec; + } + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + // TODO + // NB: This requires that each value, i.e., each uint value, + // of the mask either all be zeros or all be 1s. + // We perhaps need some kind of an assert? + // But that will affect performance. + Vectorized vec(mask.values); + vec.values = + vbslq_f32(vreinterpretq_u32_f32(vec.values), b.values, a.values); + return vec; + } + template + static Vectorized arange( + float base = 0.f, + step_t step = static_cast(1)) { + const Vectorized base_vec(base); + const Vectorized step_vec(step); + const Vectorized step_sizes(0, 1, 2, 3); + return fmadd(step_sizes, step_vec, base_vec); + } + static Vectorized set( + const Vectorized& a, + const Vectorized& b, + int64_t count = size()) { + switch (count) { + case 0: + return a; + case 1: { + Vectorized vec; + static uint32x4_t mask_low = {0xFFFFFFFF, 0x0, 0x0, 0x0}; + vec.values = vreinterpretq_f32_u32(mask_low); + vec.values = + vbslq_f32(vreinterpretq_u32_f32(vec.values), b.values, a.values); + return vec; + } + case 2: { + Vectorized vec; + static uint32x4_t mask_low = {0xFFFFFFFF, 0xFFFFFFFF, 0x0, 0x0}; + vec.values = vreinterpretq_f32_u32(mask_low); + vec.values = + vbslq_f32(vreinterpretq_u32_f32(vec.values), b.values, a.values); + return vec; + } + case 3: { + Vectorized vec; + static uint32x4_t mask_low = {0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x0}; + vec.values = vreinterpretq_f32_u32(mask_low); + vec.values = + vbslq_f32(vreinterpretq_u32_f32(vec.values), b.values, a.values); + return vec; + } + } + return b; + } + static Vectorized loadu(const void* ptr, int64_t count = size()) { + if (count == size()) { + return vld1q_f32(reinterpret_cast(ptr)); + } else { + __at_align__ float tmp_values[size()]; + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0.0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(float)); + return vld1q_f32(reinterpret_cast(tmp_values)); + } + } + void store(void* ptr, int64_t count = size()) const { + if (count == size()) { + vst1q_f32(reinterpret_cast(ptr), values); + } else { + float tmp_values[size()]; + vst1q_f32(reinterpret_cast(tmp_values), values); + std::memcpy(ptr, tmp_values, count * sizeof(float)); + } + } + // Very slow implementation of indexing. + // Only required because vec256_qint refers to this. + // Once we specialize that implementation for ARM + // this should be removed. TODO (kimishpatel) + float operator[](int idx) const { + __at_align__ float tmp[size()]; + store(tmp); + return tmp[idx]; + } + float operator[](int idx) { + __at_align__ float tmp[size()]; + store(tmp); + return tmp[idx]; + } + int zero_mask() const { + uint32x4_t is_zero_vec = vceqzq_f32(values); + const int32x4_t shift = vcombine_s32( + vcreate_s32(0x0 | (int64_t(0x1) << 32)), + vcreate_s32(0x2 | (int64_t(0x3) << 32))); + uint32x4_t bits_vec = + vshlq_u32(vandq_u32(is_zero_vec, vdupq_n_u32(1)), shift); + return vaddvq_u32(bits_vec); + } + Vectorized isnan() const { + return vreinterpretq_f32_u32(vmvnq_u32(vceqq_f32(values, values))); + } + bool has_inf_nan() const { + __at_align__ float tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + if (_isnan(tmp[i]) || _isinf(tmp[i])) { + return true; + } + } + return false; + } + Vectorized map(float (*const f)(float)) const { + __at_align__ float tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + Vectorized map2( + const Vectorized& second, + float (*const f)(float, float)) const { + __at_align__ float tmp[size()]; + __at_align__ float tmp_second[size()]; + store(tmp); + second.store(tmp_second); + for (const auto i : c10::irange(size())) { + tmp[i] = f(tmp[i], tmp_second[i]); + } + return loadu(tmp); + } + Vectorized abs() const { + return Vectorized(vabsq_f32(values)); + } + Vectorized angle() const { + auto zero = Vectorized(0); + auto pi = Vectorized(c10::pi); + auto tmp = blendv(zero, pi, *this < zero); + return blendv(tmp, *this, isnan()); + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return Vectorized(0.f); + } + Vectorized conj() const { + return *this; + } +#define DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC_WITH_SLEEF_NAME( \ + name, sleef_name) \ + Vectorized name() const { \ + return USE_SLEEF(Vectorized(sleef_name(values)), map(std::name)); \ + } + +#define DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(name) \ + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC_WITH_SLEEF_NAME( \ + name, Sleef_##name##f4_u10) + + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(acos) + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(acosh) + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(asin) + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(asinh) + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(atan) + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(atanh) + +#define DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC_WITH_SLEEF_NAME( \ + name, sleef_name) \ + Vectorized name(const Vectorized& arg) const { \ + return USE_SLEEF( \ + Vectorized(sleef_name(values, arg.values)), \ + map2(arg, std::name)); \ + } + +#define DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC(name) \ + DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC_WITH_SLEEF_NAME( \ + name, Sleef_##name##f4_u10) + + DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC(atan2) + DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC_WITH_SLEEF_NAME( + copysign, + Sleef_copysignf4) + Vectorized erf() const; + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC_WITH_SLEEF_NAME( + erfc, + Sleef_erfcf4_u15) + Vectorized erfinv() const { + return map(calc_erfinv); + } + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(exp) + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(exp2) + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(expm1) + // Implementation copied from Arm Optimized Routine + // https://github.com/ARM-software/optimized-routines/blob/master/math/aarch64/advsimd/expf.c + inline Vectorized vexpq_f32_u20() const { + // bail out to sleef if it's a special case: + // i.e. there's an input s.t. |input| > 87.3.... + const float32x4_t special_bound = vdupq_n_f32(0x1.5d5e2ap+6f); + uint32x4_t cmp = vcagtq_f32(values, special_bound); + if (vpaddd_u64(vreinterpretq_u64_u32(cmp)) != 0) { + return exp(); + } + + const float32x4_t inv_ln2 = vdupq_n_f32(0x1.715476p+0f); + const float ln2_hi = 0x1.62e4p-1f; + const float ln2_lo = 0x1.7f7d1cp-20f; + const float c0 = 0x1.0e4020p-7f; + const float c2 = 0x1.555e66p-3f; + const float32x4_t ln2_c02 = {ln2_hi, ln2_lo, c0, c2}; + + const uint32x4_t exponent_bias = vdupq_n_u32(0x3f800000); + const float32x4_t c1 = vdupq_n_f32(0x1.573e2ep-5f); + const float32x4_t c3 = vdupq_n_f32(0x1.fffdb6p-2f); + const float32x4_t c4 = vdupq_n_f32(0x1.ffffecp-1f); + + /* exp(x) = 2^n (1 + poly(r)), with 1 + poly(r) in [1/sqrt(2),sqrt(2)] + x = ln2*n + r, with r in [-ln2/2, ln2/2]. */ + + float32x4_t n = vrndaq_f32(vmulq_f32(values, inv_ln2)); + float32x4_t r = vfmsq_laneq_f32(values, n, ln2_c02, 0); + r = vfmsq_laneq_f32(r, n, ln2_c02, 1); + uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_s32(vcvtq_s32_f32(n)), 23); + float32x4_t scale = vreinterpretq_f32_u32(vaddq_u32(e, exponent_bias)); + + float32x4_t r2 = vmulq_f32(r, r); + float32x4_t p = vfmaq_laneq_f32(c1, r, ln2_c02, 2); + float32x4_t q = vfmaq_laneq_f32(c3, r, ln2_c02, 3); + q = vfmaq_f32(q, p, r2); + p = vmulq_f32(c4, r); + float32x4_t poly = vfmaq_f32(p, q, r2); + + return vfmaq_f32(scale, poly, scale); + } + Vectorized exp_u20() const { + return vexpq_f32_u20(); + } + Vectorized fexp_u20() const { + return exp_u20(); + } + DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC_WITH_SLEEF_NAME( + fmod, + Sleef_fmodf4) + DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC_WITH_SLEEF_NAME( + hypot, + Sleef_hypotf4_u05) + Vectorized i0() const { + return map(calc_i0); + } + Vectorized i0e() const { + return map(calc_i0e); + } + Vectorized digamma() const { + return map(calc_digamma); + } + Vectorized igamma(const Vectorized& x) const { + return map2(x, calc_igamma); + } + Vectorized igammac(const Vectorized& x) const { + return map2(x, calc_igammac); + } + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(log) + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(log10) + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(log1p) + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(log2) + DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC_WITH_SLEEF_NAME( + nextafter, + Sleef_nextafterf4) + Vectorized frac() const; + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(sin) + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(sinh) + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(cos) + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(cosh) + Vectorized ceil() const { + return map(at::native::ceil_impl); + } + Vectorized floor() const { + return map(at::native::floor_impl); + } + Vectorized neg() const { + return Vectorized(vnegq_f32(values)); + } + Vectorized round() const { + // We do not use std::round because we would like to round midway numbers to + // the nearest even integer. + return map(at::native::round_impl); + } + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(tan) + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(tanh) + Vectorized trunc() const { + return Vectorized(vrndq_f32(values)); + } + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(lgamma) + Vectorized sqrt() const { + return Vectorized(vsqrtq_f32(values)); + } + Vectorized reciprocal() const { + return Vectorized(vdivq_f32(vdupq_n_f32(1.0f), values)); + } + Vectorized rsqrt() const { + return this->sqrt().reciprocal(); + } + DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC(pow) + Vectorized operator==(const Vectorized& other) const { + return Vectorized( + vreinterpretq_f32_u32(vceqq_f32(values, other.values))); + } + + Vectorized operator!=(const Vectorized& other) const { + float32x4_t r0 = + vreinterpretq_f32_u32(vmvnq_u32(vceqq_f32(values, other.values))); + return Vectorized(r0); + } + + Vectorized operator<(const Vectorized& other) const { + return Vectorized( + vreinterpretq_f32_u32(vcltq_f32(values, other.values))); + } + + Vectorized operator<=(const Vectorized& other) const { + return Vectorized( + vreinterpretq_f32_u32(vcleq_f32(values, other.values))); + } + + Vectorized operator>(const Vectorized& other) const { + return Vectorized( + vreinterpretq_f32_u32(vcgtq_f32(values, other.values))); + } + + Vectorized operator>=(const Vectorized& other) const { + return Vectorized( + vreinterpretq_f32_u32(vcgeq_f32(values, other.values))); + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return Vectorized(vaddq_f32(a, b)); +} + +template <> +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return Vectorized(vsubq_f32(a, b)); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return Vectorized(vmulq_f32(a, b)); +} + +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return Vectorized(vdivq_f32(a, b)); +} + +// frac. Implement this here so we can use subtraction +inline Vectorized Vectorized::frac() const { + return *this - this->trunc(); +} + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return Vectorized(vmaxq_f32(a, b)); +} + +// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return Vectorized(vminq_f32(a, b)); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min, + const Vectorized& max) { + return minimum(max, maximum(min, a)); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max) { + return minimum(max, a); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min) { + return maximum(min, a); +} + +template <> +Vectorized inline operator&( + const Vectorized& a, + const Vectorized& b) { + return Vectorized(vreinterpretq_f32_u32( + vandq_u32(vreinterpretq_u32_f32(a), vreinterpretq_u32_f32(b)))); +} + +template <> +Vectorized inline operator|( + const Vectorized& a, + const Vectorized& b) { + return Vectorized(vreinterpretq_f32_u32( + vorrq_u32(vreinterpretq_u32_f32(a), vreinterpretq_u32_f32(b)))); +} + +template <> +Vectorized inline operator^( + const Vectorized& a, + const Vectorized& b) { + return Vectorized(vreinterpretq_f32_u32( + veorq_u32(vreinterpretq_u32_f32(a), vreinterpretq_u32_f32(b)))); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1.0f); +} + +template <> +Vectorized inline fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return Vectorized(vfmaq_f32(c, a, b)); +} + +template <> +Vectorized inline fnmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return Vectorized(vfmsq_f32(c, a, b)); +} + +template <> +Vectorized inline fmsub( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return Vectorized(vnegq_f32(vfmsq_f32(c, a, b))); +} + +template <> +Vectorized inline fnmsub( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return Vectorized(vnegq_f32(vfmaq_f32(c, a, b))); +} + +inline Vectorized Vectorized::erf() const { + // constants + const Vectorized neg_zero_vec(-0.f); + const Vectorized one_vec(1.0f); + const Vectorized p(0.3275911f); + const Vectorized p1(0.254829592f); + const Vectorized p2(-0.284496736f); + const Vectorized p3(1.421413741f); + const Vectorized p4(-1.453152027f); + const Vectorized p5(1.061405429f); + // sign(x) + auto sign_mask = neg_zero_vec & *this; + auto abs_vec = this->abs(); + // t = 1 / (p * abs(x) + 1) + auto tmp0 = fmadd(p, abs_vec, one_vec); + auto t = one_vec / tmp0; + // r = p5 * t ^ 4 + p4 * t ^ 3 + p3 * t ^ 2 + p2 * t + p1 + auto tmp1 = fmadd(p5, t, p4); + auto tmp2 = fmadd(tmp1, t, p3); + auto tmp3 = fmadd(tmp2, t, p2); + auto r = fmadd(tmp3, t, p1); + // - exp(- x * x) + auto pow_2 = (*this) * (*this); + auto neg_pow_2 = pow_2 ^ neg_zero_vec; + auto tmp4 = neg_pow_2.vexpq_f32_u20(); + auto tmp5 = tmp4 ^ neg_zero_vec; + // erf(x) = sign(x) * (1 - r * t * exp(- x * x)) + auto tmp6 = t * tmp5; + auto tmp7 = fmadd(tmp6, r, one_vec); + return tmp7 ^ sign_mask; +} +#undef DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC +#undef DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC +#endif /* defined(aarch64) */ + +} // namespace CPU_CAPABILITY +} // namespace at::vec + +C10_DIAGNOSTIC_POP() + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec128/vec128_half_neon.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec128/vec128_half_neon.h new file mode 100644 index 0000000000000000000000000000000000000000..ad49d388341c6e8f470bff7fde35ea404e0b83de --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec128/vec128_half_neon.h @@ -0,0 +1,627 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include +#include +#include +#include +#include +#include + +namespace at::vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +// Right now contains only aarch64 implementation. +// Due to follow two reasons aarch32 is not currently supported. +// 1. Due to difference in ISA been aarch32 and aarch64, intrinsics +// that work for aarch64 dont work for aarch32. +// 2. Android NDK r21 has problems with compiling aarch32. +// Clang seg faults. +// https://github.com/android/ndk/issues/1248 +// https://bugs.llvm.org/show_bug.cgi?id=45824 +// Most likely we will do aarch32 support with inline asm. +#if !defined(C10_MOBILE) && defined(__aarch64__) + +#ifdef __BIG_ENDIAN__ +#error "Big endian is not supported." +#endif + +template +struct BlendHalfRegs { + static float16x8_t impl( + const float16x8_t& a, + const float16x8_t& b, + float16x8_t& res); +}; + +template +struct BlendHalfRegs { + static float16x8_t impl( + const float16x8_t& a, + const float16x8_t& b, + float16x8_t& res) { + return vsetq_lane_f16(vgetq_lane_f16(b, index), res, index); + } +}; + +template +struct BlendHalfRegs { + static float16x8_t impl( + const float16x8_t& a, + const float16x8_t& b, + float16x8_t& res) { + return vsetq_lane_f16(vgetq_lane_f16(a, index), res, index); + } +}; + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +// On ARM, Half type supports float16_t->Half constructor and Half->float16_t +// conversion +template <> +class Vectorized : public Vectorized16< + float16x8_t, + c10::Half, + BlendHalfRegs, + Vectorized> { + using Base = Vectorized16< + float16x8_t, + c10::Half, + BlendHalfRegs, + Vectorized>; + friend Base; + + private: + // We use these private map functions to implement various methods + Vectorized map_with_vec_float_method( + Vectorized (Vectorized::*m)() const) const { + float32x4_t v00 = vcvt_f32_f16(vget_low_f16(values)); + float32x4_t v01 = vcvt_f32_f16(vget_high_f16(values)); + Vectorized mv0 = (Vectorized(v00).*m)(); + Vectorized mv1 = (Vectorized(v01).*m)(); + float16x4_t r00 = vcvt_f16_f32(mv0); + float16x4_t r01 = vcvt_f16_f32(mv1); + return Vectorized(vcombine_f16(r00, r01)); + } + + Vectorized map2_with_vec_float_method( + const Vectorized& second, + Vectorized (Vectorized::*m)(const Vectorized&) + const) const { + float32x4_t v00 = vcvt_f32_f16(vget_low_f16(values)); + float32x4_t v01 = vcvt_f32_f16(vget_high_f16(values)); + float32x4_t second_v00 = vcvt_f32_f16(vget_low_f16(second.values)); + float32x4_t second_v01 = vcvt_f32_f16(vget_high_f16(second.values)); + Vectorized mv0 = + (Vectorized(v00).*m)(Vectorized(second_v00)); + Vectorized mv1 = + (Vectorized(v01).*m)(Vectorized(second_v01)); + float16x4_t r00 = vcvt_f16_f32(mv0); + float16x4_t r01 = vcvt_f16_f32(mv1); + + // Pack result into Vectorized + return Vectorized(vcombine_f16(r00, r01)); + } + + Vectorized map2_bitmask_with_vec_float_method( + const Vectorized& second, + Vectorized (Vectorized::*m)(const Vectorized&) + const) const { + float32x4_t v00 = vcvt_f32_f16(vget_low_f16(values)); + float32x4_t v01 = vcvt_f32_f16(vget_high_f16(values)); + float32x4_t second_v00 = vcvt_f32_f16(vget_low_f16(second.values)); + float32x4_t second_v01 = vcvt_f32_f16(vget_high_f16(second.values)); + Vectorized mv0 = + (Vectorized(v00).*m)(Vectorized(second_v00)); + Vectorized mv1 = + (Vectorized(v01).*m)(Vectorized(second_v01)); + // Assume the operator returns a bitmask, not "real" floats, and + // just narrow the bits. All-ones is a NaN and will get mangled by + // conversion! + float16x4_t r00 = + vreinterpret_f16_u16(vmovn_u32(vreinterpretq_u32_f32(mv0))); + float16x4_t r01 = + vreinterpret_f16_u16(vmovn_u32(vreinterpretq_u32_f32(mv1))); + + // Pack result into Vectorized + return Vectorized(vcombine_f16(r00, r01)); + } + + public: + using Vectorized16::Vectorized16; + + Vectorized() = default; + + // A ctor that accepts c10::Half is needed to fit interface with vec_base.h + // A second constructor that takes float16_t is also included + Vectorized(c10::Half val) : Vectorized((float16_t)val) {} + Vectorized(float16_t val) : Vectorized16(vdupq_n_f16(val)) {} + Vectorized( + value_type val0, + value_type val1, + value_type val2, + value_type val3, + value_type val4, + value_type val5, + value_type val6, + value_type val7) + : Vectorized16( + float16x8_t{val0, val1, val2, val3, val4, val5, val6, val7}) {} + + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + // Note: using blendv is very awkward because 0xFFFF is one of + // many NaN's in FP16 It's unfortunate that the mask has type Half + // (required from vec_base) + + // TODO + // NB: This requires that each value, i.e., each uint value, + // of the mask either all be zeros or all be 1s. + // We perhaps need some kind of an assert? + // But that will affect performance. + + // NOTE [vbslq_f16]: vbslq_f16 doesn't work on clang without + // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC. vbslq_u16 generates the + // same instruction anyway. see https://godbolt.org/z/cY4a55Y7P + Vectorized vec(mask.values); + vec.values = vreinterpretq_f16_u16(vbslq_u16( + vreinterpretq_u16_f16(vec.values), + vreinterpretq_u16_f16(b.values), + vreinterpretq_u16_f16(a.values))); + return vec; + } + static Vectorized set( + const Vectorized& a, + const Vectorized& b, + int64_t count = size()) { + uint16_t pre_mask[size()] = {0}; + for (int i = 0; i < count; i++) { + pre_mask[i] = 0xFFFF; + } + uint16x8_t mask = vld1q_u16(pre_mask); + + // Using blendv is awkward because 0xFFFF is one of many NaN's in FP16 + // so we directly use vbslq_u16 instead. (See NOTE [vbslq_f16] above.) + Vectorized vec(vreinterpretq_f16_u16(vbslq_u16( + mask, + vreinterpretq_u16_f16(b.values), + vreinterpretq_u16_f16(a.values)))); + + return vec; + } + static Vectorized loadu(const void* ptr, int64_t count = size()) { + if (count == size()) { + return vld1q_f16(reinterpret_cast(ptr)); + } + __at_align__ float16_t tmp_values[size()]; + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(float16_t)); + return vld1q_f16(reinterpret_cast(tmp_values)); + } + void store(void* ptr, int64_t count = size()) const { + if (count == size()) { + vst1q_f16(reinterpret_cast(ptr), values); + return; + } else { + float16_t tmp_values[size()]; + vst1q_f16(reinterpret_cast(tmp_values), values); + std::memcpy(ptr, tmp_values, count * sizeof(float16_t)); + } + } + int zero_mask() const { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + uint16x8_t is_zero_vec = vceqzq_f16(values); + const int16x8_t shift = vcombine_s16( + vcreate_s16( + 0x0 | (int64_t(0x1) << 16) | (int64_t(0x2) << 32) | + (int64_t(0x3) << 48)), + vcreate_s16( + 0x4 | (int64_t(0x5) << 16) | (int64_t(0x6) << 32) | + (int64_t(0x7) << 48))); + uint16x8_t bits_vec = + vshlq_u16(vandq_u16(is_zero_vec, vdupq_n_u16(1)), shift); + return vaddvq_u16(bits_vec); +#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + // use known working implementation. + __at_align__ value_type tmp[size()]; + store(tmp); + int mask = 0; + for (int i = 0; i < size(); ++i) { + if (tmp[i] == 0) { + mask |= (1 << i); + } + } + return mask; +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + } + Vectorized isnan() const { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + return vreinterpretq_f16_u16(vmvnq_u16(vceqq_f16(values, values))); +#else + // NOTE: we could make this faster by doing vectorized checks of + // exponent/payload bits. + __at_align__ c10::Half tmp[size()]; + __at_align__ c10::Half res[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + if (_isnan(tmp[i])) { + std::memset(static_cast(&res[i]), 0xFF, sizeof(c10::Half)); + } else { + std::memset(static_cast(&res[i]), 0, sizeof(c10::Half)); + } + } + return loadu(res); +#endif + } + bool has_inf_nan() const { + __at_align__ c10::Half tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + if (_isnan(tmp[i]) || _isinf(tmp[i])) { + return true; + } + } + return false; + } + Vectorized abs() const { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + return Vectorized(vabsq_f16(values)); +#else + return map_with_vec_float_method(&Vectorized::abs); +#endif + } + Vectorized frac() const; + Vectorized neg() const { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + return Vectorized(vnegq_f16(values)); +#else + return map_with_vec_float_method(&Vectorized::neg); +#endif + } + Vectorized trunc() const { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + return Vectorized(vrndq_f16(values)); +#else + return map_with_vec_float_method(&Vectorized::trunc); +#endif + } + Vectorized sqrt() const { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + return Vectorized(vsqrtq_f16(values)); +#else + return map_with_vec_float_method(&Vectorized::sqrt); +#endif + } + Vectorized reciprocal() const { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + auto ones = vdupq_n_f16(1.0f); + return Vectorized(vdivq_f16(ones, values)); +#else + return map_with_vec_float_method(&Vectorized::reciprocal); +#endif + } + Vectorized operator==(const Vectorized& other) const { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + return Vectorized( + vreinterpretq_f16_u16(vceqq_f16(values, other.values))); +#else + return map2_bitmask_with_vec_float_method( + other, &Vectorized::operator==); +#endif + } + + Vectorized operator!=(const Vectorized& other) const { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + return Vectorized( + vreinterpretq_f16_u16(vmvnq_u16(vceqq_f16(values, other.values)))); +#else + return map2_bitmask_with_vec_float_method( + other, &Vectorized::operator!=); +#endif + } + + Vectorized operator<(const Vectorized& other) const { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + return Vectorized( + vreinterpretq_f16_u16(vcltq_f16(values, other.values))); +#else + return map2_bitmask_with_vec_float_method( + other, &Vectorized::operator<); +#endif + } + + Vectorized operator<=(const Vectorized& other) const { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + return Vectorized( + vreinterpretq_f16_u16(vcleq_f16(values, other.values))); +#else + return map2_bitmask_with_vec_float_method( + other, &Vectorized::operator<=); +#endif + } + + Vectorized operator>(const Vectorized& other) const { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + return Vectorized( + vreinterpretq_f16_u16(vcgtq_f16(values, other.values))); +#else + return map2_bitmask_with_vec_float_method( + other, &Vectorized::operator>); +#endif + } + + Vectorized operator>=(const Vectorized& other) const { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + return Vectorized( + vreinterpretq_f16_u16(vcgeq_f16(values, other.values))); +#else + return map2_bitmask_with_vec_float_method( + other, &Vectorized::operator>=); +#endif + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; // Vectorized + +inline std::tuple, Vectorized> convert_half_float( + const Vectorized& a) { + static_assert(Vectorized::size() == 2 * Vectorized::size()); + float16x8_t x = a; + float32x4_t x1 = vcvt_f32_f16(vget_low_f16(x)); + float32x4_t x2 = vcvt_f32_f16(vget_high_f16(x)); + return {Vectorized(x1), Vectorized(x2)}; +} +inline Vectorized convert_float_half( + const Vectorized& a, + const Vectorized& b) { + static_assert(Vectorized::size() == 2 * Vectorized::size()); + float32x4_t x = a; + float32x4_t y = b; + float16x4_t x1 = vcvt_f16_f32(x); + float16x4_t x2 = vcvt_f16_f32(y); + return Vectorized(vcombine_f16(x1, x2)); +} + +template +Vectorized binary_operator_via_float( + Op op, + const Vectorized& a, + const Vectorized& b) { + const auto [a_float_low, a_float_high] = convert_half_float(a); + const auto [b_float_low, b_float_high] = convert_half_float(b); + return convert_float_half( + op(a_float_low, b_float_low), op(a_float_high, b_float_high)); +} + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + return Vectorized(vaddq_f16(a, b)); +#else + return binary_operator_via_float(std::plus>(), a, b); +#endif +} + +template <> +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + return Vectorized(vsubq_f16(a, b)); +#else + return binary_operator_via_float(std::minus>(), a, b); +#endif +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + return Vectorized(vmulq_f16(a, b)); +#else + return binary_operator_via_float(std::multiplies>(), a, b); +#endif +} + +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + return Vectorized(vdivq_f16(a, b)); +#else + return binary_operator_via_float(std::divides>(), a, b); +#endif +} + +// frac. Implement this here so we can use subtraction +inline Vectorized Vectorized::frac() const { + return *this - this->trunc(); +} + +// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + return Vectorized(vmaxq_f16(a, b)); +#else + return binary_operator_via_float( + static_cast (*)( + const Vectorized&, const Vectorized&)>(&maximum), + a, + b); +#endif +} + +// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + return Vectorized(vminq_f16(a, b)); +#else + return binary_operator_via_float( + static_cast (*)( + const Vectorized&, const Vectorized&)>(&minimum), + a, + b); +#endif +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min, + const Vectorized& max) { + return minimum(max, maximum(min, a)); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max) { + return minimum(max, a); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min) { + return maximum(min, a); +} + +template <> +Vectorized inline operator&( + const Vectorized& a, + const Vectorized& b) { + return Vectorized(vreinterpretq_f16_u16( + vandq_u16(vreinterpretq_u16_f16(a), vreinterpretq_u16_f16(b)))); +} + +template <> +Vectorized inline operator|( + const Vectorized& a, + const Vectorized& b) { + return Vectorized(vreinterpretq_f16_u16( + vorrq_u16(vreinterpretq_u16_f16(a), vreinterpretq_u16_f16(b)))); +} + +template <> +Vectorized inline operator^( + const Vectorized& a, + const Vectorized& b) { + return Vectorized(vreinterpretq_f16_u16( + veorq_u16(vreinterpretq_u16_f16(a), vreinterpretq_u16_f16(b)))); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1); +} + +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1); +} + +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1); +} + +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1); +} + +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1); +} + +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1); +} + +template <> +Vectorized inline fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + return Vectorized(vfmaq_f16(c, a, b)); +#else + return a * b + c; +#endif +} + +template <> +Vectorized inline fnmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + return Vectorized(vfmsq_f16(c, a, b)); +#else + return -a * b + c; +#endif +} + +template <> +Vectorized inline fmsub( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + return Vectorized(vnegq_f16(vfmsq_f16(c, a, b))); +#else + return a * b - c; +#endif +} + +template <> +Vectorized inline fnmsub( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + return Vectorized(vnegq_f16(vfmaq_f16(c, a, b))); +#else + return -a * b - c; +#endif +} +#endif // !defined(C10_MOBILE) && defined(__aarch64__) + +} // namespace CPU_CAPABILITY +} // namespace at::vec + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec128/vec128_int_aarch64.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec128/vec128_int_aarch64.h new file mode 100644 index 0000000000000000000000000000000000000000..7d5a95e2fc54ae704bb019f50ae8347a6be93938 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec128/vec128_int_aarch64.h @@ -0,0 +1,799 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include + +namespace at::vec { +// Note [CPU_CAPABILITY namespace] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// This header, and all of its subheaders, will be compiled with +// different architecture flags for each supported set of vector +// intrinsics. So we need to make sure they aren't inadvertently +// linked together. We do this by declaring objects in an `inline +// namespace` which changes the name mangling, but can still be +// accessed as `at::vec`. +inline namespace CPU_CAPABILITY { + +#define VEC_INT_NEON_TEMPLATE(vl, bit) \ + template <> \ + struct is_vec_specialized_for : std::bool_constant {}; \ + \ + template <> \ + class Vectorized { \ + using neon_type = int##bit##x##vl##_t; \ + \ + private: \ + neon_type values; \ + \ + public: \ + using value_type = int##bit##_t; \ + using size_type = int; \ + static constexpr size_type size() { \ + return vl; \ + } \ + Vectorized() { \ + values = vdupq_n_s##bit(0); \ + } \ + Vectorized(neon_type v) : values(v) {} \ + Vectorized(int##bit##_t val); \ + template < \ + typename... Args, \ + typename = std::enable_if_t<(sizeof...(Args) == size())>> \ + Vectorized(Args... vals) { \ + __at_align__ int##bit##_t buffer[size()] = {vals...}; \ + values = vld1q_s##bit(buffer); \ + } \ + operator neon_type() const { \ + return values; \ + } \ + static Vectorized loadu( \ + const void* ptr, \ + int64_t count = size()); \ + void store(void* ptr, int64_t count = size()) const; \ + template \ + static Vectorized blend( \ + const Vectorized& a, \ + const Vectorized& b); \ + static Vectorized blendv( \ + const Vectorized& a, \ + const Vectorized& b, \ + const Vectorized& mask_) { \ + return vbslq_s##bit(vreinterpretq_u##bit##_s##bit(mask_.values), b, a); \ + } \ + template \ + static Vectorized arange( \ + value_type base = 0, \ + step_t step = static_cast(1)); \ + static Vectorized set( \ + const Vectorized& a, \ + const Vectorized& b, \ + int64_t count = size()); \ + const int##bit##_t& operator[](int idx) const = delete; \ + int##bit##_t& operator[](int idx) = delete; \ + Vectorized abs() const { \ + return vabsq_s##bit(values); \ + } \ + Vectorized real() const { \ + return values; \ + } \ + Vectorized imag() const { \ + return vdupq_n_s##bit(0); \ + } \ + Vectorized conj() const { \ + return values; \ + } \ + Vectorized neg() const { \ + return vnegq_s##bit(values); \ + } \ + int##bit##_t reduce_add() const { \ + return vaddvq_s##bit(values); \ + } \ + int##bit##_t reduce_max() const; \ + Vectorized operator==( \ + const Vectorized& other) const { \ + return Vectorized( \ + vreinterpretq_s##bit##_u##bit(vceqq_s##bit(values, other.values))); \ + } \ + Vectorized operator!=( \ + const Vectorized& other) const; \ + Vectorized operator<( \ + const Vectorized& other) const { \ + return Vectorized( \ + vreinterpretq_s##bit##_u##bit(vcltq_s##bit(values, other.values))); \ + } \ + Vectorized operator<=( \ + const Vectorized& other) const { \ + return Vectorized( \ + vreinterpretq_s##bit##_u##bit(vcleq_s##bit(values, other.values))); \ + } \ + Vectorized operator>( \ + const Vectorized& other) const { \ + return Vectorized( \ + vreinterpretq_s##bit##_u##bit(vcgtq_s##bit(values, other.values))); \ + } \ + Vectorized operator>=( \ + const Vectorized& other) const { \ + return Vectorized( \ + vreinterpretq_s##bit##_u##bit(vcgeq_s##bit(values, other.values))); \ + } \ + Vectorized eq(const Vectorized& other) const; \ + Vectorized ne(const Vectorized& other) const; \ + Vectorized gt(const Vectorized& other) const; \ + Vectorized ge(const Vectorized& other) const; \ + Vectorized lt(const Vectorized& other) const; \ + Vectorized le(const Vectorized& other) const; \ + }; \ + template <> \ + Vectorized inline operator+( \ + const Vectorized& a, const Vectorized& b) { \ + return vaddq_s##bit(a, b); \ + } \ + template <> \ + Vectorized inline operator-( \ + const Vectorized& a, const Vectorized& b) { \ + return vsubq_s##bit(a, b); \ + } \ + template <> \ + Vectorized inline operator&( \ + const Vectorized& a, const Vectorized& b) { \ + return vandq_s##bit(a, b); \ + } \ + template <> \ + Vectorized inline operator|( \ + const Vectorized& a, const Vectorized& b) { \ + return vorrq_s##bit(a, b); \ + } \ + template <> \ + Vectorized inline operator^( \ + const Vectorized& a, const Vectorized& b) { \ + return veorq_s##bit(a, b); \ + } \ + Vectorized inline Vectorized::eq( \ + const Vectorized& other) const { \ + return (*this == other) & Vectorized(1); \ + } \ + Vectorized inline Vectorized::ne( \ + const Vectorized& other) const { \ + return (*this != other) & Vectorized(1); \ + } \ + Vectorized inline Vectorized::gt( \ + const Vectorized& other) const { \ + return (*this > other) & Vectorized(1); \ + } \ + Vectorized inline Vectorized::ge( \ + const Vectorized& other) const { \ + return (*this >= other) & Vectorized(1); \ + } \ + Vectorized inline Vectorized::lt( \ + const Vectorized& other) const { \ + return (*this < other) & Vectorized(1); \ + } \ + Vectorized inline Vectorized::le( \ + const Vectorized& other) const { \ + return (*this <= other) & Vectorized(1); \ + } + +VEC_INT_NEON_TEMPLATE(2, 64) +VEC_INT_NEON_TEMPLATE(4, 32) +VEC_INT_NEON_TEMPLATE(8, 16) +VEC_INT_NEON_TEMPLATE(16, 8) + +inline int32_t Vectorized::reduce_max() const { + return vmaxvq_s32(values); +} + +inline int16_t Vectorized::reduce_max() const { + return vmaxvq_s16(values); +} + +inline int8_t Vectorized::reduce_max() const { + return vmaxvq_s8(values); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return vmulq_s32(a, b); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return vmulq_s16(a, b); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return vmulq_s8(a, b); +} + +template <> +inline Vectorized operator~(const Vectorized& a) { + int64x2_t val = a; + return ~val; +} + +template <> +inline Vectorized operator~(const Vectorized& a) { + return vmvnq_s32(a); +} + +template <> +inline Vectorized operator~(const Vectorized& a) { + return vmvnq_s16(a); +} + +template <> +inline Vectorized operator~(const Vectorized& a) { + return vmvnq_s8(a); +} + +inline Vectorized Vectorized::operator!=( + const Vectorized& other) const { + return ~(*this == other); +} + +inline Vectorized Vectorized::operator!=( + const Vectorized& other) const { + return ~(*this == other); +} + +inline Vectorized Vectorized::operator!=( + const Vectorized& other) const { + return ~(*this == other); +} + +inline Vectorized Vectorized::operator!=( + const Vectorized& other) const { + return ~(*this == other); +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return vminq_s32(a, b); +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return vminq_s16(a, b); +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return vminq_s8(a, b); +} + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return vmaxq_s32(a, b); +} + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return vmaxq_s16(a, b); +} + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return vmaxq_s8(a, b); +} + +template +Vectorized Vectorized::blend( + const Vectorized& a, + const Vectorized& b) { + // Build an array of flags: each bit of element is 1 if the corresponding bit + // in 'mask' is set, 0 otherwise. + uint64x2_t maskArray = { + (mask & 1LL) ? 0xFFFFFFFFFFFFFFFF : 0, + (mask & 2LL) ? 0xFFFFFFFFFFFFFFFF : 0}; + // Use BSL to select elements from b where the mask is 1, else from a + return vbslq_s64(maskArray, b.values, a.values); +} + +template +Vectorized Vectorized::blend( + const Vectorized& a, + const Vectorized& b) { + // Build an array of flags: each bit of element is 1 if the corresponding bit + // in 'mask' is set, 0 otherwise. + uint32x4_t maskArray = { + (mask & 1LL) ? 0xFFFFFFFF : 0, + (mask & 2LL) ? 0xFFFFFFFF : 0, + (mask & 4LL) ? 0xFFFFFFFF : 0, + (mask & 8LL) ? 0xFFFFFFFF : 0}; + // Use BSL to select elements from b where the mask is 1, else from a + return vbslq_s32(maskArray, b.values, a.values); +} + +template +Vectorized Vectorized::blend( + const Vectorized& a, + const Vectorized& b) { + // Build an array of flags: each bit of element is 1 if the corresponding bit + // in 'mask' is set, 0 otherwise. + uint16x8_t maskArray = { + (mask & 1LL) ? 0xFFFF : 0, + (mask & 2LL) ? 0xFFFF : 0, + (mask & 4LL) ? 0xFFFF : 0, + (mask & 8LL) ? 0xFFFF : 0, + (mask & 16LL) ? 0xFFFF : 0, + (mask & 32LL) ? 0xFFFF : 0, + (mask & 64LL) ? 0xFFFF : 0, + (mask & 128LL) ? 0xFFFF : 0}; + // Use BSL to select elements from b where the mask is 1, else from a + return vbslq_s16(maskArray, b.values, a.values); +} + +template +Vectorized Vectorized::blend( + const Vectorized& a, + const Vectorized& b) { + // Build an array of flags: each bit of element is 1 if the corresponding bit + // in 'mask' is set, 0 otherwise. + uint8x16_t maskArray = { + (mask & 1LL) ? 0xFF : 0, + (mask & 2LL) ? 0xFF : 0, + (mask & 4LL) ? 0xFF : 0, + (mask & 8LL) ? 0xFF : 0, + (mask & 16LL) ? 0xFF : 0, + (mask & 32LL) ? 0xFF : 0, + (mask & 64LL) ? 0xFF : 0, + (mask & 128LL) ? 0xFF : 0, + (mask & 256LL) ? 0xFF : 0, + (mask & 512LL) ? 0xFF : 0, + (mask & 1024LL) ? 0xFF : 0, + (mask & 2048LL) ? 0xFF : 0, + (mask & 4096LL) ? 0xFF : 0, + (mask & 8192LL) ? 0xFF : 0, + (mask & 16384LL) ? 0xFF : 0, + (mask & 32768LL) ? 0xFF : 0}; + // Use BSL to select elements from b where the mask is 1, else from a + return vbslq_s8(maskArray, b.values, a.values); +} + +#define VEC_INT_NEON_OPS(vl, bit) \ + inline Vectorized::Vectorized(int##bit##_t val) { \ + values = vdupq_n_s##bit(val); \ + } \ + inline Vectorized Vectorized::loadu( \ + const void* ptr, int64_t count) { \ + if (count == size()) { \ + return vld1q_s##bit(reinterpret_cast(ptr)); \ + } else { \ + __at_align__ int##bit##_t tmp_values[size()]; \ + for (const auto i : c10::irange(size())) { \ + tmp_values[i] = 0; \ + } \ + std::memcpy( \ + tmp_values, \ + reinterpret_cast(ptr), \ + count * sizeof(int##bit##_t)); \ + return vld1q_s##bit(reinterpret_cast(tmp_values)); \ + } \ + } \ + inline void Vectorized::store(void* ptr, int64_t count) \ + const { \ + if (count == size()) { \ + vst1q_s##bit(reinterpret_cast(ptr), values); \ + } else { \ + int##bit##_t tmp_values[size()]; \ + vst1q_s##bit(reinterpret_cast(tmp_values), values); \ + std::memcpy(ptr, tmp_values, count * sizeof(int##bit##_t)); \ + } \ + } + +VEC_INT_NEON_OPS(2, 64) +VEC_INT_NEON_OPS(4, 32) +VEC_INT_NEON_OPS(8, 16) +VEC_INT_NEON_OPS(16, 8) + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + int64x2_t x = a; + int64x2_t y = b; + return x * y; +} + +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + int64x2_t x = a; + int64x2_t y = b; + return x / y; +} + +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + int32x4_t x = a; + int32x4_t y = b; + return x / y; +} + +inline int64_t Vectorized::reduce_max() const { + return std::max(values[0], values[1]); +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + int64x2_t x = a; + int64x2_t y = b; + return {std::min(x[0], y[0]), std::min(x[1], y[1])}; +} + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + int64x2_t x = a; + int64x2_t y = b; + return {std::max(x[0], y[0]), std::max(x[1], y[1])}; +} + +template +inline Vectorized Vectorized::arange( + int64_t base, + step_t step) { + const Vectorized base_vec(base); + const Vectorized step_vec(step); + const int64x2_t step_sizes = {0, 1}; + return base_vec.values + step_sizes * step_vec.values; +} + +template +inline Vectorized Vectorized::arange( + int32_t base, + step_t step) { + const Vectorized base_vec(base); + const Vectorized step_vec(step); + const int32x4_t step_sizes = {0, 1, 2, 3}; + return vmlaq_s32(base_vec, step_sizes, step_vec); +} + +template +inline Vectorized Vectorized::arange( + int16_t base, + step_t step) { + const Vectorized base_vec(base); + const Vectorized step_vec(step); + const int16x8_t step_sizes = {0, 1, 2, 3, 4, 5, 6, 7}; + return vmlaq_s16(base_vec, step_sizes, step_vec); +} + +template +inline Vectorized Vectorized::arange(int8_t base, step_t step) { + const Vectorized base_vec(base); + const Vectorized step_vec(step); + const int8x16_t step_sizes = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + return vmlaq_s8(base_vec, step_sizes, step_vec); +} + +template <> +Vectorized inline operator>>( + const Vectorized& a, + const Vectorized& b) { + int64x2_t x = a; + int64x2_t y = b; + uint64x2_t u = vreinterpretq_u64_s64(y); + uint64x2_t z = {std::min(u[0], (uint64_t)63), std::min(u[1], (uint64_t)63)}; + return x >> vreinterpretq_s64_u64(z); +} + +template <> +Vectorized inline operator>>( + const Vectorized& a, + const Vectorized& b) { + int32x4_t x = a; + int32x4_t y = b; + uint32x4_t bound = vdupq_n_u32(31); + uint32x4_t z = vminq_u32(vreinterpretq_u32_s32(y), bound); + return x >> vreinterpretq_s32_u32(z); +} + +template <> +Vectorized inline operator>>( + const Vectorized& a, + const Vectorized& b) { + int16x8_t x = a; + int16x8_t y = b; + uint16x8_t bound = vdupq_n_u16(15); + uint16x8_t z = vminq_u16(vreinterpretq_u16_s16(y), bound); + return x >> vreinterpretq_s16_u16(z); +} + +template <> +Vectorized inline operator>>( + const Vectorized& a, + const Vectorized& b) { + int8x16_t x = a; + int8x16_t y = b; + uint8x16_t bound = vdupq_n_u8(7); + int8x16_t z = vreinterpretq_s8_u8(vminq_u8(vreinterpretq_u8_s8(y), bound)); + return x >> z; +} + +template <> +Vectorized inline operator<<( + const Vectorized& a, + const Vectorized& b) { + int64x2_t y = b; + uint64x2_t u = vreinterpretq_u64_s64(y); + uint64x2_t z = {std::min(u[0], (uint64_t)64), std::min(u[1], (uint64_t)64)}; + return vshlq_s64(a, vreinterpretq_s64_u64(z)); +} + +template <> +Vectorized inline operator<<( + const Vectorized& a, + const Vectorized& b) { + int32x4_t y = b; + uint32x4_t bound = vdupq_n_u32(32); + uint32x4_t z = vminq_u32(vreinterpretq_u32_s32(y), bound); + return vshlq_s32(a, vreinterpretq_s32_u32(z)); +} + +template <> +Vectorized inline operator<<( + const Vectorized& a, + const Vectorized& b) { + int16x8_t y = b; + uint16x8_t bound = vdupq_n_u16(16); + uint16x8_t z = vminq_u16(vreinterpretq_u16_s16(y), bound); + return vshlq_s16(a, vreinterpretq_s16_u16(z)); +} + +template <> +Vectorized inline operator<<( + const Vectorized& a, + const Vectorized& b) { + int8x16_t y = b; + uint8x16_t bound = vdupq_n_u8(8); + int8x16_t z = vreinterpretq_s8_u8(vminq_u8(vreinterpretq_u8_s8(y), bound)); + return vshlq_s8(a, z); +} + +inline Vectorized Vectorized::set( + const Vectorized& a, + const Vectorized& b, + int64_t count) { + if (count == 0) { + return a; + } else if (count >= 2) { + return b; + } else { + int64x2_t c = {b.values[0], a.values[1]}; + return c; + } +} + +inline Vectorized Vectorized::set( + const Vectorized& a, + const Vectorized& b, + int64_t count) { + if (count == 0) { + return a; + } else if (count >= 4) { + return b; + } else { + // Build an array of flags: each bit of element is 1 if the corresponding + // bit in 'mask' is set, 0 otherwise. + uint32x4_t maskArray = { + (count >= 1LL) ? 0xFFFFFFFF : 0, + (count >= 2LL) ? 0xFFFFFFFF : 0, + (count >= 3LL) ? 0xFFFFFFFF : 0, + 0}; + // Use BSL to select elements from b where the mask is 1, else from a + return vbslq_s32(maskArray, b.values, a.values); + } +} + +inline Vectorized Vectorized::set( + const Vectorized& a, + const Vectorized& b, + int64_t count) { + if (count == 0) { + return a; + } else if (count >= 8) { + return b; + } else { + // Build an array of flags: each bit of element is 1 if the corresponding + // bit in 'mask' is set, 0 otherwise. + uint16x8_t maskArray = { + static_cast((count >= 1LL) ? 0xFFFF : 0), + static_cast((count >= 2LL) ? 0xFFFF : 0), + static_cast((count >= 3LL) ? 0xFFFF : 0), + static_cast((count >= 4LL) ? 0xFFFF : 0), + static_cast((count >= 5LL) ? 0xFFFF : 0), + static_cast((count >= 6LL) ? 0xFFFF : 0), + static_cast((count >= 7LL) ? 0xFFFF : 0), + 0}; + // Use BSL to select elements from b where the mask is 1, else from a + return vbslq_s16(maskArray, b.values, a.values); + } +} + +inline Vectorized Vectorized::set( + const Vectorized& a, + const Vectorized& b, + int64_t count) { + if (count == 0) { + return a; + } else if (count >= 16) { + return b; + } else { + // Build an array of flags: each bit of element is 1 if the corresponding + // bit in 'mask' is set, 0 otherwise. + uint8x16_t maskArray = { + static_cast((count >= 1LL) ? 0xFF : 0), + static_cast((count >= 2LL) ? 0xFF : 0), + static_cast((count >= 3LL) ? 0xFF : 0), + static_cast((count >= 4LL) ? 0xFF : 0), + static_cast((count >= 5LL) ? 0xFF : 0), + static_cast((count >= 6LL) ? 0xFF : 0), + static_cast((count >= 7LL) ? 0xFF : 0), + static_cast((count >= 8LL) ? 0xFF : 0), + static_cast((count >= 9LL) ? 0xFF : 0), + static_cast((count >= 10LL) ? 0xFF : 0), + static_cast((count >= 11LL) ? 0xFF : 0), + static_cast((count >= 12LL) ? 0xFF : 0), + static_cast((count >= 13LL) ? 0xFF : 0), + static_cast((count >= 14LL) ? 0xFF : 0), + static_cast((count >= 15LL) ? 0xFF : 0), + 0}; + + // Use BSL to select elements from b where the mask is 1, else from a + return vbslq_s8(maskArray, b.values, a.values); + } +} + +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + Vectorized highBitsA = vmovl_high_s16(a); + Vectorized highBitsB = vmovl_high_s16(b); + Vectorized lowBitsA = vmovl_s16(vget_low_s16(a)); + Vectorized lowBitsB = vmovl_s16(vget_low_s16(b)); + int32x4_t highBitsResult = highBitsA / highBitsB; + int32x4_t lowBitsResult = lowBitsA / lowBitsB; + return vuzp1q_s16( + vreinterpretq_s16_s32(lowBitsResult), + vreinterpretq_s16_s32(highBitsResult)); +} + +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + Vectorized highBitsA = vmovl_high_s8(a); + Vectorized highBitsB = vmovl_high_s8(b); + Vectorized lowBitsA = vmovl_s8(vget_low_s8(a)); + Vectorized lowBitsB = vmovl_s8(vget_low_s8(b)); + int16x8_t highBitsResult = highBitsA / highBitsB; + int16x8_t lowBitsResult = lowBitsA / lowBitsB; + return vuzp1q_s8( + vreinterpretq_s8_s16(lowBitsResult), + vreinterpretq_s8_s16(highBitsResult)); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min, + const Vectorized& max) { + return minimum(max, maximum(min, a)); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min, + const Vectorized& max) { + return minimum(max, maximum(min, a)); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min, + const Vectorized& max) { + return minimum(max, maximum(min, a)); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min, + const Vectorized& max) { + return minimum(max, maximum(min, a)); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max) { + return minimum(max, a); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max) { + return minimum(max, a); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max) { + return minimum(max, a); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max) { + return minimum(max, a); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min) { + return maximum(min, a); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min) { + return maximum(min, a); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min) { + return maximum(min, a); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min) { + return maximum(min, a); +} + +} // namespace CPU_CAPABILITY +} // namespace at::vec + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec128/vec128_reduced_precision_common_neon.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec128/vec128_reduced_precision_common_neon.h new file mode 100644 index 0000000000000000000000000000000000000000..3c6e2cc667d373343de56c1dbb0bfa7c28d99f39 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec128/vec128_reduced_precision_common_neon.h @@ -0,0 +1,316 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +// Shared code for bfloat16 and float16. + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +namespace at::vec { +inline namespace CPU_CAPABILITY { + +// Shared implementation between Vectorized and +// Vectorized. Uses CRTP to allow derived class +// customization. +template < + typename VecT, + typename ValueT, + template typename BlendRegs, + typename Derived> +struct Vectorized16 { + protected: + VecT values; + + public: + using value_type = ValueT; + using size_type = int; + static constexpr size_type size() { + static_assert(sizeof(VecT) == 8 * sizeof(value_type)); + return 8; + } + + protected: + Derived map2( + const Derived& second, + value_type (*const f)(value_type, value_type)) const { + __at_align__ value_type tmp_first[size()]; + __at_align__ value_type tmp_second[size()]; + static_cast(this)->store( + tmp_first); // store this to tmp_first + second.store(tmp_second); + for (const auto i : c10::irange(size())) { + tmp_first[i] = f(tmp_first[i], tmp_second[i]); + } + return Derived::loadu(tmp_first); + } + + public: + Vectorized16() = default; + Vectorized16(VecT v) : values(v) {} + + operator VecT() const { + return values; + } + + template + static Derived blend(const Derived& a, const Derived& b) { + Derived vec; + vec.values = BlendRegs < 0, + (mask & 0x01) != 0 > ::impl(a.values, b.values, vec.values); + vec.values = BlendRegs < 1, + (mask & 0x02) != 0 > ::impl(a.values, b.values, vec.values); + vec.values = BlendRegs < 2, + (mask & 0x04) != 0 > ::impl(a.values, b.values, vec.values); + vec.values = BlendRegs < 3, + (mask & 0x08) != 0 > ::impl(a.values, b.values, vec.values); + + vec.values = BlendRegs < 4, + (mask & 0x10) != 0 > ::impl(a.values, b.values, vec.values); + vec.values = BlendRegs < 5, + (mask & 0x20) != 0 > ::impl(a.values, b.values, vec.values); + vec.values = BlendRegs < 6, + (mask & 0x40) != 0 > ::impl(a.values, b.values, vec.values); + vec.values = BlendRegs < 7, + (mask & 0x80) != 0 > ::impl(a.values, b.values, vec.values); + + return vec; + } + + template + static Derived arange( + value_type base = 0, + step_t step = static_cast(1)) { + const Derived base_vec(base); + const Derived step_vec(step); + const Derived step_sizes( + value_type(0), + value_type(1), + value_type(2), + value_type(3), + value_type(4), + value_type(5), + value_type(6), + value_type(7)); + return fmadd(step_sizes, step_vec, base_vec); + } + + // Very slow implementation of indexing. + // Only required because vec256_qint refers to this. + // Once we specialize that implementation for ARM + // this should be removed. TODO (kimishpatel) + value_type operator[](int idx) const { + __at_align__ value_type tmp[size()]; + static_cast(this)->store(tmp); + return tmp[idx]; + } + + int zero_mask() const { + __at_align__ value_type tmp[size()]; + static_cast(this)->store(tmp); + int mask = 0; + for (int i = 0; i < size(); ++i) { + if (tmp[i] == 0) { + mask |= (1 << i); + } + } + return mask; + } + + Derived map(value_type (*const f)(value_type)) const { + __at_align__ value_type tmp[size()]; + static_cast(this)->store(tmp); + for (const auto i : c10::irange(size())) { + tmp[i] = f(tmp[i]); + } + return Derived::loadu(tmp); + } + + Derived angle() const { + auto zero = Derived(0); + auto pi = Derived(c10::pi); + auto tmp = + Derived::blendv(zero, pi, *static_cast(this) < zero); + return Derived::blendv( + tmp, + *static_cast(this), + static_cast(this)->isnan()); + } + Derived real() const { + return *this; + } + Derived imag() const { + return Derived(0); + } + Derived conj() const { + return *this; + } + + // Sleef does not support FP16/BF16, so many math functions are applied by + // converting to FP32, applying the math function, and then converting back to + // FP16/BF16. + Derived acos() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::acos); + } + Derived acosh() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::acosh); + } + Derived asin() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::asin); + } + Derived asinh() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::asinh); + } + Derived atan() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::atan); + } + Derived atanh() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::atanh); + } + Derived atan2(const Derived& exp) const { + return static_cast(this)->map2_with_vec_float_method( + exp, &Vectorized::atan2); + } + Derived copysign(const Derived& sign) const { + return static_cast(this)->map2_with_vec_float_method( + sign, &Vectorized::copysign); + } + Derived erf() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::erf); + } + Derived erfc() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::erfc); + } + Derived erfinv() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::erfinv); + } + Derived exp() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::exp); + } + Derived exp2() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::exp2); + } + Derived expm1() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::expm1); + } + Derived exp_u20() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::exp_u20); + } + Derived fexp_u20() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::exp_u20); + } + Derived fmod(const Derived& q) const { + // This function is questionable with a conversion, so we use map2 + return map2(q, std::fmod); + } + Derived hypot(const Derived& b) const { + return static_cast(this)->map2_with_vec_float_method( + b, &Vectorized::hypot); + } + Derived i0() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::i0); + } + Derived i0e() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::i0e); + } + Derived digamma() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::digamma); + } + Derived igamma(const Derived& x) const { + return static_cast(this)->map2_with_vec_float_method( + x, &Vectorized::igamma); + } + Derived igammac(const Derived& x) const { + return static_cast(this)->map2_with_vec_float_method( + x, &Vectorized::igammac); + } + Derived log() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::log); + } + Derived log10() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::log10); + } + Derived log1p() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::log1p); + } + Derived log2() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::log2); + } + Derived nextafter(const Derived& b) const { + // This function does not make sense with conversion, so we use map2 + return map2(b, std::nextafter); + } + Derived sin() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::sin); + } + Derived sinh() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::sinh); + } + Derived cos() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::cos); + } + Derived cosh() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::cosh); + } + Derived ceil() const { + // This function is questionable with a conversion, so we use map + return map(at::native::ceil_impl); + } + Derived floor() const { + // This function is questionable with a conversion, so we use map + return map(at::native::floor_impl); + } + Derived round() const { + // This function is questionable with a conversion, so we use map + return map(at::native::round_impl); + } + Derived tan() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::tan); + } + Derived tanh() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::tanh); + } + Derived lgamma() const { + return static_cast(this)->map_with_vec_float_method( + &Vectorized::lgamma); + } + Derived rsqrt() const { + return static_cast(this)->sqrt().reciprocal(); + } + Derived pow(const Derived& exp) const { + return static_cast(this)->map2_with_vec_float_method( + exp, &Vectorized::pow); + } +}; + +} // namespace CPU_CAPABILITY +} // namespace at::vec + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec128/vec128_uint_aarch64.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec128/vec128_uint_aarch64.h new file mode 100644 index 0000000000000000000000000000000000000000..f8c811704314cceb401a0ed793a219332977fded --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec128/vec128_uint_aarch64.h @@ -0,0 +1,383 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include + +namespace at::vec { +// Note [CPU_CAPABILITY namespace] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// This header, and all of its subheaders, will be compiled with +// different architecture flags for each supported set of vector +// intrinsics. So we need to make sure they aren't inadvertently +// linked together. We do this by declaring objects in an `inline +// namespace` which changes the name mangling, but can still be +// accessed as `at::vec`. +inline namespace CPU_CAPABILITY { + +#define VEC_UINT_NEON_TEMPLATE(vl, bit) \ + template <> \ + struct is_vec_specialized_for : std::bool_constant {}; \ + \ + template <> \ + class Vectorized { \ + using neon_type = uint##bit##x##vl##_t; \ + \ + private: \ + neon_type values; \ + \ + public: \ + using value_type = uint##bit##_t; \ + using size_type = int; \ + static constexpr size_type size() { \ + return vl; \ + } \ + Vectorized() { \ + values = vdupq_n_u##bit(0); \ + } \ + Vectorized(neon_type v) : values(v) {} \ + Vectorized(uint##bit##_t val); \ + template < \ + typename... Args, \ + typename = std::enable_if_t<(sizeof...(Args) == size())>> \ + Vectorized(Args... vals) { \ + __at_align__ uint##bit##_t buffer[size()] = {vals...}; \ + values = vld1q_u##bit(buffer); \ + } \ + operator neon_type() const { \ + return values; \ + } \ + static Vectorized loadu( \ + const void* ptr, \ + uint64_t count = size()); \ + void store(void* ptr, uint64_t count = size()) const; \ + template \ + static Vectorized blend( \ + const Vectorized& a, \ + const Vectorized& b); \ + static Vectorized blendv( \ + const Vectorized& a, \ + const Vectorized& b, \ + const Vectorized& mask_) { \ + return vbslq_u##bit(mask_.values, b, a); \ + } \ + template \ + static Vectorized arange( \ + value_type base = 0, \ + step_t step = static_cast(1)); \ + static Vectorized set( \ + const Vectorized& a, \ + const Vectorized& b, \ + uint64_t count = size()); \ + const uint##bit##_t& operator[](uint idx) const = delete; \ + uint##bit##_t& operator[](uint idx) = delete; \ + Vectorized abs() const { \ + return values; \ + } \ + Vectorized real() const { \ + return values; \ + } \ + Vectorized imag() const { \ + return vdupq_n_u##bit(0); \ + } \ + Vectorized conj() const { \ + return values; \ + } \ + Vectorized neg() const { \ + return vreinterpretq_u##bit##_s##bit( \ + vnegq_s##bit(vreinterpretq_s##bit##_u##bit(values))); \ + } \ + uint##bit##_t reduce_add() const { \ + return vaddvq_u##bit(values); \ + } \ + uint##bit##_t reduce_max() const; \ + Vectorized operator==( \ + const Vectorized& other) const { \ + return Vectorized(vceqq_u##bit(values, other.values)); \ + } \ + Vectorized operator!=( \ + const Vectorized& other) const; \ + Vectorized operator<( \ + const Vectorized& other) const { \ + return Vectorized(vcltq_u##bit(values, other.values)); \ + } \ + Vectorized operator<=( \ + const Vectorized& other) const { \ + return Vectorized(vcleq_u##bit(values, other.values)); \ + } \ + Vectorized operator>( \ + const Vectorized& other) const { \ + return Vectorized(vcgtq_u##bit(values, other.values)); \ + } \ + Vectorized operator>=( \ + const Vectorized& other) const { \ + return Vectorized(vcgeq_u##bit(values, other.values)); \ + } \ + Vectorized eq( \ + const Vectorized& other) const; \ + Vectorized ne( \ + const Vectorized& other) const; \ + Vectorized gt( \ + const Vectorized& other) const; \ + Vectorized ge( \ + const Vectorized& other) const; \ + Vectorized lt( \ + const Vectorized& other) const; \ + Vectorized le( \ + const Vectorized& other) const; \ + }; \ + template <> \ + Vectorized inline operator+( \ + const Vectorized& a, \ + const Vectorized& b) { \ + return vaddq_u##bit(a, b); \ + } \ + template <> \ + Vectorized inline operator-( \ + const Vectorized& a, \ + const Vectorized& b) { \ + return vsubq_u##bit(a, b); \ + } \ + template <> \ + Vectorized inline operator&( \ + const Vectorized& a, \ + const Vectorized& b) { \ + return vandq_u##bit(a, b); \ + } \ + template <> \ + Vectorized inline operator|( \ + const Vectorized& a, \ + const Vectorized& b) { \ + return vorrq_u##bit(a, b); \ + } \ + template <> \ + Vectorized inline operator^( \ + const Vectorized& a, \ + const Vectorized& b) { \ + return veorq_u##bit(a, b); \ + } \ + Vectorized inline Vectorized::eq( \ + const Vectorized& other) const { \ + return (*this == other) & Vectorized(1); \ + } \ + Vectorized inline Vectorized::ne( \ + const Vectorized& other) const { \ + return (*this != other) & Vectorized(1); \ + } \ + Vectorized inline Vectorized::gt( \ + const Vectorized& other) const { \ + return (*this > other) & Vectorized(1); \ + } \ + Vectorized inline Vectorized::ge( \ + const Vectorized& other) const { \ + return (*this >= other) & Vectorized(1); \ + } \ + Vectorized inline Vectorized::lt( \ + const Vectorized& other) const { \ + return (*this < other) & Vectorized(1); \ + } \ + Vectorized inline Vectorized::le( \ + const Vectorized& other) const { \ + return (*this <= other) & Vectorized(1); \ + } + +VEC_UINT_NEON_TEMPLATE(16, 8) + +inline uint8_t Vectorized::reduce_max() const { + return vmaxvq_u8(values); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return vmulq_u8(a, b); +} + +template <> +inline Vectorized operator~(const Vectorized& a) { + return vmvnq_u8(a); +} + +inline Vectorized Vectorized::operator!=( + const Vectorized& other) const { + return ~(*this == other); +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return vminq_u8(a, b); +} + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return vmaxq_u8(a, b); +} + +template +Vectorized Vectorized::blend( + const Vectorized& a, + const Vectorized& b) { + // Build an array of flags: each bit of element is 1 if the corresponding bit + // in 'mask' is set, 0 otherwise. + uint8x16_t maskArray = { + (mask & 1LL) ? 0xFF : 0, + (mask & 2LL) ? 0xFF : 0, + (mask & 4LL) ? 0xFF : 0, + (mask & 8LL) ? 0xFF : 0, + (mask & 16LL) ? 0xFF : 0, + (mask & 32LL) ? 0xFF : 0, + (mask & 64LL) ? 0xFF : 0, + (mask & 128LL) ? 0xFF : 0, + (mask & 256LL) ? 0xFF : 0, + (mask & 512LL) ? 0xFF : 0, + (mask & 1024LL) ? 0xFF : 0, + (mask & 2048LL) ? 0xFF : 0, + (mask & 4096LL) ? 0xFF : 0, + (mask & 8192LL) ? 0xFF : 0, + (mask & 16384LL) ? 0xFF : 0, + (mask & 32768LL) ? 0xFF : 0}; + // Use BSL to select elements from b where the mask is 1, else from a + return vbslq_u8(maskArray, b.values, a.values); +} + +#define VEC_UINT_NEON_OPS(vl, bit) \ + inline Vectorized::Vectorized(uint##bit##_t val) { \ + values = vdupq_n_u##bit(val); \ + } \ + inline Vectorized Vectorized::loadu( \ + const void* ptr, uint64_t count) { \ + if (count == size()) { \ + return vld1q_u##bit(reinterpret_cast(ptr)); \ + } else { \ + __at_align__ uint##bit##_t tmp_values[size()]; \ + for (const auto i : c10::irange(size())) { \ + tmp_values[i] = 0; \ + } \ + std::memcpy( \ + tmp_values, \ + reinterpret_cast(ptr), \ + count * sizeof(uint##bit##_t)); \ + return vld1q_u##bit(reinterpret_cast(tmp_values)); \ + } \ + } \ + inline void Vectorized::store(void* ptr, uint64_t count) \ + const { \ + if (count == size()) { \ + vst1q_u##bit(reinterpret_cast(ptr), values); \ + } else { \ + uint##bit##_t tmp_values[size()]; \ + vst1q_u##bit(reinterpret_cast(tmp_values), values); \ + std::memcpy(ptr, tmp_values, count * sizeof(uint##bit##_t)); \ + } \ + } + +VEC_UINT_NEON_OPS(16, 8) + +template +inline Vectorized Vectorized::arange( + uint8_t base, + step_t step) { + const Vectorized base_vec(base); + const Vectorized step_vec(step); + const uint8x16_t step_sizes = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + return vmlaq_u8(base_vec, step_sizes, step_vec); +} + +template <> +Vectorized inline operator>>( + const Vectorized& a, + const Vectorized& b) { + uint8x16_t x = a; + uint8x16_t bound = vdupq_n_u8(8); + uint8x16_t z = vminq_u8(b, bound); + return x >> z; +} + +template <> +Vectorized inline operator<<( + const Vectorized& a, + const Vectorized& b) { + uint8x16_t bound = vdupq_n_u8(8); + uint8x16_t z = vminq_u8(b, bound); + return vshlq_u8(a, vreinterpretq_s8_u8(z)); +} + +inline Vectorized Vectorized::set( + const Vectorized& a, + const Vectorized& b, + uint64_t count) { + if (count == 0) { + return a; + } else if (count >= 16) { + return b; + } else { + // Build an array of flags: each bit of element is 1 if the corresponding + // bit in 'mask' is set, 0 otherwise. + uint8x16_t maskArray = { + static_cast((count >= 1LL) ? 0xFF : 0), + static_cast((count >= 2LL) ? 0xFF : 0), + static_cast((count >= 3LL) ? 0xFF : 0), + static_cast((count >= 4LL) ? 0xFF : 0), + static_cast((count >= 5LL) ? 0xFF : 0), + static_cast((count >= 6LL) ? 0xFF : 0), + static_cast((count >= 7LL) ? 0xFF : 0), + static_cast((count >= 8LL) ? 0xFF : 0), + static_cast((count >= 9LL) ? 0xFF : 0), + static_cast((count >= 10LL) ? 0xFF : 0), + static_cast((count >= 11LL) ? 0xFF : 0), + static_cast((count >= 12LL) ? 0xFF : 0), + static_cast((count >= 13LL) ? 0xFF : 0), + static_cast((count >= 14LL) ? 0xFF : 0), + static_cast((count >= 15LL) ? 0xFF : 0), + 0}; + + // Use BSL to select elements from b where the mask is 1, else from a + return vbslq_u8(maskArray, b.values, a.values); + } +} + +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + uint8x16_t x = a; + uint8x16_t y = b; + return x / y; +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min, + const Vectorized& max) { + return minimum(max, maximum(min, a)); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max) { + return minimum(max, a); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min) { + return maximum(min, a); +} + +} // namespace CPU_CAPABILITY +} // namespace at::vec + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256.h new file mode 100644 index 0000000000000000000000000000000000000000..6745dd7eb2a1f371b45d5e21fe2f52276cf864db --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256.h @@ -0,0 +1,435 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include + +#include +#if !( \ + defined(__VSX__) || defined(CPU_CAPABILITY_VSX) || \ + defined(CPU_CAPABILITY_ZVECTOR)) +#if defined(CPU_CAPABILITY_SVE256) +#include +#else +// clang-format off +#include +#include +#include +#include +#endif +#if !defined(CPU_CAPABILITY_SVE256) || !defined(__ARM_FEATURE_BF16) +#include +#endif +#include +#include +#include +// clang-format on +#elif defined(__VSX__) || defined(CPU_CAPABILITY_VSX) +#include +#else +// clang-format off +#include +#include +#include +// clang-format on +#endif + +#include +#include + +#include +#include +#include +#include +#include + +namespace at::vec { + +// Note [CPU_CAPABILITY namespace] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// This header, and all of its subheaders, will be compiled with +// different architecture flags for each supported set of vector +// intrinsics. So we need to make sure they aren't inadvertently +// linked together. We do this by declaring objects in an `inline +// namespace` which changes the name mangling, but can still be +// accessed as `at::vec`. +inline namespace CPU_CAPABILITY { + +inline std::ostream& operator<<(std::ostream& stream, const c10::qint32& val) { + stream << val.val_; + return stream; +} +inline std::ostream& operator<<(std::ostream& stream, const c10::qint8& val) { + stream << static_cast(val.val_); + return stream; +} +inline std::ostream& operator<<(std::ostream& stream, const c10::quint8& val) { + stream << static_cast(val.val_); + return stream; +} + +template +std::ostream& operator<<(std::ostream& stream, const Vectorized& vec) { + T buf[Vectorized::size()]; + vec.store(buf); + stream << "vec["; + for (int i = 0; i != Vectorized::size(); i++) { + if (i != 0) { + stream << ", "; + } + stream << buf[i]; + } + stream << ']'; + return stream; +} + +#if defined(CPU_CAPABILITY_AVX2) + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CAST (AVX2) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template <> +inline Vectorized cast(const Vectorized& src) { + return _mm256_castpd_ps(src); +} + +template <> +inline Vectorized cast(const Vectorized& src) { + return _mm256_castps_pd(src); +} + +template <> +inline Vectorized cast(const Vectorized& src) { + return _mm256_castsi256_ps(src); +} + +template <> +inline Vectorized cast( + const Vectorized& src) { + return _mm256_castsi256_pd(src); +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +#ifndef _MSC_VER +// MSVC is not working well on complex function overload. +template +std::enable_if_t< + scale == 1 || scale == 2 || scale == 4 || scale == 8, + Vectorized< + double>> inline gather(const double* base_addr, const Vectorized& vindex) { + return _mm256_i64gather_pd(base_addr, vindex, scale); +} + +template +std::enable_if_t< + scale == 1 || scale == 2 || scale == 4 || scale == 8, + Vectorized< + float>> inline gather(const float* base_addr, const Vectorized& vindex) { + return _mm256_i32gather_ps(base_addr, vindex, scale); +} +#endif +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MASK GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +#ifndef _MSC_VER +// MSVC is not working well on complex function overload. +template +std:: + enable_if_t> inline mask_gather( + const Vectorized& src, + const double* base_addr, + const Vectorized& vindex, + Vectorized& mask) { + return _mm256_mask_i64gather_pd(src, base_addr, vindex, mask, scale); +} + +template +std:: + enable_if_t> inline mask_gather( + const Vectorized& src, + const float* base_addr, + const Vectorized& vindex, + Vectorized& mask) { + return _mm256_mask_i32gather_ps(src, base_addr, vindex, mask, scale); +} +#endif +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CONVERT ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Only works for inputs in the range: [-2^51, 2^51] +// From: https://stackoverflow.com/a/41148578 +template <> +Vectorized inline convert_to_int_of_same_size( + const Vectorized& src) { + auto x = _mm256_add_pd(src, _mm256_set1_pd(0x0018000000000000)); + return _mm256_sub_epi64( + _mm256_castpd_si256(x), + _mm256_castpd_si256(_mm256_set1_pd(0x0018000000000000))); +} + +template <> +Vectorized inline convert_to_int_of_same_size( + const Vectorized& src) { + return _mm256_cvttps_epi32(src); +} + +// From: https://stackoverflow.com/a/41148578 +template <> +Vectorized inline convert_to_fp_of_same_size( + const Vectorized& src) { + __m256i magic_i_lo = _mm256_set1_epi64x(0x4330000000000000); /* 2^52 */ + __m256i magic_i_hi32 = + _mm256_set1_epi64x(0x4530000080000000); /* 2^84 + 2^63 */ + __m256i magic_i_all = + _mm256_set1_epi64x(0x4530000080100000); /* 2^84 + 2^63 + 2^52 */ + __m256d magic_d_all = _mm256_castsi256_pd(magic_i_all); + + __m256i v_lo = _mm256_blend_epi32( + magic_i_lo, src, 0b01010101); /* v_low = low32 + 2^52 */ + __m256i v_hi = _mm256_srli_epi64(src, 32); + v_hi = _mm256_xor_si256( + v_hi, magic_i_hi32); /* v_hi = high32*2^32 + 2^84 + 2^63 */ + /* int64 = low32 + high32*2^32 = v_hi + v_lo - 2^52 - 2^63 - 2^84 */ + __m256d v_hi_dbl = _mm256_sub_pd(_mm256_castsi256_pd(v_hi), magic_d_all); + __m256d result = _mm256_add_pd(v_hi_dbl, _mm256_castsi256_pd(v_lo)); + return result; +} + +template <> +Vectorized inline convert_to_fp_of_same_size( + const Vectorized& src) { + return _mm256_cvtepi32_ps(src); +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ INTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template <> +std::pair, Vectorized> inline interleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, a1, a2, a3} + // b = {b0, b1, b2, b3} + + // swap lanes: + // a_swapped = {a0, a1, b0, b1} + // b_swapped = {a2, a3, b2, b3} + auto a_swapped = + _mm256_permute2f128_pd(a, b, 0b0100000); // 0, 2. 4 bits apart + auto b_swapped = + _mm256_permute2f128_pd(a, b, 0b0110001); // 1, 3. 4 bits apart + + // group cols crossing lanes: + // return {a0, b0, a1, b1} + // {a2, b2, a3, b3} + return std::make_pair( + _mm256_permute4x64_pd(a_swapped, 0b11011000), // 0, 2, 1, 3 + _mm256_permute4x64_pd(b_swapped, 0b11011000)); // 0, 2, 1, 3 +} + +template <> +std::pair, Vectorized> inline interleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, a1, a2, a3, a4, a5, a6, a7} + // b = {b0, b1, b2, b3, b4, b5, b6, b7} + + // swap lanes: + // a_swapped = {a0, a1, a2, a3, b0, b1, b2, b3} + // b_swapped = {a4, a5, a6, a7, b4, b5, b6, b7} + // TODO: can we support caching this? + auto a_swapped = + _mm256_permute2f128_ps(a, b, 0b0100000); // 0, 2. 4 bits apart + auto b_swapped = + _mm256_permute2f128_ps(a, b, 0b0110001); // 1, 3. 4 bits apart + + // group cols crossing lanes: + // return {a0, b0, a1, b1, a2, b2, a3, b3} + // {a4, b4, a5, b5, a6, b6, a7, b7} + const __m256i group_ctrl = _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7); + return std::make_pair( + _mm256_permutevar8x32_ps(a_swapped, group_ctrl), + _mm256_permutevar8x32_ps(b_swapped, group_ctrl)); +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ DEINTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template <> +std::pair, Vectorized> inline deinterleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, b0, a1, b1} + // b = {a2, b2, a3, b3} + + // group cols crossing lanes: + // a_grouped = {a0, a1, b0, b1} + // b_grouped = {a2, a3, b2, b3} + auto a_grouped = _mm256_permute4x64_pd(a, 0b11011000); // 0, 2, 1, 3 + auto b_grouped = _mm256_permute4x64_pd(b, 0b11011000); // 0, 2, 1, 3 + + // swap lanes: + // return {a0, a1, a2, a3} + // {b0, b1, b2, b3} + return std::make_pair( + _mm256_permute2f128_pd( + a_grouped, b_grouped, 0b0100000), // 0, 2. 4 bits apart + _mm256_permute2f128_pd( + a_grouped, b_grouped, 0b0110001)); // 1, 3. 4 bits apart +} + +template <> +std::pair, Vectorized> inline deinterleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, b0, a1, b1, a2, b2, a3, b3} + // b = {a4, b4, a5, b5, a6, b6, a7, b7} + + // group cols crossing lanes: + // a_grouped = {a0, a1, a2, a3, b0, b1, b2, b3} + // b_grouped = {a4, a5, a6, a7, b4, b5, b6, b7} + // TODO: can we support caching this? + const __m256i group_ctrl = _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7); + auto a_grouped = _mm256_permutevar8x32_ps(a, group_ctrl); + auto b_grouped = _mm256_permutevar8x32_ps(b, group_ctrl); + + // swap lanes: + // return {a0, a1, a2, a3, a4, a5, a6, a7} + // {b0, b1, b2, b3, b4, b5, b6, b7} + return std::make_pair( + _mm256_permute2f128_ps( + a_grouped, b_grouped, 0b0100000), // 0, 2. 4 bits apart + _mm256_permute2f128_ps( + a_grouped, b_grouped, 0b0110001)); // 1, 3. 4 bits apart +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FLIP ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template <> +inline Vectorized flip(const Vectorized& v) { + const __m256i mask_float = _mm256_set_epi32(0, 1, 2, 3, 4, 5, 6, 7); + return _mm256_permutevar8x32_ps(v, mask_float); +} + +template <> +inline Vectorized flip(const Vectorized& v) { + return _mm256_permute4x64_pd(v, 27); // 27 == _MM_SHUFFLE(0, 1, 2, 3) +} + +template <> +inline Vectorized flip(const Vectorized& v) { + return _mm256_permute4x64_epi64(v, 27); // 27 == _MM_SHUFFLE(0, 1, 2, 3) +} + +template <> +inline Vectorized flip(const Vectorized& v) { + const __m256i mask_int32 = _mm256_set_epi32(0, 1, 2, 3, 4, 5, 6, 7); + return _mm256_permutevar8x32_epi32(v, mask_int32); +} + +template <> +inline Vectorized flip(const Vectorized& v) { + const __m256i mask = _mm256_set_epi8( + 1, + 0, + 3, + 2, + 5, + 4, + 7, + 6, + 9, + 8, + 11, + 10, + 13, + 12, + 15, + 14, + 1, + 0, + 3, + 2, + 5, + 4, + 7, + 6, + 9, + 8, + 11, + 10, + 13, + 12, + 15, + 14); + auto reversed = _mm256_shuffle_epi8(v, mask); + return _mm256_permute2x128_si256(reversed, reversed, 1); +} + +inline __m256i flip8(const __m256i& v) { + const __m256i mask_int8 = _mm256_set_epi8( + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15); + auto reversed = _mm256_shuffle_epi8(v, mask_int8); + return _mm256_permute2x128_si256(reversed, reversed, 1); +} + +template <> +inline Vectorized flip(const Vectorized& v) { + return flip8(v); +} + +template <> +inline Vectorized flip(const Vectorized& v) { + return flip8(v); +} + +inline Vectorized operator&&( + const Vectorized& self, + const Vectorized& other) { + const __m256i* self_ = reinterpret_cast(self.as_bytes()); + const __m256i* other_ = reinterpret_cast(other.as_bytes()); + __m256i out = _mm256_and_si256(*self_, *other_); + Vectorized ret; + std::memcpy(ret, &out, ret.size() * sizeof(bool)); + return ret; +} + +#endif // (defined(CPU_CAPABILITY_AVX2) + +} // namespace CPU_CAPABILITY +} // namespace at::vec + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_16bit_float.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_16bit_float.h new file mode 100644 index 0000000000000000000000000000000000000000..2a585884e36ebdb20ef32ef8dc0e9f82d02895ba --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_16bit_float.h @@ -0,0 +1,837 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +// Used for shared functions and classes for vec256_bfloat16.h and +// vec256_half.h. Any functions/classes that are common between those two files +// should be defined here. Any non-shared functions/classes should be defined in +// the respective files. + +#include +#include + +#if defined(CPU_CAPABILITY_AVX2) +#define SLEEF_STATIC_LIBS +#include +#endif + +namespace at::vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_AVX2) + +#ifndef SLEEF_CONST +#if (defined(__GNUC__) || defined(__CLANG__)) && !defined(__INTEL_COMPILER) +#define SLEEF_CONST const +#else +#define SLEEF_CONST +#endif +#define SLEEF_CONST_OLD SLEEF_CONST +#else +#define SLEEF_CONST_OLD +#endif + +// bfloat16 conversion +static inline void cvtbf16_fp32(const __m128i& a, __m256& o) { + o = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(a), 16)); +} + +static inline void cvtbf16_fp32(const __m256i& a, __m256& o1, __m256& o2) { + __m128i lo = _mm256_extractf128_si256(a, 0); + __m128i hi = _mm256_extractf128_si256(a, 1); + cvtbf16_fp32(lo, o1); + cvtbf16_fp32(hi, o2); +} + +static inline __m128i cvtfp32_bf16(const __m256& src) { + __m256i value = _mm256_castps_si256(src); + __m256i nan = _mm256_set1_epi32(0xffff); + __m256i mask = _mm256_castps_si256(_mm256_cmp_ps(src, src, _CMP_ORD_Q)); + __m256i ones = _mm256_set1_epi32(0x1); + __m256i vec_bias = _mm256_set1_epi32(0x7fff); + // uint32_t lsb = (input >> 16) & 1; + auto t_value = _mm256_and_si256(_mm256_srli_epi32(value, 16), ones); + // uint32_t rounding_bias = 0x7fff + lsb; + t_value = _mm256_add_epi32(t_value, vec_bias); + // input += rounding_bias; + t_value = _mm256_add_epi32(t_value, value); + // input = input >> 16; + t_value = _mm256_srli_epi32(t_value, 16); + // Check NaN before converting back to bf16 + t_value = _mm256_blendv_epi8(nan, t_value, mask); + t_value = + _mm256_packus_epi32(t_value, t_value); // t[4-7] t[4-7] t[0-4] t[0-4] + t_value = _mm256_permute4x64_epi64(t_value, 0xd8); // 11 01 10 00 + return _mm256_castsi256_si128(t_value); +} + +static inline __m256i cvtfp32_bf16(const __m256& a, const __m256& b) { + __m256i lo = _mm256_castps_si256(a); + __m256i hi = _mm256_castps_si256(b); + __m256i nan = _mm256_set1_epi32(0xffff); + __m256i mask_lo = _mm256_castps_si256(_mm256_cmp_ps(a, a, _CMP_ORD_Q)); + __m256i mask_hi = _mm256_castps_si256(_mm256_cmp_ps(b, b, _CMP_ORD_Q)); + __m256i ones = _mm256_set1_epi32(0x1); + __m256i vec_bias = _mm256_set1_epi32(0x7fff); + // uint32_t lsb = (input >> 16) & 1; + auto t_lo = _mm256_and_si256(_mm256_srli_epi32(lo, 16), ones); + auto t_hi = _mm256_and_si256(_mm256_srli_epi32(hi, 16), ones); + // uint32_t rounding_bias = 0x7fff + lsb; + t_lo = _mm256_add_epi32(t_lo, vec_bias); + t_hi = _mm256_add_epi32(t_hi, vec_bias); + // input += rounding_bias; + t_lo = _mm256_add_epi32(t_lo, lo); + t_hi = _mm256_add_epi32(t_hi, hi); + // input = input >> 16; + t_lo = _mm256_srli_epi32(t_lo, 16); + t_hi = _mm256_srli_epi32(t_hi, 16); + // Check NaN before converting back to bf16 + t_lo = _mm256_blendv_epi8(nan, t_lo, mask_lo); + t_hi = _mm256_blendv_epi8(nan, t_hi, mask_hi); + + t_lo = _mm256_packus_epi32( + t_lo, t_hi); // t_hi[4-7] t_lo[4-7] t_hi[0-4] t_lo[0-4] + return _mm256_permute4x64_epi64(t_lo, 0xd8); // 11 01 10 00 +} + +static inline __m256i merge_compare_result(const __m256& a, const __m256& b) { + __m256i lo = _mm256_castps_si256(a); + __m256i hi = _mm256_castps_si256(b); + lo = _mm256_srli_epi32(lo, 16); + hi = _mm256_srli_epi32(hi, 16); + auto out = _mm256_packus_epi32(lo, hi); + return _mm256_permute4x64_epi64(out, 0xd8); +} + +// float16 conversion +static inline void cvtfp16_fp32(const __m128i& a, __m256& o) { + o = _mm256_cvtph_ps(a); +} + +static inline void cvtfp16_fp32(const __m256i& a, __m256& o1, __m256& o2) { + __m128i lo = _mm256_extractf128_si256(a, 0); + __m128i hi = _mm256_extractf128_si256(a, 1); + cvtfp16_fp32(lo, o1); + cvtfp16_fp32(hi, o2); +} + +static inline __m128i cvtfp32_fp16(const __m256& src) { + return _mm256_cvtps_ph(src, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); +} + +static inline __m256i cvtfp32_fp16(const __m256& a, const __m256& b) { + __m128i lo = + _mm256_cvtps_ph(a, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + __m128i hi = + _mm256_cvtps_ph(b, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + return _mm256_insertf128_si256(_mm256_castsi128_si256(lo), hi, 1); +} + +// dtype conversion between float16/bfloat16 and float32 +template < + typename T, + typename std::enable_if_t, int> = 0> +inline void cvt_to_fp32(const __m128i& a, __m256& o); +template <> +inline void cvt_to_fp32(const __m128i& a, __m256& o) { + cvtbf16_fp32(a, o); +} +template <> +inline void cvt_to_fp32(const __m128i& a, __m256& o) { + cvtfp16_fp32(a, o); +} + +template < + typename T, + typename std::enable_if_t, int> = 0> +inline void cvt_to_fp32(const __m256i& a, __m256& o1, __m256& o2); +template <> +inline void cvt_to_fp32(const __m256i& a, __m256& o1, __m256& o2) { + cvtbf16_fp32(a, o1, o2); +} +template <> +inline void cvt_to_fp32(const __m256i& a, __m256& o1, __m256& o2) { + cvtfp16_fp32(a, o1, o2); +} + +template < + typename T, + bool is_compare_op = false, + typename std::enable_if_t, int> = 0> +inline __m256i cvt_from_fp32(const __m256& a, const __m256& b); +template <> +inline __m256i cvt_from_fp32( + const __m256& a, + const __m256& b) { + return cvtfp32_bf16(a, b); +} +template <> +inline __m256i cvt_from_fp32(const __m256& a, const __m256& b) { + return merge_compare_result(a, b); +} +template <> +inline __m256i cvt_from_fp32(const __m256& a, const __m256& b) { + return cvtfp32_fp16(a, b); +} +template <> +inline __m256i cvt_from_fp32(const __m256& a, const __m256& b) { + return cvtfp32_fp16(a, b); +} + +template +class Vectorized16 { + static_assert( + is_reduced_floating_point_v, + "Support only float16 and bfloat16."); + + protected: + __m256i values; + + public: + using value_type = uint16_t; + using size_type = int; + static constexpr size_type size() { + return 16; + } + Vectorized16() {} + Vectorized16(__m256i v) : values(v) {} + Vectorized16(T val) { + value_type uw = val.x; + values = _mm256_set1_epi16(uw); + } + Vectorized16( + T val1, + T val2, + T val3, + T val4, + T val5, + T val6, + T val7, + T val8, + T val9, + T val10, + T val11, + T val12, + T val13, + T val14, + T val15, + T val16) { + values = _mm256_setr_epi16( + val1.x, + val2.x, + val3.x, + val4.x, + val5.x, + val6.x, + val7.x, + val8.x, + val9.x, + val10.x, + val11.x, + val12.x, + val13.x, + val14.x, + val15.x, + val16.x); + } + operator __m256i() const { + return values; + } + T& operator[](int idx) = delete; + const T& operator[](int idx) const = delete; + int zero_mask() const { + // returns an integer mask where all zero elements are translated to 1-bit + // and others are translated to 0-bit + __m256i cmp = _mm256_cmpeq_epi16(values, _mm256_set1_epi16(0)); + return _mm256_movemask_epi8(cmp); + } + static Vectorized loadu(const void* ptr, int16_t count = size()) { + if (count == size()) + return _mm256_loadu_si256(reinterpret_cast(ptr)); + + __at_align__ int16_t tmp_values[size()]; +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (const auto i : c10::irange(count, size())) { + tmp_values[i] = 0; + } + std::memcpy(tmp_values, ptr, count * sizeof(int16_t)); + return _mm256_loadu_si256(reinterpret_cast(tmp_values)); + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values); + } else if (count > 0) { + __at_align__ int16_t tmp_values[size()]; + _mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values); + std::memcpy(ptr, tmp_values, count * sizeof(int16_t)); + } + } + template + static Vectorized blend(const Vectorized& a, const Vectorized& b) { + __at_align__ int16_t tmp_values[size()]; + a.store(tmp_values); + if (mask & 0x01) + tmp_values[0] = _mm256_extract_epi16(b.values, 0); + if (mask & 0x02) + tmp_values[1] = _mm256_extract_epi16(b.values, 1); + if (mask & 0x04) + tmp_values[2] = _mm256_extract_epi16(b.values, 2); + if (mask & 0x08) + tmp_values[3] = _mm256_extract_epi16(b.values, 3); + if (mask & 0x10) + tmp_values[4] = _mm256_extract_epi16(b.values, 4); + if (mask & 0x20) + tmp_values[5] = _mm256_extract_epi16(b.values, 5); + if (mask & 0x40) + tmp_values[6] = _mm256_extract_epi16(b.values, 6); + if (mask & 0x80) + tmp_values[7] = _mm256_extract_epi16(b.values, 7); + if (mask & 0x100) + tmp_values[8] = _mm256_extract_epi16(b.values, 8); + if (mask & 0x200) + tmp_values[9] = _mm256_extract_epi16(b.values, 9); + if (mask & 0x400) + tmp_values[10] = _mm256_extract_epi16(b.values, 10); + if (mask & 0x800) + tmp_values[11] = _mm256_extract_epi16(b.values, 11); + if (mask & 0x1000) + tmp_values[12] = _mm256_extract_epi16(b.values, 12); + if (mask & 0x2000) + tmp_values[13] = _mm256_extract_epi16(b.values, 13); + if (mask & 0x4000) + tmp_values[14] = _mm256_extract_epi16(b.values, 14); + if (mask & 0x8000) + tmp_values[15] = _mm256_extract_epi16(b.values, 15); + return loadu(tmp_values); + } + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + return _mm256_blendv_epi8(a.values, b.values, mask.values); + } + template + static Vectorized arange( + T base = 0.f, + step_t step = static_cast(1)) { + return Vectorized( + base, + base + step, + base + 2 * step, + base + 3 * step, + base + 4 * step, + base + 5 * step, + base + 6 * step, + base + 7 * step, + base + 8 * step, + base + 9 * step, + base + 10 * step, + base + 11 * step, + base + 12 * step, + base + 13 * step, + base + 14 * step, + base + 15 * step); + } + static Vectorized set( + const Vectorized& a, + const Vectorized& b, + int64_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + case 4: + return blend<15>(a, b); + case 5: + return blend<31>(a, b); + case 6: + return blend<63>(a, b); + case 7: + return blend<127>(a, b); + case 8: + return blend<255>(a, b); + case 9: + return blend<511>(a, b); + case 10: + return blend<1023>(a, b); + case 11: + return blend<2047>(a, b); + case 12: + return blend<4095>(a, b); + case 13: + return blend<8191>(a, b); + case 14: + return blend<16383>(a, b); + case 15: + return blend<32767>(a, b); + } + return b; + } + + // 'const' type qualifier on return type has no effect, but sleef defines this + // this way For example `Sleef_exp2f8_u10` signature is `const __m256 + // (__m256)` + C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wignored-qualifiers") + Vectorized map(SLEEF_CONST __m256 (*SLEEF_CONST_OLD vop)(__m256)) const { + __m256 lo, hi; + cvt_to_fp32(values, lo, hi); + const auto o1 = vop(lo); + const auto o2 = vop(hi); + return cvt_from_fp32(o1, o2); + } + C10_DIAGNOSTIC_POP() + Vectorized isnan() const { + __m256 lo, hi; + cvt_to_fp32(values, lo, hi); + lo = _mm256_cmp_ps(lo, _mm256_set1_ps(0.0f), _CMP_UNORD_Q); + hi = _mm256_cmp_ps(hi, _mm256_set1_ps(0.0f), _CMP_UNORD_Q); + return merge_compare_result(lo, hi); + } + Vectorized abs() const { + return _mm256_andnot_si256(_mm256_set1_epi16(0x8000), values); + } + Vectorized angle() const { + __m256 lo, hi; + cvt_to_fp32(values, lo, hi); + auto angle_lambda = [](__m256 values_2) { + const auto zero_vec = _mm256_set1_ps(0.f); + const auto nan_vec = _mm256_set1_ps(NAN); + const auto not_nan_mask = _mm256_cmp_ps(values_2, values_2, _CMP_EQ_OQ); + const auto nan_mask = _mm256_cmp_ps(not_nan_mask, zero_vec, _CMP_EQ_OQ); + const auto pi = _mm256_set1_ps(c10::pi); + + const auto neg_mask = _mm256_cmp_ps(values_2, zero_vec, _CMP_LT_OQ); + auto angle = _mm256_blendv_ps(zero_vec, pi, neg_mask); + angle = _mm256_blendv_ps(angle, nan_vec, nan_mask); + return angle; + }; + auto o1 = angle_lambda(lo); + auto o2 = angle_lambda(hi); + return cvt_from_fp32(o1, o2); + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return _mm256_set1_epi16(0); + } + Vectorized conj() const { + return *this; + } + Vectorized acos() const { + return map(Sleef_acosf8_u10); + } + Vectorized acosh() const { + return map(Sleef_acoshf8_u10); + } + Vectorized asin() const { + return map(Sleef_asinf8_u10); + } + Vectorized atan() const { + return map(Sleef_atanf8_u10); + } + Vectorized atanh() const { + return map(Sleef_atanhf8_u10); + } + Vectorized atan2(const Vectorized& b) const { + __m256 lo, hi; + __m256 b1, b2; + cvt_to_fp32(values, lo, hi); + cvt_to_fp32(b.values, b1, b2); + auto o1 = Sleef_atan2f8_u10(lo, b1); + auto o2 = Sleef_atan2f8_u10(hi, b2); + return cvt_from_fp32(o1, o2); + } + Vectorized copysign(const Vectorized& sign) const { + // copy sign bit (0x8000) from sign and remaining bits from values + __m256i mask_value = _mm256_set1_epi32(~0x80008000); + __m256i mask_signbit = _mm256_set1_epi32(0x80008000); + return Vectorized(_mm256_or_si256( + _mm256_and_si256(values, mask_value), + _mm256_and_si256(sign, mask_signbit))); + } + Vectorized erf() const { + return map(Sleef_erff8_u10); + } + Vectorized erfc() const { + return map(Sleef_erfcf8_u15); + } + Vectorized erfinv() const { + __m256 lo, hi; + cvt_to_fp32(values, lo, hi); + __at_align__ float tmp1[size() / 2], tmp2[size() / 2]; + _mm256_storeu_ps(reinterpret_cast(tmp1), lo); + _mm256_storeu_ps(reinterpret_cast(tmp2), hi); + for (int64_t i = 0; i < size() / 2; i++) { + tmp1[i] = calc_erfinv(tmp1[i]); + tmp2[i] = calc_erfinv(tmp2[i]); + } + auto o1 = _mm256_loadu_ps(tmp1); + auto o2 = _mm256_loadu_ps(tmp2); + return cvt_from_fp32(o1, o2); + } + Vectorized exp() const { + return map(Sleef_expf8_u10); + } + Vectorized exp2() const { + return map(Sleef_exp2f8_u10); + } + Vectorized expm1() const { + return map(Sleef_expm1f8_u10); + } + Vectorized fexp_u20() const { + return exp(); + } + Vectorized exp_u20() const { + return exp(); + } + Vectorized fmod(const Vectorized& q) const { + __m256 x_lo, x_hi; + cvt_to_fp32(values, x_lo, x_hi); + __m256 q_lo, q_hi; + cvt_to_fp32(q.values, q_lo, q_hi); + auto o1 = Sleef_fmodf8(x_lo, q_lo); + auto o2 = Sleef_fmodf8(x_hi, q_hi); + return cvt_from_fp32(o1, o2); + } + Vectorized hypot(const Vectorized& b) const { + __m256 lo, hi; + __m256 b1, b2; + cvt_to_fp32(values, lo, hi); + cvt_to_fp32(b.values, b1, b2); + auto o1 = Sleef_hypotf8_u05(lo, b1); + auto o2 = Sleef_hypotf8_u05(hi, b2); + return cvt_from_fp32(o1, o2); + } + Vectorized i0() const { + __m256 lo, hi; + cvt_to_fp32(values, lo, hi); + __at_align__ float tmp1[size() / 2], tmp2[size() / 2]; + _mm256_storeu_ps(reinterpret_cast(tmp1), lo); + _mm256_storeu_ps(reinterpret_cast(tmp2), hi); + for (int64_t i = 0; i < size() / 2; i++) { + tmp1[i] = calc_i0(tmp1[i]); + tmp2[i] = calc_i0(tmp2[i]); + } + auto o1 = _mm256_loadu_ps(tmp1); + auto o2 = _mm256_loadu_ps(tmp2); + return cvt_from_fp32(o1, o2); + } + Vectorized i0e() const { + __m256 lo, hi; + cvt_to_fp32(values, lo, hi); + constexpr auto sz = size(); + __at_align__ float tmp1[sz / 2], tmp2[sz / 2]; + _mm256_storeu_ps(reinterpret_cast(tmp1), lo); + _mm256_storeu_ps(reinterpret_cast(tmp2), hi); + + for (auto i = decltype(sz){0}; i < sz / 2; i++) { + tmp1[i] = calc_i0e(tmp1[i]); + tmp2[i] = calc_i0e(tmp2[i]); + } + const auto o1 = _mm256_loadu_ps(tmp1); + const auto o2 = _mm256_loadu_ps(tmp2); + return cvt_from_fp32(o1, o2); + } + Vectorized digamma() const { + __m256 lo, hi; + cvt_to_fp32(values, lo, hi); + constexpr auto sz = size(); + __at_align__ float tmp1[sz / 2], tmp2[sz / 2]; + _mm256_storeu_ps(reinterpret_cast(tmp1), lo); + _mm256_storeu_ps(reinterpret_cast(tmp2), hi); + + for (auto i = decltype(sz){0}; i < sz / 2; i++) { + tmp1[i] = calc_digamma(tmp1[i]); + tmp2[i] = calc_digamma(tmp2[i]); + } + const auto o1 = _mm256_loadu_ps(tmp1); + const auto o2 = _mm256_loadu_ps(tmp2); + return cvt_from_fp32(o1, o2); + } + Vectorized igamma(const Vectorized& x) const { + __m256 lo, hi; + __m256 xlo, xhi; + cvt_to_fp32(values, lo, hi); + cvt_to_fp32(x.values, xlo, xhi); + __at_align__ float tmp1[size() / 2], tmp2[size() / 2]; + _mm256_storeu_ps(reinterpret_cast(tmp1), lo); + _mm256_storeu_ps(reinterpret_cast(tmp2), hi); + __at_align__ float tmpx1[size() / 2], tmpx2[size() / 2]; + _mm256_storeu_ps(reinterpret_cast(tmpx1), xlo); + _mm256_storeu_ps(reinterpret_cast(tmpx2), xhi); + for (int64_t i = 0; i < size() / 2; ++i) { + tmp1[i] = calc_igamma(tmp1[i], tmpx1[i]); + tmp2[i] = calc_igamma(tmp2[i], tmpx2[i]); + } + auto o1 = _mm256_loadu_ps(tmp1); + auto o2 = _mm256_loadu_ps(tmp2); + return cvt_from_fp32(o1, o2); + } + + Vectorized igammac(const Vectorized& x) const { + __m256 lo, hi; + __m256 xlo, xhi; + cvt_to_fp32(values, lo, hi); + cvt_to_fp32(x.values, xlo, xhi); + __at_align__ float tmp1[size() / 2], tmp2[size() / 2]; + _mm256_storeu_ps(reinterpret_cast(tmp1), lo); + _mm256_storeu_ps(reinterpret_cast(tmp2), hi); + __at_align__ float tmpx1[size() / 2], tmpx2[size() / 2]; + _mm256_storeu_ps(reinterpret_cast(tmpx1), xlo); + _mm256_storeu_ps(reinterpret_cast(tmpx2), xhi); + for (int64_t i = 0; i < size() / 2; ++i) { + tmp1[i] = calc_igammac(tmp1[i], tmpx1[i]); + tmp2[i] = calc_igammac(tmp2[i], tmpx2[i]); + } + auto o1 = _mm256_loadu_ps(tmp1); + auto o2 = _mm256_loadu_ps(tmp2); + return cvt_from_fp32(o1, o2); + } + Vectorized log() const { + return map(Sleef_logf8_u10); + } + Vectorized log2() const { + return map(Sleef_log2f8_u10); + } + Vectorized log10() const { + return map(Sleef_log10f8_u10); + } + Vectorized log1p() const { + return map(Sleef_log1pf8_u10); + } + Vectorized sin() const { + return map(Sleef_sinf8_u10); + } + Vectorized sinh() const { + return map(Sleef_sinhf8_u10); + } + Vectorized cos() const { + return map(Sleef_cosf8_u10); + } + Vectorized cosh() const { + return map(Sleef_coshf8_u10); + } + Vectorized ceil() const { + __m256 lo, hi; + cvt_to_fp32(values, lo, hi); + auto o1 = _mm256_ceil_ps(lo); + auto o2 = _mm256_ceil_ps(hi); + return cvt_from_fp32(o1, o2); + } + Vectorized floor() const { + __m256 lo, hi; + cvt_to_fp32(values, lo, hi); + auto o1 = _mm256_floor_ps(lo); + auto o2 = _mm256_floor_ps(hi); + return cvt_from_fp32(o1, o2); + } + Vectorized neg() const { + return _mm256_xor_si256(values, _mm256_set1_epi16(0x8000)); + } + Vectorized round() const { + __m256 lo, hi; + cvt_to_fp32(values, lo, hi); + auto o1 = + _mm256_round_ps(lo, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + auto o2 = + _mm256_round_ps(hi, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + return cvt_from_fp32(o1, o2); + } + Vectorized tan() const { + return map(Sleef_tanf8_u10); + } + Vectorized tanh() const { + return map(Sleef_tanhf8_u10); + } + Vectorized trunc() const { + __m256 lo, hi; + cvt_to_fp32(values, lo, hi); + auto o1 = _mm256_round_ps(lo, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); + auto o2 = _mm256_round_ps(hi, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); + return cvt_from_fp32(o1, o2); + } + Vectorized lgamma() const { + return map(Sleef_lgammaf8_u10); + } + Vectorized sqrt() const { + __m256 lo, hi; + cvt_to_fp32(values, lo, hi); + auto o1 = _mm256_sqrt_ps(lo); + auto o2 = _mm256_sqrt_ps(hi); + return cvt_from_fp32(o1, o2); + } + Vectorized reciprocal() const { + __m256 lo, hi; + cvt_to_fp32(values, lo, hi); + auto ones = _mm256_set1_ps(1); + auto o1 = _mm256_div_ps(ones, lo); + auto o2 = _mm256_div_ps(ones, hi); + return cvt_from_fp32(o1, o2); + } + Vectorized rsqrt() const { + __m256 lo, hi; + cvt_to_fp32(values, lo, hi); + auto ones = _mm256_set1_ps(1); + auto o1 = _mm256_div_ps(ones, _mm256_sqrt_ps(lo)); + auto o2 = _mm256_div_ps(ones, _mm256_sqrt_ps(hi)); + return cvt_from_fp32(o1, o2); + } + Vectorized pow(const Vectorized& b) const { + __m256 lo, hi; + __m256 b1, b2; + cvt_to_fp32(values, lo, hi); + cvt_to_fp32(b.values, b1, b2); + auto o1 = Sleef_powf8_u10(lo, b1); + auto o2 = Sleef_powf8_u10(hi, b2); + return cvt_from_fp32(o1, o2); + } + + private: + template + Vectorized inline binary_compare(const VectorizedType& b, Op op) const { + __m256 a_lo, a_hi; + __m256 b_lo, b_hi; + cvt_to_fp32(values, a_lo, a_hi); + cvt_to_fp32(b.values, b_lo, b_hi); + auto o1 = op(a_lo, b_lo); + auto o2 = op(a_hi, b_hi); + return cvt_from_fp32(o1, o2); + } + + public: + Vectorized inline operator>(const Vectorized& other) const { + return binary_compare(other, [](__m256 x, __m256 y) { + return _mm256_cmp_ps(x, y, _CMP_GT_OQ); + }); + } + Vectorized inline operator<(const Vectorized& other) const { + return binary_compare(other, [](__m256 x, __m256 y) { + return _mm256_cmp_ps(x, y, _CMP_LT_OQ); + }); + } + Vectorized inline operator>=(const Vectorized& other) const { + return binary_compare(other, [](__m256 x, __m256 y) { + return _mm256_cmp_ps(x, y, _CMP_GE_OQ); + }); + } + Vectorized inline operator<=(const Vectorized& other) const { + return binary_compare(other, [](__m256 x, __m256 y) { + return _mm256_cmp_ps(x, y, _CMP_LE_OQ); + }); + } + Vectorized inline operator==(const Vectorized16& other) const { + return binary_compare(other, [](__m256 x, __m256 y) { + return _mm256_cmp_ps(x, y, _CMP_EQ_OQ); + }); + } + Vectorized inline operator!=(const Vectorized16& other) const { + return binary_compare(other, [](__m256 x, __m256 y) { + return _mm256_cmp_ps(x, y, _CMP_NEQ_UQ); + }); + } +}; + +template +static inline Vectorized binary_op_as_fp32( + const Vectorized& a, + const Vectorized& b, + Op op) { + __m256 a_lo, a_hi; + __m256 b_lo, b_hi; + cvt_to_fp32(__m256i(a), a_lo, a_hi); + cvt_to_fp32(__m256i(b), b_lo, b_hi); + auto o1 = op(a_lo, b_lo); + auto o2 = op(a_hi, b_hi); + return cvt_from_fp32(o1, o2); +} + +#define CONVERT_VECTORIZED_INIT(type, name) \ + inline std::tuple, Vectorized> \ + convert_##name##_float(const Vectorized& a) { \ + __m256 o1, o2; \ + cvt_to_fp32(__m256i(a), o1, o2); \ + return std::make_tuple(o1, o2); \ + } \ + inline Vectorized convert_float_##name( \ + const Vectorized& a, const Vectorized& b) { \ + return cvt_from_fp32(__m256(a), __m256(b)); \ + } + +#define LOAD_FP32_VECTORIZED_INIT(type, name) \ + inline void load_fp32_from_##name( \ + const type* data, Vectorized& out) { \ + auto values = _mm_loadu_si128(reinterpret_cast(data)); \ + __m256 out_values; \ + cvt_to_fp32(values, out_values); \ + out = out_values; \ + } \ + \ + inline void load_fp32_from_##name( \ + const type* data, Vectorized& out1, Vectorized& out2) { \ + auto vec = Vectorized::loadu(data); \ + __m256 out1_values, out2_values; \ + cvt_to_fp32(vec, out1_values, out2_values); \ + out1 = out1_values; \ + out2 = out2_values; \ + } + +#else // CPU_CAPABILITY_AVX2 + +#define CONVERT_NON_VECTORIZED_INIT(type, name) \ + inline std::tuple, Vectorized> \ + convert_##name##_float(const Vectorized& a) { \ + constexpr int64_t K = Vectorized::size(); \ + __at_align__ float arr[K]; \ + __at_align__ type arr2[K]; \ + a.store(arr2); \ + convert(arr2, arr, K); \ + return std::make_tuple( \ + Vectorized::loadu(arr), \ + Vectorized::loadu(arr + Vectorized::size())); \ + } \ + inline Vectorized convert_float_##name( \ + const Vectorized& a, const Vectorized& b) { \ + constexpr int64_t K = Vectorized::size(); \ + __at_align__ float arr[K]; \ + __at_align__ type arr2[K]; \ + a.store(arr); \ + b.store(arr + Vectorized::size()); \ + convert(arr, arr2, K); \ + return Vectorized::loadu(arr2); \ + } + +#define LOAD_FP32_NON_VECTORIZED_INIT(type, name) \ + inline void load_fp32_from_##name( \ + const type* data, Vectorized& out) { \ + __at_align__ float values[Vectorized::size()]; \ + for (const auto k : c10::irange(Vectorized::size())) { \ + values[k] = data[k]; \ + } \ + out = Vectorized::loadu(values); \ + } \ + \ + inline void load_fp32_from_##name( \ + const type* data, Vectorized& out1, Vectorized& out2) { \ + load_fp32_from_##name(data, out1); \ + data += Vectorized::size(); \ + load_fp32_from_##name(data, out2); \ + } + +#endif // CPU_CAPABILITY_AVX2 +} // namespace CPU_CAPABILITY +} // namespace at::vec + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_bfloat16.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_bfloat16.h new file mode 100644 index 0000000000000000000000000000000000000000..6fec6b9b7b59a2ba50b720c71b4146992b665084 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_bfloat16.h @@ -0,0 +1,285 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include + +namespace at::vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_AVX2) + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized : public Vectorized16 { + public: + using Vectorized16::Vectorized16; + + using value_type = BFloat16; + + Vectorized frac() const; + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return binary_op_as_fp32(a, b, [](const __m256& x, const __m256& y) { + return _mm256_add_ps(x, y); + }); +} +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return binary_op_as_fp32(a, b, [](const __m256& x, const __m256& y) { + return _mm256_sub_ps(x, y); + }); +} +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return binary_op_as_fp32(a, b, [](const __m256& x, const __m256& y) { + return _mm256_mul_ps(x, y); + }); +} +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return binary_op_as_fp32(a, b, [](const __m256& x, const __m256& y) { + return _mm256_div_ps(x, y); + }); +} +Vectorized inline operator&( + const Vectorized& a, + const Vectorized& b) { + return _mm256_and_si256(a, b); +} +Vectorized inline operator|( + const Vectorized& a, + const Vectorized& b) { + return _mm256_or_si256(a, b); +} +Vectorized inline operator^( + const Vectorized& a, + const Vectorized& b) { + return _mm256_xor_si256(a, b); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1.0f); +} +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1.0f); +} +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1.0f); +} +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1.0f); +} +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1.0f); +} +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1.0f); +} + +// frac. Implement this here so we can use subtraction +inline Vectorized Vectorized::frac() const { + return *this - this->trunc(); +} + +// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + __m256 a_lo, a_hi; + __m256 b_lo, b_hi; + cvtbf16_fp32(__m256i(a), a_lo, a_hi); + cvtbf16_fp32(__m256i(b), b_lo, b_hi); + auto max_lo = _mm256_max_ps(a_lo, b_lo); + auto max_hi = _mm256_max_ps(a_hi, b_hi); + auto nan_lo = _mm256_cmp_ps(a_lo, b_lo, _CMP_UNORD_Q); + auto nan_hi = _mm256_cmp_ps(a_hi, b_hi, _CMP_UNORD_Q); + // Exploit the fact that all-ones is a NaN. + auto o1 = _mm256_or_ps(max_lo, nan_lo); + auto o2 = _mm256_or_ps(max_hi, nan_hi); + return cvtfp32_bf16(o1, o2); +} + +// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + __m256 a_lo, a_hi; + __m256 b_lo, b_hi; + cvtbf16_fp32(__m256i(a), a_lo, a_hi); + cvtbf16_fp32(__m256i(b), b_lo, b_hi); + auto min_lo = _mm256_min_ps(a_lo, b_lo); + auto min_hi = _mm256_min_ps(a_hi, b_hi); + auto nan_lo = _mm256_cmp_ps(a_lo, b_lo, _CMP_UNORD_Q); + auto nan_hi = _mm256_cmp_ps(a_hi, b_hi, _CMP_UNORD_Q); + // Exploit the fact that all-ones is a NaN. + auto o1 = _mm256_or_ps(min_lo, nan_lo); + auto o2 = _mm256_or_ps(min_hi, nan_hi); + return cvtfp32_bf16(o1, o2); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min, + const Vectorized& max) { + __m256 a_lo, a_hi; + __m256 min_lo, min_hi; + __m256 max_lo, max_hi; + cvtbf16_fp32(__m256i(a), a_lo, a_hi); + cvtbf16_fp32(__m256i(min), min_lo, min_hi); + cvtbf16_fp32(__m256i(max), max_lo, max_hi); + auto o1 = _mm256_min_ps(max_lo, _mm256_max_ps(min_lo, a_lo)); + auto o2 = _mm256_min_ps(max_hi, _mm256_max_ps(min_hi, a_hi)); + return cvtfp32_bf16(o1, o2); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max) { + __m256 a_lo, a_hi; + __m256 max_lo, max_hi; + cvtbf16_fp32(__m256i(a), a_lo, a_hi); + cvtbf16_fp32(__m256i(max), max_lo, max_hi); + auto o1 = _mm256_min_ps(max_lo, a_lo); + auto o2 = _mm256_min_ps(max_hi, a_hi); + return cvtfp32_bf16(o1, o2); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min) { + __m256 a_lo, a_hi; + __m256 min_lo, min_hi; + cvtbf16_fp32(__m256i(a), a_lo, a_hi); + cvtbf16_fp32(__m256i(min), min_lo, min_hi); + auto o1 = _mm256_max_ps(min_lo, a_lo); + auto o2 = _mm256_max_ps(min_hi, a_hi); + return cvtfp32_bf16(o1, o2); +} + +template <> +inline void convert(const BFloat16* src, BFloat16* dst, int64_t n) { + int64_t i; +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (i = 0; i <= (n - Vectorized::size()); + i += Vectorized::size()) { + auto vsrc = + _mm256_loadu_si256(reinterpret_cast<__m256i*>((void*)(src + i))); + _mm256_storeu_si256(reinterpret_cast<__m256i*>((void*)(dst + i)), vsrc); + } +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (; i < n; i++) { + dst[i] = src[i]; + } +} + +template <> +inline void convert(const float* src, BFloat16* dst, int64_t n) { + int64_t i; + for (i = 0; i + Vectorized::size() <= n; + i += Vectorized::size()) { + __m256 a = _mm256_loadu_ps(&src[i]); + __m256 b = _mm256_loadu_ps(&src[i + 8]); + + __m256i bf = cvtfp32_bf16(a, b); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(&dst[i]), bf); + } + for (; i < n; i++) { + dst[i] = c10::convert(src[i]); + } +} + +template <> +inline void convert(const double* src, BFloat16* dst, int64_t n) { + auto load_float = [](const double* src) -> __m256 { + // Load one float vector from an array of doubles + __m128 a = _mm256_cvtpd_ps(_mm256_loadu_pd(src)); + __m128 b = _mm256_cvtpd_ps(_mm256_loadu_pd(src + 4)); + return _mm256_insertf128_ps(_mm256_castps128_ps256(a), b, 1); + }; + + int64_t i; + for (i = 0; i + Vectorized::size() <= n; + i += Vectorized::size()) { + __m256 a = load_float(&src[i]); + __m256 b = load_float(&src[i + 8]); + + __m256i bf = cvtfp32_bf16(a, b); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(&dst[i]), bf); + } + for (; i < n; i++) { + dst[i] = c10::convert(src[i]); + } +} + +template <> +Vectorized inline fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + __m256 a_lo, a_hi; + __m256 b_lo, b_hi; + __m256 c_lo, c_hi; + cvtbf16_fp32(__m256i(a), a_lo, a_hi); + cvtbf16_fp32(__m256i(b), b_lo, b_hi); + cvtbf16_fp32(__m256i(c), c_lo, c_hi); + auto o1 = _mm256_fmadd_ps(a_lo, b_lo, c_lo); + auto o2 = _mm256_fmadd_ps(a_hi, b_hi, c_hi); + return cvtfp32_bf16(o1, o2); +} + +CONVERT_VECTORIZED_INIT(BFloat16, bfloat16) +LOAD_FP32_VECTORIZED_INIT(BFloat16, bf16) + +#else // defined(CPU_CAPABILITY_AVX2) + +#if !( \ + defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && \ + !defined(CPU_CAPABILITY_SVE256)) +CONVERT_NON_VECTORIZED_INIT(BFloat16, bfloat16) +#endif + +LOAD_FP32_NON_VECTORIZED_INIT(BFloat16, bf16) +#endif // defined(CPU_CAPABILITY_AVX2) +} // namespace CPU_CAPABILITY +} // namespace at::vec + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_complex_double.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_complex_double.h new file mode 100644 index 0000000000000000000000000000000000000000..a8b68fdfc60003e8bf42dcaec98fdc02219bda15 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_complex_double.h @@ -0,0 +1,543 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include +#include +#include + +#if defined(CPU_CAPABILITY_AVX2) +#define SLEEF_STATIC_LIBS +#include +#endif + +namespace at::vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_AVX2) + +template <> +struct is_vec_specialized_for> : std::bool_constant { +}; + +template <> +class Vectorized> { + private: + __m256d values; + + public: + using value_type = c10::complex; + using size_type = int; + static constexpr size_type size() { + return 2; + } + Vectorized() { + values = _mm256_setzero_pd(); + } + Vectorized(__m256d v) : values(v) {} + Vectorized(c10::complex val) { + double real_value = val.real(); + double imag_value = val.imag(); + values = _mm256_setr_pd(real_value, imag_value, real_value, imag_value); + } + Vectorized(c10::complex val1, c10::complex val2) { + values = _mm256_setr_pd(val1.real(), val1.imag(), val2.real(), val2.imag()); + } + operator __m256d() const { + return values; + } + template + static Vectorized> blend( + const Vectorized>& a, + const Vectorized>& b) { + // convert c10::complex index mask to V index mask: xy -> xxyy + static_assert(mask > -1 && mask < 4, "Unexpected mask value"); + switch (mask) { + case 0: + return a; + case 1: + return _mm256_blend_pd(a.values, b.values, 0x03); + case 2: + return _mm256_blend_pd(a.values, b.values, 0x0c); + case 3: + break; + } + return b; + } + static Vectorized> blendv( + const Vectorized>& a, + const Vectorized>& b, + const Vectorized>& mask) { + // convert c10::complex index mask to V index mask: xy -> xxyy + auto mask_ = _mm256_unpacklo_pd(mask.values, mask.values); + return _mm256_blendv_pd(a.values, b.values, mask_); + } + template + static Vectorized> arange( + c10::complex base = 0., + step_t step = static_cast(1)) { + return Vectorized>(base, base + step); + } + static Vectorized> set( + const Vectorized>& a, + const Vectorized>& b, + int64_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + } + return b; + } + static Vectorized> loadu( + const void* ptr, + int64_t count = size()) { + if (count == size()) + return _mm256_loadu_pd(reinterpret_cast(ptr)); + + __at_align__ double tmp_values[2 * size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(2 * size())) { + tmp_values[i] = 0.0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(c10::complex)); + return _mm256_load_pd(tmp_values); + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + _mm256_storeu_pd(reinterpret_cast(ptr), values); + } else if (count > 0) { + double tmp_values[2 * size()]; + _mm256_storeu_pd(reinterpret_cast(tmp_values), values); + std::memcpy(ptr, tmp_values, count * sizeof(c10::complex)); + } + } + const c10::complex& operator[](int idx) const = delete; + c10::complex& operator[](int idx) = delete; + Vectorized> map( + c10::complex (*const f)(const c10::complex&)) const { + __at_align__ c10::complex tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + __m256d abs_2_() const { + auto val_2 = _mm256_mul_pd(values, values); // a*a b*b + return _mm256_hadd_pd(val_2, val_2); // a*a+b*b a*a+b*b + } + __m256d abs_() const { + auto real = _mm256_movedup_pd(values); // real real + // movehdup_pd does not exist... + auto imag = _mm256_permute_pd(values, 0xf); // imag imag + return Sleef_hypotd4_u05(real, imag); // abs abs + } + Vectorized> abs() const { + const __m256d real_mask = _mm256_castsi256_pd(_mm256_setr_epi64x( + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000)); + return _mm256_and_pd(abs_(), real_mask); // abs 0 + } + __m256d angle_() const { + // angle = atan2(b/a) + auto b_a = _mm256_permute_pd(values, 0x05); // b a + return Sleef_atan2d4_u10(values, b_a); // 90-angle angle + } + Vectorized> angle() const { + const __m256d real_mask = _mm256_castsi256_pd(_mm256_setr_epi64x( + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000)); + auto angle = _mm256_permute_pd(angle_(), 0x05); // angle 90-angle + return _mm256_and_pd(angle, real_mask); // angle 0 + } + Vectorized> sgn() const { + auto abs = abs_(); + auto zero = _mm256_setzero_pd(); + auto mask = _mm256_cmp_pd(abs, zero, _CMP_EQ_OQ); + auto div = _mm256_div_pd(values, abs); + return _mm256_blendv_pd(div, zero, mask); + } + __m256d real_() const { + const __m256d real_mask = _mm256_castsi256_pd(_mm256_setr_epi64x( + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000)); + return _mm256_and_pd(values, real_mask); + } + Vectorized> real() const { + return real_(); + } + __m256d imag_() const { + const __m256d imag_mask = _mm256_castsi256_pd(_mm256_setr_epi64x( + 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF)); + return _mm256_and_pd(values, imag_mask); + } + Vectorized> imag() const { + return _mm256_permute_pd(imag_(), 0x05); // b a + } + __m256d conj_() const { + const __m256d sign_mask = _mm256_setr_pd(0.0, -0.0, 0.0, -0.0); + return _mm256_xor_pd(values, sign_mask); // a -b + } + Vectorized> conj() const { + return conj_(); + } + Vectorized> log() const { + // Most trigonomic ops use the log() op to improve complex number + // performance. + return map(std::log); + } + Vectorized> log2() const { + const __m256d log2_ = _mm256_set1_pd(std::log(2)); + return _mm256_div_pd(log(), log2_); + } + Vectorized> log10() const { + const __m256d log10_ = _mm256_set1_pd(std::log(10)); + return _mm256_div_pd(log(), log10_); + } + Vectorized> log1p() const { + return map(std::log1p); + } + Vectorized> asin() const { + // TODO: The vectorized implementation requires special handling for the + // case where real number/imag number is 0/Inf/NaN. + // // asin(x) + // // = -i*ln(iz + sqrt(1 -z^2)) + // // = -i*ln((ai - b) + sqrt(1 - (a + bi)*(a + bi))) + // // = -i*ln((-b + ai) + sqrt(1 - (a**2 - b**2) - 2*abi)) + // const __m256d one = _mm256_set1_pd(1); + + // auto conj = conj_(); + // auto b_a = _mm256_permute_pd(conj, 0x05); //-b a + // auto ab = _mm256_mul_pd(conj, b_a); //-ab + // -ab auto im = _mm256_add_pd(ab, ab); //-2ab -2ab + + // auto val_2 = _mm256_mul_pd(values, values); // a*a + // b*b auto re = _mm256_hsub_pd(val_2, _mm256_permute_pd(val_2, 0x05)); // + // a*a-b*b b*b-a*a re = _mm256_sub_pd(one, re); + + // auto root = Vectorized(_mm256_blend_pd(re, im, 0x0A)).sqrt(); //sqrt(re + + // i*im) auto ln = Vectorized(_mm256_add_pd(b_a, root)).log(); //ln(iz + + // sqrt()) return Vectorized(_mm256_permute_pd(ln.values, 0x05)).conj(); + // //-i*ln() + return map(std::asin); + } + Vectorized> acos() const { + // acos(x) = pi/2 - asin(x) + constexpr auto pi_2d = c10::pi / 2; + const __m256d pi_2 = _mm256_setr_pd(pi_2d, 0.0, pi_2d, 0.0); + return _mm256_sub_pd(pi_2, asin()); + } + Vectorized> atan() const; + Vectorized> atanh() const { + return map(std::atanh); + } + Vectorized> exp() const { + // TODO: The vectorized implementation requires special handling for the + // case where real number/imag number is 0/Inf/NaN. + // //exp(a + bi) + // // = exp(a)*(cos(b) + sin(b)i) + // auto exp = Sleef_expd4_u10(values); //exp(a) exp(b) exp = + // _mm256_blend_pd(exp, _mm256_permute_pd(exp, 0x05), 0x0A); //exp(a) + // exp(a) + + // auto sin_cos = Sleef_sincosd4_u10(values); //[sin(a), cos(a)] [sin(b), + // cos(b)] auto cos_sin = _mm256_blend_pd(_mm256_permute_pd(sin_cos.y, + // 0x05), + // sin_cos.x, 0x0A); //cos(b) sin(b) + // return _mm256_mul_pd(exp, cos_sin); + return map(std::exp); + } + Vectorized> exp2() const { + // Use identity 2**x = exp(log(2) * x) + const __m256d ln_2 = _mm256_set1_pd(c10::ln_2); + Vectorized> scaled_values = + _mm256_mul_pd(values, ln_2); + return scaled_values.exp(); + } + Vectorized> expm1() const { + return map(std::expm1); + } + Vectorized> sin() const { + return map(std::sin); + } + Vectorized> sinh() const { + return map(std::sinh); + } + Vectorized> cos() const { + return map(std::cos); + } + Vectorized> cosh() const { + return map(std::cosh); + } + Vectorized> ceil() const { + return _mm256_ceil_pd(values); + } + Vectorized> floor() const { + return _mm256_floor_pd(values); + } + Vectorized> neg() const { + auto zero = _mm256_setzero_pd(); + return _mm256_sub_pd(zero, values); + } + Vectorized> round() const { + return _mm256_round_pd( + values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + } + Vectorized> tan() const { + return map(std::tan); + } + Vectorized> tanh() const { + return map(std::tanh); + } + Vectorized> trunc() const { + return _mm256_round_pd(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); + } + Vectorized> sqrt() const { + return map(std::sqrt); + } + Vectorized> reciprocal() const; + Vectorized> rsqrt() const { + return sqrt().reciprocal(); + } + Vectorized> pow( + const Vectorized>& exp) const { + __at_align__ c10::complex x_tmp[size()]; + __at_align__ c10::complex y_tmp[size()]; + store(x_tmp); + exp.store(y_tmp); + for (const auto i : c10::irange(size())) { + x_tmp[i] = std::pow(x_tmp[i], y_tmp[i]); + } + return loadu(x_tmp); + } + // Comparison using the _CMP_**_OQ predicate. + // `O`: get false if an operand is NaN + // `Q`: do not raise if an operand is NaN + Vectorized> operator==( + const Vectorized>& other) const { + return _mm256_cmp_pd(values, other.values, _CMP_EQ_OQ); + } + Vectorized> operator!=( + const Vectorized>& other) const { + return _mm256_cmp_pd(values, other.values, _CMP_NEQ_UQ); + } + Vectorized> operator<( + const Vectorized>& /*unused*/) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized> operator<=( + const Vectorized>& /*unused*/) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized> operator>( + const Vectorized>& /*unused*/) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized> operator>=( + const Vectorized>& /*unused*/) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vectorized> eq( + const Vectorized>& other) const; + Vectorized> ne( + const Vectorized>& other) const; +}; + +template <> +Vectorized> inline operator+( + const Vectorized>& a, + const Vectorized>& b) { + return _mm256_add_pd(a, b); +} + +template <> +Vectorized> inline operator-( + const Vectorized>& a, + const Vectorized>& b) { + return _mm256_sub_pd(a, b); +} + +template <> +Vectorized> inline operator*( + const Vectorized>& a, + const Vectorized>& b) { + //(a + bi) * (c + di) = (ac - bd) + (ad + bc)i + const __m256d sign_mask = _mm256_setr_pd(0.0, -0.0, 0.0, -0.0); + auto ac_bd = _mm256_mul_pd(a, b); // ac bd + + auto d_c = _mm256_permute_pd(b, 0x05); // d c + d_c = _mm256_xor_pd(sign_mask, d_c); // d -c + auto ad_bc = _mm256_mul_pd(a, d_c); // ad -bc + + auto ret = _mm256_hsub_pd(ac_bd, ad_bc); // ac - bd ad + bc + return ret; +} + +template <> +Vectorized> inline operator/( + const Vectorized>& a, + const Vectorized>& b) { + // TODO: The vectorized implementation requires special handling for the case + // where real number/imag number is 0/Inf/NaN. + // //re + im*i = (a + bi) / (c + di) + // auto mask = _mm256_set1_pd(-0.f); + // auto fabs_cd = _mm256_andnot_pd(mask, b); // |c| |d| + // auto fabs_dc = _mm256_permute_pd(fabs_cd, 0x05); // |d| |c| + // auto scale = _mm256_div_pd(_mm256_set1_pd(1.0f), _mm256_max_pd(fabs_cd, + // fabs_dc)); // 1/sc 1/sc auto a2 = _mm256_mul_pd(a, scale); // + // a/sc b/sc auto b2 = _mm256_mul_pd(b, scale); // c/sc d/sc + // auto acbd2 = _mm256_mul_pd(a2, b2); + + // const __m256d sign_mask = _mm256_setr_pd(-0.0, 0.0, -0.0, 0.0); + // auto dc2 = _mm256_permute_pd(b2, 0x05); // d/sc c/sc + // dc2 = _mm256_xor_pd(sign_mask, dc2); // -d/|c,d| c/sc + // auto adbc2 = _mm256_mul_pd(a2, dc2); //-ad/sc^2 bc/sc^2 + // auto res2 = _mm256_hadd_pd(acbd2, adbc2); //(ac+bd)/sc^2 (bc-ad)/sc^2 + + // // get the denominator + // auto denom2 = Vectorized>(b2).abs_2_(); // + // (c^2+d^2)/sc^2 (c^2+d^2)/sc^2 res2 = _mm256_div_pd(res2, denom2); return + // res2; + __at_align__ c10::complex + tmp1[Vectorized>::size()]; + __at_align__ c10::complex + tmp2[Vectorized>::size()]; + __at_align__ c10::complex + out[Vectorized>::size()]; + a.store(tmp1); + b.store(tmp2); + for (const auto i : c10::irange(Vectorized>::size())) { + out[i] = tmp1[i] / tmp2[i]; + } + return _mm256_loadu_pd(reinterpret_cast(out)); +} + +// reciprocal. Implement this here so we can use multiplication. +inline Vectorized> Vectorized< + c10::complex>::reciprocal() const { + // TODO: The vectorized implementation requires special handling for the case + // where real number/imag number is 0/Inf/NaN. + // //re + im*i = (a + bi) / (c + di) + // //re = (ac + bd)/abs_2() = c/abs_2() + // //im = (bc - ad)/abs_2() = d/abs_2() + // const __m256d sign_mask = _mm256_setr_pd(0.0, -0.0, 0.0, -0.0); + // auto c_d = _mm256_xor_pd(sign_mask, values); //c -d + // return _mm256_div_pd(c_d, abs_2_()); + __at_align__ c10::complex tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + tmp[i] = c10::complex(1) / tmp[i]; + } + return loadu(tmp); +} + +inline Vectorized> Vectorized>::atan() + const { + // TODO: The vectorized implementation requires special handling for the case + // where real number/imag number is 0/Inf/NaN. + // // atan(x) = i/2 * ln((i + z)/(i - z)) + // const __m256d i = _mm256_setr_pd(0.0, 1.0, 0.0, 1.0); + // const Vectorized i_half = _mm256_setr_pd(0.0, 0.5, 0.0, 0.5); + + // auto sum = Vectorized(_mm256_add_pd(i, values)); // a + // 1+b auto sub = Vectorized(_mm256_sub_pd(i, values)); // -a 1-b auto + // ln = (sum/sub).log(); // ln((i + + // z)/(i - z)) return i_half*ln; // i/2*ln() + return map(std::atan); +} + +template <> +Vectorized> inline maximum( + const Vectorized>& a, + const Vectorized>& b) { + auto abs_a = a.abs_2_(); + auto abs_b = b.abs_2_(); + auto mask = _mm256_cmp_pd(abs_a, abs_b, _CMP_LT_OQ); + auto max = _mm256_blendv_pd(a, b, mask); + // Exploit the fact that all-ones is a NaN. + auto isnan = _mm256_cmp_pd(abs_a, abs_b, _CMP_UNORD_Q); + return _mm256_or_pd(max, isnan); +} + +template <> +Vectorized> inline minimum( + const Vectorized>& a, + const Vectorized>& b) { + auto abs_a = a.abs_2_(); + auto abs_b = b.abs_2_(); + auto mask = _mm256_cmp_pd(abs_a, abs_b, _CMP_GT_OQ); + auto min = _mm256_blendv_pd(a, b, mask); + // Exploit the fact that all-ones is a NaN. + auto isnan = _mm256_cmp_pd(abs_a, abs_b, _CMP_UNORD_Q); + return _mm256_or_pd(min, isnan); +} + +template <> +Vectorized> inline operator&( + const Vectorized>& a, + const Vectorized>& b) { + return _mm256_and_pd(a, b); +} + +template <> +Vectorized> inline operator|( + const Vectorized>& a, + const Vectorized>& b) { + return _mm256_or_pd(a, b); +} + +template <> +Vectorized> inline operator^( + const Vectorized>& a, + const Vectorized>& b) { + return _mm256_xor_pd(a, b); +} + +inline Vectorized> Vectorized>::eq( + const Vectorized>& other) const { + auto eq = (*this == other); // compares real and imag individually + // If both real numbers and imag numbers are equal, then the complex numbers + // are equal + return (eq.real() & eq.imag()) & + Vectorized>(_mm256_set1_pd(1.0)); +} + +inline Vectorized> Vectorized>::ne( + const Vectorized>& other) const { + auto ne = (*this != other); // compares real and imag individually + // If either real numbers or imag numbers are not equal, then the complex + // numbers are not equal + return (ne.real() | ne.imag()) & + Vectorized>(_mm256_set1_pd(1.0)); +} + +#endif + +} // namespace CPU_CAPABILITY +} // namespace at::vec + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_complex_float.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_complex_float.h new file mode 100644 index 0000000000000000000000000000000000000000..96d0530f038d32d5eebfd82269c1df7cd5ae5daa --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_complex_float.h @@ -0,0 +1,625 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include +#include +#include +#if defined(CPU_CAPABILITY_AVX2) +#define SLEEF_STATIC_LIBS +#include +#endif + +namespace at::vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_AVX2) + +template <> +struct is_vec_specialized_for> : std::bool_constant { +}; + +template <> +class Vectorized> { + private: + __m256 values; + + public: + using value_type = c10::complex; + using size_type = int; + static constexpr size_type size() { + return 4; + } + Vectorized() { + values = _mm256_setzero_ps(); + } + Vectorized(__m256 v) : values(v) {} + Vectorized(c10::complex val) { + float real_value = val.real(); + float imag_value = val.imag(); + values = _mm256_setr_ps( + real_value, + imag_value, + real_value, + imag_value, + real_value, + imag_value, + real_value, + imag_value); + } + Vectorized( + c10::complex val1, + c10::complex val2, + c10::complex val3, + c10::complex val4) { + values = _mm256_setr_ps( + val1.real(), + val1.imag(), + val2.real(), + val2.imag(), + val3.real(), + val3.imag(), + val4.real(), + val4.imag()); + } + operator __m256() const { + return values; + } + template + static Vectorized> blend( + const Vectorized>& a, + const Vectorized>& b) { + // convert c10::complex index mask to V index mask: xy -> xxyy + static_assert(mask > -1 && mask < 16, "Unexpected mask range"); + switch (mask) { + case 0: + return a; + case 1: + return _mm256_blend_ps( + a.values, b.values, 0x03); // b0000 0001 = b0000 0011 + case 2: + return _mm256_blend_ps( + a.values, b.values, 0x0C); // b0000 0010 = b0000 1100 + case 3: + return _mm256_blend_ps( + a.values, b.values, 0x0F); // b0000 0011 = b0000 1111 + case 4: + return _mm256_blend_ps( + a.values, b.values, 0x30); // b0000 0100 = b0011 0000 + case 5: + return _mm256_blend_ps( + a.values, b.values, 0x33); // b0000 0101 = b0011 0011 + case 6: + return _mm256_blend_ps( + a.values, b.values, 0x3C); // b0000 0110 = b0011 1100 + case 7: + return _mm256_blend_ps( + a.values, b.values, 0x3F); // b0000 0111 = b0011 1111 + case 8: + return _mm256_blend_ps( + a.values, b.values, 0xC0); // b0000 1000 = b1100 0000 + case 9: + return _mm256_blend_ps( + a.values, b.values, 0xC3); // b0000 1001 = b1100 0011 + case 10: + return _mm256_blend_ps( + a.values, b.values, 0xCC); // b0000 1010 = b1100 1100 + case 11: + return _mm256_blend_ps( + a.values, b.values, 0xCF); // b0000 1011 = b1100 1111 + case 12: + return _mm256_blend_ps( + a.values, b.values, 0xF0); // b0000 1100 = b1111 0000 + case 13: + return _mm256_blend_ps( + a.values, b.values, 0xF3); // b0000 1101 = b1111 0011 + case 14: + return _mm256_blend_ps( + a.values, b.values, 0xFC); // b0000 1110 = b1111 1100 + default: + break; + } + return b; + } + static Vectorized> blendv( + const Vectorized>& a, + const Vectorized>& b, + const Vectorized>& mask) { + // convert c10::complex index mask to V index mask: xy -> xxyy + auto mask_ = _mm256_unpacklo_ps(mask.values, mask.values); + return _mm256_blendv_ps(a.values, b.values, mask_); + } + template + static Vectorized> arange( + c10::complex base = 0., + step_t step = static_cast(1)) { + return Vectorized>( + base, + base + step, + base + c10::complex(2) * step, + base + c10::complex(3) * step); + } + static Vectorized> set( + const Vectorized>& a, + const Vectorized>& b, + int64_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + } + return b; + } + static Vectorized> loadu( + const void* ptr, + int64_t count = size()) { + if (count == size()) + return _mm256_loadu_ps(reinterpret_cast(ptr)); + + __at_align__ float tmp_values[2 * size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(2 * size())) { + tmp_values[i] = 0.0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(c10::complex)); + return _mm256_load_ps(tmp_values); + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + _mm256_storeu_ps(reinterpret_cast(ptr), values); + } else if (count > 0) { + float tmp_values[2 * size()]; + _mm256_storeu_ps(reinterpret_cast(tmp_values), values); + std::memcpy(ptr, tmp_values, count * sizeof(c10::complex)); + } + } + const c10::complex& operator[](int idx) const = delete; + c10::complex& operator[](int idx) = delete; + Vectorized> map( + c10::complex (*const f)(const c10::complex&)) const { + __at_align__ c10::complex tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + __m256 abs_2_() const { + auto val_2 = _mm256_mul_ps(values, values); // a*a b*b + auto ret = _mm256_hadd_ps(val_2, val_2); // a*a+b*b a*a+b*b + return _mm256_permute_ps(ret, 0xD8); + } + __m256 abs_() const { + auto real = _mm256_moveldup_ps(values); // real real + auto imag = _mm256_movehdup_ps(values); // imag imag + return Sleef_hypotf8_u05(real, imag); // abs abs + } + Vectorized> abs() const { + const __m256 real_mask = _mm256_castsi256_ps(_mm256_setr_epi32( + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000)); + return _mm256_and_ps(abs_(), real_mask); // abs 0 + } + __m256 angle_() const { + // angle = atan2(b/a) + auto b_a = _mm256_permute_ps(values, 0xB1); // b a + return Sleef_atan2f8_u10(values, b_a); // 90-angle angle + } + Vectorized> angle() const { + const __m256 real_mask = _mm256_castsi256_ps(_mm256_setr_epi32( + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000)); + auto angle = _mm256_permute_ps(angle_(), 0xB1); // angle 90-angle + return _mm256_and_ps(angle, real_mask); // angle 0 + } + Vectorized> sgn() const { + auto abs = abs_(); + auto zero = _mm256_setzero_ps(); + auto mask = _mm256_cmp_ps(abs, zero, _CMP_EQ_OQ); + auto div = _mm256_div_ps(values, abs); + return _mm256_blendv_ps(div, zero, mask); + } + __m256 real_() const { + const __m256 real_mask = _mm256_castsi256_ps(_mm256_setr_epi32( + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000)); + return _mm256_and_ps(values, real_mask); + } + Vectorized> real() const { + return real_(); + } + __m256 imag_() const { + const __m256 imag_mask = _mm256_castsi256_ps(_mm256_setr_epi32( + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF)); + return _mm256_and_ps(values, imag_mask); + } + Vectorized> imag() const { + return _mm256_permute_ps(imag_(), 0xB1); // b a + } + __m256 conj_() const { + const __m256 sign_mask = + _mm256_setr_ps(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0); + return _mm256_xor_ps(values, sign_mask); // a -b + } + Vectorized> conj() const { + return conj_(); + } + Vectorized> log() const { + // Most trigonomic ops use the log() op to improve complex number + // performance. + return map(std::log); + } + Vectorized> log2() const { + const __m256 log2_ = _mm256_set1_ps(std::log(2)); + return _mm256_div_ps(log(), log2_); + } + Vectorized> log10() const { + const __m256 log10_ = _mm256_set1_ps(std::log(10)); + return _mm256_div_ps(log(), log10_); + } + Vectorized> log1p() const { + return map(std::log1p); + } + Vectorized> asin() const { + // TODO: The vectorized implementation requires special handling for the + // case where real number/imag number is 0/Inf/NaN. + // // asin(x) + // // = -i*ln(iz + sqrt(1 -z^2)) + // // = -i*ln((ai - b) + sqrt(1 - (a + bi)*(a + bi))) + // // = -i*ln((-b + ai) + sqrt(1 - (a**2 - b**2) - 2*abi)) + // const __m256 one = _mm256_set1_ps(1); + + // auto conj = conj_(); + // auto b_a = _mm256_permute_ps(conj, 0xB1); //-b a + // auto ab = _mm256_mul_ps(conj, b_a); //-ab + // -ab auto im = _mm256_add_ps(ab, ab); //-2ab -2ab + + // auto val_2 = _mm256_mul_ps(values, values); // a*a + // b*b auto re = _mm256_hsub_ps(val_2, _mm256_permute_ps(val_2, 0xB1)); // + // a*a-b*b b*b-a*a re = _mm256_permute_ps(re, 0xD8); re = + // _mm256_sub_ps(one, re); + + // auto root = Vectorized(_mm256_blend_ps(re, im, 0xAA)).sqrt(); //sqrt(re + + // i*im) auto ln = Vectorized(_mm256_add_ps(b_a, root)).log(); //ln(iz + + // sqrt()) return Vectorized(_mm256_permute_ps(ln.values, 0xB1)).conj(); + // //-i*ln() + return map(std::asin); + } + Vectorized> acos() const { + return map(std::acos); + } + Vectorized> atan() const; + Vectorized> atanh() const { + return map(std::atanh); + } + Vectorized> exp() const { + // TODO: The vectorized implementation requires special handling for the + // case where real number/imag number is 0/Inf/NaN. + // //exp(a + bi) + // // = exp(a)*(cos(b) + sin(b)i) + // auto exp = Sleef_expf8_u10(values); //exp(a) exp(b) exp = + // _mm256_blend_ps(exp, _mm256_permute_ps(exp, 0xB1), 0xAA); //exp(a) + // exp(a) + + // auto sin_cos = Sleef_sincosf8_u10(values); //[sin(a), cos(a)] [sin(b), + // cos(b)] auto cos_sin = _mm256_blend_ps(_mm256_permute_ps(sin_cos.y, + // 0xB1), + // sin_cos.x, 0xAA); //cos(b) sin(b) + // return _mm256_mul_ps(exp, cos_sin); + return map(std::exp); + } + Vectorized> exp2() const { + // Use identity 2**x = exp(log(2) * x) + const __m256 ln_2 = _mm256_set1_ps(c10::ln_2); + Vectorized> scaled_values = _mm256_mul_ps(values, ln_2); + return scaled_values.exp(); + } + Vectorized> expm1() const { + return map(std::expm1); + } + Vectorized> sin() const { + return map(std::sin); + } + Vectorized> sinh() const { + return map(std::sinh); + } + Vectorized> cos() const { + return map(std::cos); + } + Vectorized> cosh() const { + return map(std::cosh); + } + Vectorized> ceil() const { + return _mm256_ceil_ps(values); + } + Vectorized> floor() const { + return _mm256_floor_ps(values); + } + Vectorized> neg() const { + auto zero = _mm256_setzero_ps(); + return _mm256_sub_ps(zero, values); + } + Vectorized> round() const { + return _mm256_round_ps( + values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + } + Vectorized> tan() const { + return map(std::tan); + } + Vectorized> tanh() const { + return map(std::tanh); + } + Vectorized> trunc() const { + return _mm256_round_ps(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); + } + Vectorized> sqrt() const { + return map(std::sqrt); + } + Vectorized> reciprocal() const; + Vectorized> rsqrt() const { + return sqrt().reciprocal(); + } + Vectorized> pow( + const Vectorized>& exp) const { + __at_align__ c10::complex x_tmp[size()]; + __at_align__ c10::complex y_tmp[size()]; + store(x_tmp); + exp.store(y_tmp); + for (const auto i : c10::irange(size())) { + x_tmp[i] = std::pow(x_tmp[i], y_tmp[i]); + } + return loadu(x_tmp); + } + // Comparison using the _CMP_**_OQ predicate. + // `O`: get false if an operand is NaN + // `Q`: do not raise if an operand is NaN + Vectorized> operator==( + const Vectorized>& other) const { + return _mm256_cmp_ps(values, other.values, _CMP_EQ_OQ); + } + Vectorized> operator!=( + const Vectorized>& other) const { + return _mm256_cmp_ps(values, other.values, _CMP_NEQ_UQ); + } + Vectorized> operator<( + const Vectorized>& /*other*/) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized> operator<=( + const Vectorized>& /*other*/) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized> operator>( + const Vectorized>& /*other*/) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized> operator>=( + const Vectorized>& /*other*/) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vectorized> eq( + const Vectorized>& other) const; + Vectorized> ne( + const Vectorized>& other) const; +}; + +template <> +Vectorized> inline operator+( + const Vectorized>& a, + const Vectorized>& b) { + return _mm256_add_ps(a, b); +} + +template <> +Vectorized> inline operator-( + const Vectorized>& a, + const Vectorized>& b) { + return _mm256_sub_ps(a, b); +} + +template <> +Vectorized> inline operator*( + const Vectorized>& a, + const Vectorized>& b) { + //(a + bi) * (c + di) = (ac - bd) + (ad + bc)i + const __m256 sign_mask = + _mm256_setr_ps(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0); + auto ac_bd = _mm256_mul_ps(a, b); // ac bd + + auto d_c = _mm256_permute_ps(b, 0xB1); // d c + d_c = _mm256_xor_ps(sign_mask, d_c); // d -c + auto ad_bc = _mm256_mul_ps(a, d_c); // ad -bc + + auto ret = _mm256_hsub_ps(ac_bd, ad_bc); // ac - bd ad + bc + ret = _mm256_permute_ps(ret, 0xD8); + return ret; +} + +template <> +Vectorized> inline operator/( + const Vectorized>& a, + const Vectorized>& b) { + // TODO: The vectorized implementation requires special handling for the case + // where real number/imag number is 0/Inf/NaN. + // //re + im*i = (a + bi) / (c + di) + // auto mask = _mm256_set1_ps(-0.f); + // auto fabs_cd = _mm256_andnot_ps(mask, b); // |c| |d| + // auto fabs_dc = _mm256_permute_ps(fabs_cd, 0xB1); // |d| |c| + // auto scale = _mm256_rcp_ps(_mm256_max_ps(fabs_cd, fabs_dc)); // 1/sc 1/sc + // auto a2 = _mm256_mul_ps(a, scale); // a/sc b/sc + // auto b2 = _mm256_mul_ps(b, scale); // c/sc d/sc + // auto acbd2 = _mm256_mul_ps(a2, b2); + + // const __m256 sign_mask = _mm256_setr_ps(-0.0, 0.0, -0.0, 0.0, -0.0, 0.0, + // -0.0, 0.0); auto dc2 = _mm256_permute_ps(b2, 0xB1); // d/sc c/sc + // dc2 = _mm256_xor_ps(sign_mask, dc2); // -d/|c,d| c/sc + // auto adbc2 = _mm256_mul_ps(a2, dc2); //-ad/sc^2 bc/sc^2 + // auto res2 = _mm256_hadd_ps(acbd2, adbc2); //(ac+bd)/sc^2 (bc-ad)/sc^2 + // res2 = _mm256_permute_ps(res2, 0xD8); + + // // get the denominator + // auto denom2 = Vectorized>(b2).abs_2_(); // + // (c^2+d^2)/sc^2 (c^2+d^2)/sc^2 res2 = _mm256_div_ps(res2, denom2); return + // res2; + __at_align__ c10::complex + tmp1[Vectorized>::size()]; + __at_align__ c10::complex + tmp2[Vectorized>::size()]; + __at_align__ c10::complex out[Vectorized>::size()]; + a.store(tmp1); + b.store(tmp2); + for (const auto i : c10::irange(Vectorized>::size())) { + out[i] = tmp1[i] / tmp2[i]; + } + return _mm256_loadu_ps(reinterpret_cast(out)); +} + +// reciprocal. Implement this here so we can use multiplication. +inline Vectorized> Vectorized< + c10::complex>::reciprocal() const { + // TODO: The vectorized implementation requires special handling for the case + // where real number/imag number is 0/Inf/NaN. + // //re + im*i = (a + bi) / (c + di) + // //re = (ac + bd)/abs_2() = c/abs_2() + // //im = (bc - ad)/abs_2() = d/abs_2() + // const __m256 sign_mask = _mm256_setr_ps(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, + // 0.0, -0.0); auto c_d = _mm256_xor_ps(sign_mask, values); //c -d + // return _mm256_div_ps(c_d, abs_2_()); + __at_align__ c10::complex tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + tmp[i] = c10::complex(1) / tmp[i]; + } + return loadu(tmp); +} + +inline Vectorized> Vectorized>::atan() + const { + // TODO: The vectorized implementation requires special handling for the case + // where real number/imag number is 0/Inf/NaN. + // // atan(x) = i/2 * ln((i + z)/(i - z)) + // const __m256 i = _mm256_setr_ps(0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0); + // const Vectorized i_half = _mm256_setr_ps(0.0, 0.5, 0.0, 0.5, 0.0, 0.5, 0.0, + // 0.5); + + // auto sum = Vectorized(_mm256_add_ps(i, values)); // a + // 1+b auto sub = Vectorized(_mm256_sub_ps(i, values)); // -a 1-b auto + // ln = (sum/sub).log(); // ln((i + + // z)/(i - z)) return i_half*ln; // i/2*ln() + return map(std::atan); +} + +template <> +Vectorized> inline maximum( + const Vectorized>& a, + const Vectorized>& b) { + auto abs_a = a.abs_2_(); + auto abs_b = b.abs_2_(); + auto mask = _mm256_cmp_ps(abs_a, abs_b, _CMP_LT_OQ); + auto max = _mm256_blendv_ps(a, b, mask); + // Exploit the fact that all-ones is a NaN. + auto isnan = _mm256_cmp_ps(abs_a, abs_b, _CMP_UNORD_Q); + return _mm256_or_ps(max, isnan); +} + +template <> +Vectorized> inline minimum( + const Vectorized>& a, + const Vectorized>& b) { + auto abs_a = a.abs_2_(); + auto abs_b = b.abs_2_(); + auto mask = _mm256_cmp_ps(abs_a, abs_b, _CMP_GT_OQ); + auto min = _mm256_blendv_ps(a, b, mask); + // Exploit the fact that all-ones is a NaN. + auto isnan = _mm256_cmp_ps(abs_a, abs_b, _CMP_UNORD_Q); + return _mm256_or_ps(min, isnan); +} + +template <> +Vectorized> inline operator&( + const Vectorized>& a, + const Vectorized>& b) { + return _mm256_and_ps(a, b); +} + +template <> +Vectorized> inline operator|( + const Vectorized>& a, + const Vectorized>& b) { + return _mm256_or_ps(a, b); +} + +template <> +Vectorized> inline operator^( + const Vectorized>& a, + const Vectorized>& b) { + return _mm256_xor_ps(a, b); +} + +inline Vectorized> Vectorized>::eq( + const Vectorized>& other) const { + auto eq = (*this == other); // compares real and imag individually + // If both real numbers and imag numbers are equal, then the complex numbers + // are equal + return (eq.real() & eq.imag()) & + Vectorized>(_mm256_set1_ps(1.0f)); +} + +inline Vectorized> Vectorized>::ne( + const Vectorized>& other) const { + auto ne = (*this != other); // compares real and imag individually + // If either real numbers or imag numbers are not equal, then the complex + // numbers are not equal + return (ne.real() | ne.imag()) & + Vectorized>(_mm256_set1_ps(1.0f)); +} + +#endif + +} // namespace CPU_CAPABILITY +} // namespace at::vec + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_convert.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_convert.h new file mode 100644 index 0000000000000000000000000000000000000000..4ea85701b7cbbef81f26709ea08be38cdea3e108 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_convert.h @@ -0,0 +1,370 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include + +namespace at::vec { +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + VectorizedN result; + __m256 value; + cvtbf16_fp32(_mm256_castsi256_si128(src[0]), value); + result[0] = value; + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply(const VectorizedN& src) { + VectorizedN result; + __m256 value; + cvtfp16_fp32(_mm256_castsi256_si128(src[0]), value); + result[0] = value; + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + VectorizedN result; + result[0] = _mm256_castsi128_si256(cvtfp32_bf16(src[0])); + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + VectorizedN result; + result[0] = convert_float_bfloat16(src[0], src[1]); + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + VectorizedN result; + std::tie(result[0], result[1]) = convert_bfloat16_float(src[0]); + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply(const VectorizedN& src) { + VectorizedN result; + result[0] = _mm256_castsi128_si256(cvtfp32_fp16(src[0])); + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply(const VectorizedN& src) { + VectorizedN result; + result[0] = convert_float_half(src[0], src[1]); + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply(const VectorizedN& src) { + VectorizedN result; + std::tie(result[0], result[1]) = convert_half_float(src[0]); + return result; + } +}; + +template <> +inline Vectorized convert_to_fp_of_same_size( + const Vectorized& src); + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + auto low_double = at::vec::convert_to_fp_of_same_size(src[0]); + auto low = _mm256_cvtpd_ps(low_double); + auto high_double = at::vec::convert_to_fp_of_same_size(src[1]); + auto high = _mm256_cvtpd_ps(high_double); + return Vectorized( + _mm256_insertf128_ps(_mm256_castps128_ps256(low), high, 1)); + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + // Scalarization is the most reliable way of converting fp to int64 on AVX2. + // Check: https://stackoverflow.com/questions/41144668 + float buffer[8]; + src.store(buffer); + at::vec::VectorizedN result; + result[0] = Vectorized( + static_cast(buffer[0]), + static_cast(buffer[1]), + static_cast(buffer[2]), + static_cast(buffer[3])); + result[1] = Vectorized( + static_cast(buffer[4]), + static_cast(buffer[5]), + static_cast(buffer[6]), + static_cast(buffer[7])); + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + auto low = _mm256_shuffle_epi32(src[0], _MM_SHUFFLE(2, 0, 2, 0)); + auto high = _mm256_shuffle_epi32(src[1], _MM_SHUFFLE(2, 0, 2, 0)); + auto low_perm = _mm256_permute4x64_epi64(low, _MM_SHUFFLE(3, 1, 2, 0)); + auto high_perm = _mm256_permute4x64_epi64(high, _MM_SHUFFLE(3, 1, 2, 0)); + return Vectorized(_mm256_blend_epi32(low_perm, high_perm, 0xF0)); + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + at::vec::VectorizedN result; + result[0] = _mm256_cvtepi32_epi64(_mm256_castsi256_si128(src[0])); + result[1] = _mm256_cvtepi32_epi64(_mm256_extracti128_si256(src[0], 1)); + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + auto src128 = _mm256_castsi256_si128(src[0]); + return Vectorized(_mm256_cvtepi8_epi32(src128)); + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + auto src128 = _mm256_castsi256_si128(src[0]); + return Vectorized(_mm256_cvtepu8_epi32(src128)); + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + return Vectorized(_mm256_cvttps_epi32(src[0])); + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + return Vectorized(_mm256_cvtepi32_ps(src[0])); + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + auto src128 = _mm256_castsi256_si128(src[0]); + return Vectorized(_mm256_cvtepu8_epi16(src128)); + } +}; + +template +struct VecConvert< + dst_t, + 1, + src_t, + 1, + typename std::enable_if_t< + (is_reduced_floating_point_v && is_8bit_integer_v) || + (is_reduced_floating_point_v && is_8bit_integer_v), + void>> { + static inline VectorizedN apply(const VectorizedN& src) { + VectorizedN tmp_fp32 = VecConvert::apply(src); + return VecConvert::apply(tmp_fp32); + } +}; + +template +struct VecConvert< + dst_t, + 1, + float, + 2, + typename std::enable_if_t, void>> { + static inline VectorizedN apply(const VectorizedN& src) { + at::vec::Vectorized vec1 = convert_float_to_int8(src[0]); + at::vec::Vectorized vec2 = convert_float_to_int8(src[1]); + __m128 lane2 = _mm256_castps256_ps128(_mm256_castsi256_ps(vec2)); + __m256 combined = _mm256_insertf128_ps(_mm256_castsi256_ps(vec1), lane2, 1); + // Shuffle [191:128] bit from combined in to [127:64] bit of result + __m256i result = + _mm256_permute4x64_epi64(_mm256_castps_si256(combined), 0b11011000); + return at::vec::Vectorized(result); + } +}; + +template +struct VecConvert< + dst_t, + 1, + float, + 1, + typename std::enable_if_t, void>> { + static inline VectorizedN apply(const VectorizedN& src) { + return convert_float_to_int8(src[0]); + } +}; + +template +struct VecConvert< + float, + 2, + src_t, + 1, + typename std::enable_if_t, void>> { + static inline VectorizedN apply(const VectorizedN& src) { + // Shuffle [127:64] bit from src[0] in to [191:128] bit of shuffled + __m256i shuffled = _mm256_permute4x64_epi64(src[0], 0b11011000); + __m256i src2 = + _mm256_castsi128_si256(_mm_castps_si128(_mm256_extractf128_ps( + _mm256_castsi256_ps(shuffled), 1) // Extract the second 128-bit lane + )); + return VectorizedN( + convert_int8_to_float(src[0]), + convert_int8_to_float(src2)); + } +}; + +template +struct VecConvert< + dst_t, + 1, + int64_t, + 2, + std::enable_if_t< + std::is_same_v || std::is_same_v>> { + static inline VectorizedN apply( + const VectorizedN& src) { + return VecConvert::apply( + VecConvert::apply(src)); + } +}; + +#endif /* defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) */ + +#if (defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)) +template +struct VecConvert< + float, + 1, + src_t, + 1, + typename std::enable_if_t, void>> { + static inline VectorizedN apply(const VectorizedN& src) { + return convert_int8_to_float(src[0]); + } +}; +#endif + +#if defined(CPU_CAPABILITY_SVE256) && defined(__ARM_FEATURE_BF16) + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + VectorizedN res; + // Load 16-bit unsigned integers from src into an SVE vector + svuint16_t u16x4 = + svld1_u16(svptrue_b16(), reinterpret_cast(&src[0])); + // Zero-extend to 32-bit SVE does not have direct vmovl_u16 equivalent. + vls_uint32_t u32x4 = + svreinterpret_u32_u16(svzip1_u16(svdup_n_u16(0), u16x4)); + // Reinterpret as float32 + vls_float32_t f32x4 = svreinterpret_f32_u32(u32x4); + res[0] = Vectorized(f32x4); + return res; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + VectorizedN res; + std::tie(res[0], res[1]) = convert_bfloat16_float(src[0]); + return res; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + VectorizedN res; + res[0] = convert_float_bfloat16(src[0], src[1]); + return res; + } +}; + +#endif // defined(CPU_CAPABILITY_SVE256) && defined(__ARM_FEATURE_BF16) + +template +struct VecConvert< + float, + 1, + src_t, + 1, + typename std::enable_if_t, void>> { + static inline VectorizedN apply(const VectorizedN& src) { + auto [res_vec1, res_vec2] = convert_to_float(src[0]); + return res_vec1; + } +}; + +template +struct VecConvert< + dst_t, + 1, + float, + 1, + typename std::enable_if_t, void>> { + static inline VectorizedN apply(const VectorizedN& src) { + return convert_from_float(src[0], src[0]); + } +}; + +} // namespace CPU_CAPABILITY +} // namespace at::vec + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_double.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_double.h new file mode 100644 index 0000000000000000000000000000000000000000..34c34f62526d9cb2d5cd5ed9d8e396280ca608f8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_double.h @@ -0,0 +1,531 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include +#include +#if defined(CPU_CAPABILITY_AVX2) +#define SLEEF_STATIC_LIBS +#include +#endif + +namespace at::vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_AVX2) + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized { + private: + __m256d values; + + public: + using value_type = double; + using size_type = int; + static constexpr size_type size() { + return 4; + } + Vectorized() { + values = _mm256_setzero_pd(); + } + Vectorized(__m256d v) : values(v) {} + Vectorized(double val) { + values = _mm256_set1_pd(val); + } + Vectorized(double val1, double val2, double val3, double val4) { + values = _mm256_setr_pd(val1, val2, val3, val4); + } + operator __m256d() const { + return values; + } + template + static Vectorized blend( + const Vectorized& a, + const Vectorized& b) { + return _mm256_blend_pd(a.values, b.values, mask); + } + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + return _mm256_blendv_pd(a.values, b.values, mask.values); + } + template + static Vectorized arange( + double base = 0., + step_t step = static_cast(1)) { + return Vectorized( + base, base + step, base + 2 * step, base + 3 * step); + } + static Vectorized set( + const Vectorized& a, + const Vectorized& b, + int64_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + } + return b; + } + static Vectorized loadu(const void* ptr, int64_t count = size()) { + if (count == size()) + return _mm256_loadu_pd(reinterpret_cast(ptr)); + + __at_align__ double tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0.0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(double)); + return _mm256_load_pd(tmp_values); + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + _mm256_storeu_pd(reinterpret_cast(ptr), values); + } else if (count > 0) { + double tmp_values[size()]; + _mm256_storeu_pd(reinterpret_cast(tmp_values), values); + std::memcpy(ptr, tmp_values, count * sizeof(double)); + } + } + const double& operator[](int idx) const = delete; + double& operator[](int idx) = delete; + int zero_mask() const { + // returns an integer mask where all zero elements are translated to 1-bit + // and others are translated to 0-bit + __m256d cmp = _mm256_cmp_pd(values, _mm256_set1_pd(0.0), _CMP_EQ_OQ); + return _mm256_movemask_pd(cmp); + } + Vectorized isnan() const { + return _mm256_cmp_pd(values, _mm256_set1_pd(0.0), _CMP_UNORD_Q); + } + bool has_inf_nan() const { + __m256d self_sub = _mm256_sub_pd(values, values); + return (_mm256_movemask_epi8(_mm256_castpd_si256(self_sub)) & 0x77777777) != + 0; + } + Vectorized map(double (*const f)(double)) const { + __at_align__ double tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + Vectorized abs() const { + auto mask = _mm256_set1_pd(-0.f); + return _mm256_andnot_pd(mask, values); + } + Vectorized angle() const { + const auto zero_vec = _mm256_set1_pd(0.f); + const auto nan_vec = _mm256_set1_pd(NAN); + const auto not_nan_mask = _mm256_cmp_pd(values, values, _CMP_EQ_OQ); + const auto nan_mask = _mm256_cmp_pd(not_nan_mask, zero_vec, _CMP_EQ_OQ); + const auto pi = _mm256_set1_pd(c10::pi); + + const auto neg_mask = _mm256_cmp_pd(values, zero_vec, _CMP_LT_OQ); + auto angle = _mm256_blendv_pd(zero_vec, pi, neg_mask); + angle = _mm256_blendv_pd(angle, nan_vec, nan_mask); + return angle; + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return _mm256_set1_pd(0); + } + Vectorized conj() const { + return *this; + } + Vectorized acos() const { + return Vectorized(Sleef_acosd4_u10(values)); + } + Vectorized acosh() const { + return Vectorized(Sleef_acoshd4_u10(values)); + } + Vectorized asin() const { + return Vectorized(Sleef_asind4_u10(values)); + } + Vectorized asinh() const { + return Vectorized(Sleef_asinhd4_u10(values)); + } + Vectorized atan() const { + return Vectorized(Sleef_atand4_u10(values)); + } + Vectorized atanh() const { + return Vectorized(Sleef_atanhd4_u10(values)); + } + Vectorized atan2(const Vectorized& b) const { + return Vectorized(Sleef_atan2d4_u10(values, b)); + } + Vectorized copysign(const Vectorized& sign) const { + return Vectorized(Sleef_copysignd4(values, sign)); + } + Vectorized erf() const { + return Vectorized(Sleef_erfd4_u10(values)); + } + Vectorized erfc() const { + return Vectorized(Sleef_erfcd4_u15(values)); + } + Vectorized erfinv() const { + return map(calc_erfinv); + } + Vectorized exp() const { + return Vectorized(Sleef_expd4_u10(values)); + } + Vectorized exp2() const { + return Vectorized(Sleef_exp2d4_u10(values)); + } + Vectorized expm1() const { + return Vectorized(Sleef_expm1d4_u10(values)); + } + Vectorized exp_u20() const { + return exp(); + } + Vectorized fexp_u20() const { + return exp(); + } + Vectorized fmod(const Vectorized& q) const { + return Vectorized(Sleef_fmodd4(values, q)); + } + Vectorized hypot(const Vectorized& b) const { + return Vectorized(Sleef_hypotd4_u05(values, b)); + } + Vectorized i0() const { + return map(calc_i0); + } + Vectorized i0e() const { + return map(calc_i0e); + } + Vectorized digamma() const { + return map(calc_digamma); + } + Vectorized igamma(const Vectorized& x) const { + __at_align__ double tmp[size()]; + __at_align__ double tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (const auto i : c10::irange(size())) { + tmp[i] = calc_igamma(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } + Vectorized igammac(const Vectorized& x) const { + __at_align__ double tmp[size()]; + __at_align__ double tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (const auto i : c10::irange(size())) { + tmp[i] = calc_igammac(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } + Vectorized log() const { + return Vectorized(Sleef_logd4_u10(values)); + } + Vectorized log2() const { + return Vectorized(Sleef_log2d4_u10(values)); + } + Vectorized log10() const { + return Vectorized(Sleef_log10d4_u10(values)); + } + Vectorized log1p() const { + return Vectorized(Sleef_log1pd4_u10(values)); + } + Vectorized sin() const { + return Vectorized(Sleef_sind4_u10(values)); + } + Vectorized sinh() const { + return Vectorized(Sleef_sinhd4_u10(values)); + } + Vectorized cos() const { + return Vectorized(Sleef_cosd4_u10(values)); + } + Vectorized cosh() const { + return Vectorized(Sleef_coshd4_u10(values)); + } + Vectorized ceil() const { + return _mm256_ceil_pd(values); + } + Vectorized floor() const { + return _mm256_floor_pd(values); + } + Vectorized frac() const; + Vectorized neg() const { + return _mm256_xor_pd(_mm256_set1_pd(-0.), values); + } + Vectorized nextafter(const Vectorized& b) const { + return Vectorized(Sleef_nextafterd4(values, b)); + } + Vectorized round() const { + return _mm256_round_pd( + values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + } + Vectorized tan() const { + return Vectorized(Sleef_tand4_u10(values)); + } + Vectorized tanh() const { + return Vectorized(Sleef_tanhd4_u10(values)); + } + Vectorized trunc() const { + return _mm256_round_pd(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); + } + Vectorized lgamma() const { + return Vectorized(Sleef_lgammad4_u10(values)); + } + Vectorized sqrt() const { + return _mm256_sqrt_pd(values); + } + Vectorized reciprocal() const { + return _mm256_div_pd(_mm256_set1_pd(1), values); + } + Vectorized rsqrt() const { + return _mm256_div_pd(_mm256_set1_pd(1), _mm256_sqrt_pd(values)); + } + Vectorized pow(const Vectorized& b) const { + return Vectorized(Sleef_powd4_u10(values, b)); + } + // Comparison using the _CMP_**_OQ predicate. + // `O`: get false if an operand is NaN + // `Q`: do not raise if an operand is NaN + Vectorized operator==(const Vectorized& other) const { + return _mm256_cmp_pd(values, other.values, _CMP_EQ_OQ); + } + + Vectorized operator!=(const Vectorized& other) const { + return _mm256_cmp_pd(values, other.values, _CMP_NEQ_UQ); + } + + Vectorized operator<(const Vectorized& other) const { + return _mm256_cmp_pd(values, other.values, _CMP_LT_OQ); + } + + Vectorized operator<=(const Vectorized& other) const { + return _mm256_cmp_pd(values, other.values, _CMP_LE_OQ); + } + + Vectorized operator>(const Vectorized& other) const { + return _mm256_cmp_pd(values, other.values, _CMP_GT_OQ); + } + + Vectorized operator>=(const Vectorized& other) const { + return _mm256_cmp_pd(values, other.values, _CMP_GE_OQ); + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; +}; + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return _mm256_add_pd(a, b); +} + +template <> +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return _mm256_sub_pd(a, b); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return _mm256_mul_pd(a, b); +} + +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return _mm256_div_pd(a, b); +} + +// frac. Implement this here so we can use subtraction. +inline Vectorized Vectorized::frac() const { + return *this - this->trunc(); +} + +// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + Vectorized max = _mm256_max_pd(a, b); + Vectorized isnan = _mm256_cmp_pd(a, b, _CMP_UNORD_Q); + // Exploit the fact that all-ones is a NaN. + return _mm256_or_pd(max, isnan); +} + +// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + Vectorized min = _mm256_min_pd(a, b); + Vectorized isnan = _mm256_cmp_pd(a, b, _CMP_UNORD_Q); + // Exploit the fact that all-ones is a NaN. + return _mm256_or_pd(min, isnan); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min, + const Vectorized& max) { + return _mm256_min_pd(max, _mm256_max_pd(min, a)); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min) { + return _mm256_max_pd(min, a); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max) { + return _mm256_min_pd(max, a); +} + +template <> +Vectorized inline operator&( + const Vectorized& a, + const Vectorized& b) { + return _mm256_and_pd(a, b); +} + +template <> +Vectorized inline operator|( + const Vectorized& a, + const Vectorized& b) { + return _mm256_or_pd(a, b); +} + +template <> +Vectorized inline operator^( + const Vectorized& a, + const Vectorized& b) { + return _mm256_xor_pd(a, b); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1.0); +} + +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1.0); +} + +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1.0); +} + +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1.0); +} + +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1.0); +} + +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1.0); +} + +template <> +inline void convert(const double* src, double* dst, int64_t n) { + int64_t i; +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (i = 0; i <= (n - Vectorized::size()); + i += Vectorized::size()) { + _mm256_storeu_pd(dst + i, _mm256_loadu_pd(src + i)); + } +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (; i < n; i++) { + dst[i] = src[i]; + } +} + +#ifdef CPU_CAPABILITY_AVX2 +template <> +Vectorized inline fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return _mm256_fmadd_pd(a, b, c); +} + +template <> +Vectorized inline fnmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return _mm256_fnmadd_pd(a, b, c); +} + +template <> +Vectorized inline fmsub( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return _mm256_fmsub_pd(a, b, c); +} + +template <> +Vectorized inline fnmsub( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return _mm256_fnmsub_pd(a, b, c); +} +#endif + +#endif + +} // namespace CPU_CAPABILITY +} // namespace at::vec + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_float.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_float.h new file mode 100644 index 0000000000000000000000000000000000000000..1a2cbb07006467f5eded6893f5aadf4d68e93053 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_float.h @@ -0,0 +1,847 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include +#include +#if defined(CPU_CAPABILITY_AVX2) +#define SLEEF_STATIC_LIBS +#include +#endif + +namespace at::vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_AVX2) + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized { + private: + __m256 values; + + public: + using value_type = float; + using size_type = int; + static constexpr size_type size() { + return 8; + } + Vectorized() { + values = _mm256_setzero_ps(); + } + Vectorized(__m256 v) : values(v) {} + Vectorized(float val) { + values = _mm256_set1_ps(val); + } + Vectorized( + float val1, + float val2, + float val3, + float val4, + float val5, + float val6, + float val7, + float val8) { + values = _mm256_setr_ps(val1, val2, val3, val4, val5, val6, val7, val8); + } + Vectorized(const float (&arr)[8]) + : Vectorized( + arr[0], + arr[1], + arr[2], + arr[3], + arr[4], + arr[5], + arr[6], + arr[7]) {} + operator __m256() const { + return values; + } + template + static Vectorized blend( + const Vectorized& a, + const Vectorized& b) { + return _mm256_blend_ps(a.values, b.values, mask); + } + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + return _mm256_blendv_ps(a.values, b.values, mask.values); + } + template + static Vectorized arange( + float base = 0.f, + step_t step = static_cast(1)) { + return Vectorized( + base, + base + step, + base + 2 * step, + base + 3 * step, + base + 4 * step, + base + 5 * step, + base + 6 * step, + base + 7 * step); + } + static Vectorized set( + const Vectorized& a, + const Vectorized& b, + int64_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + case 4: + return blend<15>(a, b); + case 5: + return blend<31>(a, b); + case 6: + return blend<63>(a, b); + case 7: + return blend<127>(a, b); + } + return b; + } + static Vectorized loadu(const void* ptr, int64_t count = size()) { + if (count == size()) + return _mm256_loadu_ps(reinterpret_cast(ptr)); + __at_align__ float tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0.0; + } + std::memcpy( + tmp_values, reinterpret_cast(ptr), count * sizeof(float)); + return _mm256_loadu_ps(tmp_values); + } + void store(void* ptr, int64_t count = size()) const { + if (count == size()) { + _mm256_storeu_ps(reinterpret_cast(ptr), values); + } else if (count > 0) { + float tmp_values[size()]; + _mm256_storeu_ps(reinterpret_cast(tmp_values), values); + std::memcpy(ptr, tmp_values, count * sizeof(float)); + } + } + const float& operator[](int idx) const = delete; + float& operator[](int idx) = delete; + int zero_mask() const { + // returns an integer mask where all zero elements are translated to 1-bit + // and others are translated to 0-bit + __m256 cmp = _mm256_cmp_ps(values, _mm256_set1_ps(0.0f), _CMP_EQ_OQ); + return _mm256_movemask_ps(cmp); + } + Vectorized isnan() const { + return _mm256_cmp_ps(values, _mm256_set1_ps(0.0f), _CMP_UNORD_Q); + } + + bool has_inf_nan() const { + __m256 self_sub = _mm256_sub_ps(values, values); + return (_mm256_movemask_epi8(_mm256_castps_si256(self_sub)) & 0x77777777) != + 0; + } + + Vectorized map(float (*const f)(float)) const { + __at_align__ float tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + Vectorized abs() const { + auto mask = _mm256_set1_ps(-0.f); + return _mm256_andnot_ps(mask, values); + } + Vectorized angle() const { + const auto zero_vec = _mm256_set1_ps(0.f); + const auto nan_vec = _mm256_set1_ps(NAN); + const auto not_nan_mask = _mm256_cmp_ps(values, values, _CMP_EQ_OQ); + const auto nan_mask = _mm256_cmp_ps(not_nan_mask, zero_vec, _CMP_EQ_OQ); + const auto pi = _mm256_set1_ps(c10::pi); + + const auto neg_mask = _mm256_cmp_ps(values, zero_vec, _CMP_LT_OQ); + auto angle = _mm256_blendv_ps(zero_vec, pi, neg_mask); + angle = _mm256_blendv_ps(angle, nan_vec, nan_mask); + return angle; + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return _mm256_set1_ps(0); + } + Vectorized conj() const { + return *this; + } + Vectorized acos() const { + return Vectorized(Sleef_acosf8_u10(values)); + } + Vectorized acosh() const { + return Vectorized(Sleef_acoshf8_u10(values)); + } + Vectorized asin() const { + return Vectorized(Sleef_asinf8_u10(values)); + } + Vectorized asinh() const { + return Vectorized(Sleef_asinhf8_u10(values)); + } + Vectorized atan() const { + return Vectorized(Sleef_atanf8_u10(values)); + } + Vectorized atanh() const { + return Vectorized(Sleef_atanhf8_u10(values)); + } + Vectorized atan2(const Vectorized& b) const { + return Vectorized(Sleef_atan2f8_u10(values, b)); + } + Vectorized copysign(const Vectorized& sign) const { + return Vectorized(Sleef_copysignf8(values, sign)); + } + Vectorized erf() const { + // constants + const auto neg_zero_vec = _mm256_set1_ps(-0.f); + const auto one_vec = _mm256_set1_ps(1.0f); + const auto p = _mm256_set1_ps(0.3275911f); + const auto p1 = _mm256_set1_ps(0.254829592f); + const auto p2 = _mm256_set1_ps(-0.284496736f); + const auto p3 = _mm256_set1_ps(1.421413741f); + const auto p4 = _mm256_set1_ps(-1.453152027f); + const auto p5 = _mm256_set1_ps(1.061405429f); + // sign(x) + auto sign_mask = _mm256_and_ps(neg_zero_vec, values); + auto abs_vec = _mm256_xor_ps(sign_mask, values); + // t = 1 / (p * abs(x) + 1) + auto tmp0 = _mm256_fmadd_ps(p, abs_vec, one_vec); + auto t = _mm256_div_ps(one_vec, tmp0); + // r = p5 * t ^ 4 + p4 * t ^ 3 + p3 * t ^ 2 + p2 * t + p1 + auto tmp1 = _mm256_fmadd_ps(p5, t, p4); + auto tmp2 = _mm256_fmadd_ps(tmp1, t, p3); + auto tmp3 = _mm256_fmadd_ps(tmp2, t, p2); + auto r = _mm256_fmadd_ps(tmp3, t, p1); + // - exp(- x * x) + auto pow_2 = _mm256_mul_ps(values, values); + auto neg_pow_2 = _mm256_xor_ps(neg_zero_vec, pow_2); + // auto tmp4 = exp(neg_pow_2); + auto tmp4 = Vectorized(Sleef_expf8_u10(neg_pow_2)); + auto tmp5 = _mm256_xor_ps(neg_zero_vec, tmp4); + // erf(x) = sign(x) * (1 - r * t * exp(- x * x)) + auto tmp6 = _mm256_mul_ps(tmp5, t); + auto tmp7 = _mm256_fmadd_ps(tmp6, r, one_vec); + return _mm256_xor_ps(sign_mask, tmp7); + } + Vectorized erfc() const { + return Vectorized(Sleef_erfcf8_u15(values)); + } + Vectorized erfinv() const { + return map(calc_erfinv); + } + Vectorized exp() const { + return Vectorized(Sleef_expf8_u10(values)); + } + Vectorized exp2() const { + return Vectorized(Sleef_exp2f8_u10(values)); + } + Vectorized expm1() const { + return Vectorized(Sleef_expm1f8_u10(values)); + } + Vectorized fexp_u20() const { + const __m256 vec_c0 = _mm256_set1_ps(0.00010703434948458272f); + const __m256 vec_c1 = _mm256_set1_ps(0.30354260500649682f); + const __m256 vec_c2 = _mm256_set1_ps(-0.22433836478672356); + const __m256 vec_c3 = _mm256_set1_ps(-0.079204240219773236); + + const __m256 vec_exp_log2ef = + _mm256_castsi256_ps(_mm256_set1_epi32(0x3fb8aa3b)); // log2(e) + + const __m256 vec_a = _mm256_set1_ps(std::pow(2, 23) / std::log2(2)); + const __m256 vec_b = _mm256_set1_ps(std::pow(2, 23) * 127.f); + + const __m256 vec_ln_flt_min = + _mm256_castsi256_ps(_mm256_set1_epi32(0xc2aeac50)); + const __m256 vec_ln_flt_max = + _mm256_castsi256_ps(_mm256_set1_epi32(0x42b17218)); + const __m256 vec_inf = _mm256_set1_ps(INFINITY); + const __m256 zero = _mm256_setzero_ps(); + + // exp(x) = 2**(x * log2(e)) + // = 2**xi * 2**xf - TIPS we are using the EEEE floating point + // representation with identification to the exponent and the + // mentissa + // 2**xf will be approximated to a polynomial of degree 3 computed with + // Horner method + // compute the min/max for the mask + // Masks + __m256 mask_too_small = + _mm256_cmp_ps(values, vec_ln_flt_min, _CMP_LT_OS); // x < min + __m256 mask_too_large = + _mm256_cmp_ps(values, vec_ln_flt_max, _CMP_GT_OS); // x > max + + // transformation with log2(e) + auto vec_src = _mm256_mul_ps(values, vec_exp_log2ef); + auto vec_fractional = _mm256_sub_ps(vec_src, _mm256_floor_ps(vec_src)); + + // compute polynomial using Horner Scheme + auto vec_res = _mm256_fmadd_ps(vec_fractional, vec_c3, vec_c2); + vec_res = _mm256_fmadd_ps(vec_fractional, vec_res, vec_c1); + vec_res = _mm256_fmadd_ps(vec_fractional, vec_res, vec_c0); + + vec_src = _mm256_sub_ps(vec_src, vec_res); + // // the tips is here, headache in perspective + auto tmp = _mm256_fmadd_ps(vec_a, vec_src, vec_b); + // headache bis + __m256i casted_integer = _mm256_cvttps_epi32(tmp); + // bitwise to float for the final transformation + auto result = _mm256_castsi256_ps(casted_integer); + // boundary condition + // Set to 0 where x < ln(FLT_MIN) + result = _mm256_blendv_ps(result, zero, mask_too_small); + // Set to +inf where x > ln(FLT_MAX) + result = _mm256_blendv_ps(result, vec_inf, mask_too_large); + // final interpretation to float + return result; + } + + Vectorized exp_u20() const { + // A faster version of exp with ULP=20 + const __m256 vec_factorial_1 = + _mm256_set1_ps(0.999999701f); // 1/factorial(1) + const __m256 vec_factorial_2 = + _mm256_set1_ps(0.499991506f); // 1/factorial(2) + const __m256 vec_factorial_3 = + _mm256_set1_ps(0.166676521f); // 1/factorial(3) + const __m256 vec_factorial_4 = + _mm256_set1_ps(0.0418978221f); // 1/factorial(4) + const __m256 vec_factorial_5 = + _mm256_set1_ps(0.00828929059f); // 1/factorial(5) + const __m256 vec_exp_log2ef = + _mm256_castsi256_ps(_mm256_set1_epi32(0x3fb8aa3b)); // log2(e) + const __m256 vec_half = _mm256_set1_ps(0.5f); + const __m256 vec_one = _mm256_set1_ps(1.f); + const __m256 vec_zero = _mm256_set1_ps(0.f); + const __m256 vec_two = _mm256_set1_ps(2.f); + const __m256 vec_ln2f = + _mm256_castsi256_ps(_mm256_set1_epi32(0x3f317218)); // ln(2) + const __m256 vec_ln_flt_min = + _mm256_castsi256_ps(_mm256_set1_epi32(0xc2aeac50)); + const __m256 vec_ln_flt_max = + _mm256_castsi256_ps(_mm256_set1_epi32(0x42b17218)); + const __m256i vec_127 = _mm256_set1_epi32(0x0000007f); + const int n_mantissa_bits = 23; + + // exp(x) = + // = exp(n * ln(2) + r) // divide x by ln(2) and get quot and rem + // = 2^n * exp(r) // simplify the exp(n*ln(2)) expression + + auto less_ln_flt_min_mask = + _mm256_cmp_ps(values, vec_ln_flt_min, 1 /*_CMP_LT_OS*/); + auto vec_src = _mm256_min_ps(values, vec_ln_flt_max); + vec_src = _mm256_max_ps(vec_src, vec_ln_flt_min); + + // fx = floorf(x * log2ef + 0.5) + auto vec_fx = _mm256_fmadd_ps(vec_src, vec_exp_log2ef, vec_half); + vec_fx = _mm256_floor_ps(vec_fx); + + // x = x - fx * ln2 + auto vec_exp_poly = _mm256_fnmadd_ps(vec_fx, vec_ln2f, vec_src); + + // compute polynomial + auto vec_res = + _mm256_fmadd_ps(vec_exp_poly, vec_factorial_5, vec_factorial_4); + vec_res = _mm256_fmadd_ps(vec_exp_poly, vec_res, vec_factorial_3); + vec_res = _mm256_fmadd_ps(vec_exp_poly, vec_res, vec_factorial_2); + vec_res = _mm256_fmadd_ps(vec_exp_poly, vec_res, vec_factorial_1); + vec_res = _mm256_fmadd_ps(vec_exp_poly, vec_res, vec_one); + + // compute 2^(n-1) + auto vec_exp_number = _mm256_sub_ps(vec_fx, vec_one); + auto vec_exp_number_i = _mm256_cvtps_epi32(vec_exp_number); + auto vec_two_pow_n_i = _mm256_add_epi32(vec_exp_number_i, vec_127); + vec_two_pow_n_i = _mm256_slli_epi32(vec_two_pow_n_i, n_mantissa_bits); + auto vec_two_pow_n = _mm256_castsi256_ps(vec_two_pow_n_i); + vec_two_pow_n = + _mm256_blendv_ps(vec_two_pow_n, vec_zero, less_ln_flt_min_mask); + + // y = y * 2^n + vec_res = _mm256_mul_ps(vec_res, vec_two_pow_n); + vec_res = _mm256_mul_ps(vec_res, vec_two); + return vec_res; + } + Vectorized fmod(const Vectorized& q) const { + return Vectorized(Sleef_fmodf8(values, q)); + } + Vectorized log() const { + return Vectorized(Sleef_logf8_u10(values)); + } + Vectorized log2() const { + return Vectorized(Sleef_log2f8_u10(values)); + } + Vectorized log10() const { + return Vectorized(Sleef_log10f8_u10(values)); + } + Vectorized log1p() const { + return Vectorized(Sleef_log1pf8_u10(values)); + } + Vectorized frac() const; + Vectorized sin() const { + return Vectorized(Sleef_sinf8_u35(values)); + } + Vectorized sinh() const { + return Vectorized(Sleef_sinhf8_u10(values)); + } + Vectorized cos() const { + return Vectorized(Sleef_cosf8_u35(values)); + } + Vectorized cosh() const { + return Vectorized(Sleef_coshf8_u10(values)); + } + Vectorized ceil() const { + return _mm256_ceil_ps(values); + } + Vectorized floor() const { + return _mm256_floor_ps(values); + } + Vectorized hypot(const Vectorized& b) const { + return Vectorized(Sleef_hypotf8_u05(values, b)); + } + Vectorized i0() const { + return map(calc_i0); + } + Vectorized i0e() const { + return map(calc_i0e); + } + Vectorized digamma() const { + return map(calc_digamma); + } + Vectorized igamma(const Vectorized& x) const { + __at_align__ float tmp[size()]; + __at_align__ float tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (const auto i : c10::irange(size())) { + tmp[i] = calc_igamma(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } + Vectorized igammac(const Vectorized& x) const { + __at_align__ float tmp[size()]; + __at_align__ float tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (const auto i : c10::irange(size())) { + tmp[i] = calc_igammac(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } + Vectorized neg() const { + return _mm256_xor_ps(_mm256_set1_ps(-0.f), values); + } + Vectorized nextafter(const Vectorized& b) const { + return Vectorized(Sleef_nextafterf8(values, b)); + } + Vectorized round() const { + return _mm256_round_ps( + values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + } + Vectorized tan() const { + return Vectorized(Sleef_tanf8_u10(values)); + } + Vectorized tanh() const { + return Vectorized(Sleef_tanhf8_u10(values)); + } + Vectorized trunc() const { + return _mm256_round_ps(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); + } + Vectorized lgamma() const { + return Vectorized(Sleef_lgammaf8_u10(values)); + } + Vectorized sqrt() const { + return _mm256_sqrt_ps(values); + } + Vectorized reciprocal() const { + return _mm256_div_ps(_mm256_set1_ps(1), values); + } + Vectorized rsqrt() const { + return _mm256_div_ps(_mm256_set1_ps(1), _mm256_sqrt_ps(values)); + } + Vectorized pow(const Vectorized& b) const { + return Vectorized(Sleef_powf8_u10(values, b)); + } + float reduce_add() const { + auto v = values; + // 128-bit shuffle + auto v1 = _mm256_permute2f128_ps(v, v, 0x1); + v = _mm256_add_ps(v, v1); + // 64-bit shuffle + v1 = _mm256_shuffle_ps(v, v, 0x4E); + v = _mm256_add_ps(v, v1); + // 32-bit shuffle + v1 = _mm256_shuffle_ps(v, v, 0xB1); + v = _mm256_add_ps(v, v1); + return _mm256_cvtss_f32(v); + } + float reduce_max() const { + auto v = values; + // 128-bit shuffle + auto v1 = _mm256_permute2f128_ps(v, v, 0x1); + v = _mm256_max_ps(v, v1); + // 64-bit shuffle + v1 = _mm256_shuffle_ps(v, v, 0x4E); + v = _mm256_max_ps(v, v1); + // 32-bit shuffle + v1 = _mm256_shuffle_ps(v, v, 0xB1); + v = _mm256_max_ps(v, v1); + return _mm256_cvtss_f32(v); + } + // Comparison using the _CMP_**_OQ predicate. + // `O`: get false if an operand is NaN + // `Q`: do not raise if an operand is NaN + Vectorized operator==(const Vectorized& other) const { + return _mm256_cmp_ps(values, other.values, _CMP_EQ_OQ); + } + + Vectorized operator!=(const Vectorized& other) const { + return _mm256_cmp_ps(values, other.values, _CMP_NEQ_UQ); + } + + Vectorized operator<(const Vectorized& other) const { + return _mm256_cmp_ps(values, other.values, _CMP_LT_OQ); + } + + Vectorized operator<=(const Vectorized& other) const { + return _mm256_cmp_ps(values, other.values, _CMP_LE_OQ); + } + + Vectorized operator>(const Vectorized& other) const { + return _mm256_cmp_ps(values, other.values, _CMP_GT_OQ); + } + + Vectorized operator>=(const Vectorized& other) const { + return _mm256_cmp_ps(values, other.values, _CMP_GE_OQ); + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return _mm256_add_ps(a, b); +} + +template <> +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return _mm256_sub_ps(a, b); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return _mm256_mul_ps(a, b); +} + +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return _mm256_div_ps(a, b); +} + +// frac. Implement this here so we can use subtraction +inline Vectorized Vectorized::frac() const { + return *this - this->trunc(); +} + +// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + Vectorized max = _mm256_max_ps(a, b); + Vectorized isnan = _mm256_cmp_ps(a, b, _CMP_UNORD_Q); + // Exploit the fact that all-ones is a NaN. + return _mm256_or_ps(max, isnan); +} + +// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + Vectorized min = _mm256_min_ps(a, b); + Vectorized isnan = _mm256_cmp_ps(a, b, _CMP_UNORD_Q); + // Exploit the fact that all-ones is a NaN. + return _mm256_or_ps(min, isnan); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min, + const Vectorized& max) { + return _mm256_min_ps(max, _mm256_max_ps(min, a)); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max) { + return _mm256_min_ps(max, a); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min) { + return _mm256_max_ps(min, a); +} + +template <> +Vectorized inline operator&( + const Vectorized& a, + const Vectorized& b) { + return _mm256_and_ps(a, b); +} + +template <> +Vectorized inline operator|( + const Vectorized& a, + const Vectorized& b) { + return _mm256_or_ps(a, b); +} + +template <> +Vectorized inline operator^( + const Vectorized& a, + const Vectorized& b) { + return _mm256_xor_ps(a, b); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1.0f); +} + +template <> +inline void convert(const float* src, float* dst, int64_t n) { + int64_t i; +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (i = 0; i <= (n - Vectorized::size()); + i += Vectorized::size()) { + _mm256_storeu_ps(dst + i, _mm256_loadu_ps(src + i)); + } +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (; i < n; i++) { + dst[i] = src[i]; + } +} + +template <> +Vectorized inline fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return _mm256_fmadd_ps(a, b, c); +} + +template <> +Vectorized inline fnmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return _mm256_fnmadd_ps(a, b, c); +} + +template <> +Vectorized inline fmsub( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return _mm256_fmsub_ps(a, b, c); +} + +template <> +Vectorized inline fnmsub( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return _mm256_fnmsub_ps(a, b, c); +} + +// TODO: rewrite with ATEN vectorized (need to add unpack and shuffle) +// Used by Inductor CPP codegen for micro gemm +inline void transpose_block(at::vec::VectorizedN& input) { + __m256 temp0[8]; + // unpacking and interleaving 32-bit elements + // a0 b0 a1 b1 a4 b4 a5 b5 + // a2 b2 a3 b3 a6 b6 a7 b7 + // c0 d0 c1 d1 ... + // c2 d2 c3 d3 ... + // e0 f0 e1 f1 ... + // e2 f2 e3 f3 ... + // g0 h0 g1 h1 ... + // g2 h2 g3 h3 ... + temp0[0] = _mm256_unpacklo_ps(input[0], input[1]); + temp0[1] = _mm256_unpackhi_ps(input[0], input[1]); + temp0[2] = _mm256_unpacklo_ps(input[2], input[3]); + temp0[3] = _mm256_unpackhi_ps(input[2], input[3]); + temp0[4] = _mm256_unpacklo_ps(input[4], input[5]); + temp0[5] = _mm256_unpackhi_ps(input[4], input[5]); + temp0[6] = _mm256_unpacklo_ps(input[6], input[7]); + temp0[7] = _mm256_unpackhi_ps(input[6], input[7]); + + __m256 temp1[8]; + // unpacking and interleaving 64-bit elements + // a0 b0 c0 d0 a4 b4 c4 d4 + // a1 b1 c1 d1 ... + // a2 b2 c2 d2 ... + // a3 b3 c3 d3 ... + // e0 f0 g0 h0 e4 f4 g4 h4 + // e1 f1 g1 h1 ... + // e2 f2 g2 h2 ... + // e3 f3 g3 h3 ... + temp1[0] = _mm256_castpd_ps(_mm256_unpacklo_pd( + _mm256_castps_pd(temp0[0]), _mm256_castps_pd(temp0[2]))); + temp1[1] = _mm256_castpd_ps(_mm256_unpackhi_pd( + _mm256_castps_pd(temp0[0]), _mm256_castps_pd(temp0[2]))); + temp1[2] = _mm256_castpd_ps(_mm256_unpacklo_pd( + _mm256_castps_pd(temp0[1]), _mm256_castps_pd(temp0[3]))); + temp1[3] = _mm256_castpd_ps(_mm256_unpackhi_pd( + _mm256_castps_pd(temp0[1]), _mm256_castps_pd(temp0[3]))); + temp1[4] = _mm256_castpd_ps(_mm256_unpacklo_pd( + _mm256_castps_pd(temp0[4]), _mm256_castps_pd(temp0[6]))); + temp1[5] = _mm256_castpd_ps(_mm256_unpackhi_pd( + _mm256_castps_pd(temp0[4]), _mm256_castps_pd(temp0[6]))); + temp1[6] = _mm256_castpd_ps(_mm256_unpacklo_pd( + _mm256_castps_pd(temp0[5]), _mm256_castps_pd(temp0[7]))); + temp1[7] = _mm256_castpd_ps(_mm256_unpackhi_pd( + _mm256_castps_pd(temp0[5]), _mm256_castps_pd(temp0[7]))); + + // shuffle 128-bits (composed of 4 32-bit elements) + // a0 b0 c0 d0 e0 f0 g0 h0 + // a1 b1 c1 d1 ... + // a2 b2 c2 d2 ... + // a3 b3 c3 d3 ... + // a4 b4 c4 d4 ... + // a5 b5 c5 d5 ... + // a6 b6 c6 d6 ... + // a7 b7 c7 d7 ... + input[0] = _mm256_permute2f128_ps(temp1[0], temp1[4], 0x20); + input[1] = _mm256_permute2f128_ps(temp1[1], temp1[5], 0x20); + input[2] = _mm256_permute2f128_ps(temp1[2], temp1[6], 0x20); + input[3] = _mm256_permute2f128_ps(temp1[3], temp1[7], 0x20); + input[4] = _mm256_permute2f128_ps(temp1[0], temp1[4], 0x31); + input[5] = _mm256_permute2f128_ps(temp1[1], temp1[5], 0x31); + input[6] = _mm256_permute2f128_ps(temp1[2], temp1[6], 0x31); + input[7] = _mm256_permute2f128_ps(temp1[3], temp1[7], 0x31); +} + +// Used by Inductor CPP codegen +template <> +inline void transpose_mxn( + const float* src, + int64_t ld_src, + float* dst, + int64_t ld_dst) { + // load from src to registers + at::vec::VectorizedN input; + // a: a0 a1 a2 a3 a4 a5 a6 a7 + // b: b0 b1 b2 b3 b4 b5 b6 b7 + // c: c0 c1 c2 c3 c4 c5 c6 c7 + // d: d0 d1 d2 d3 d4 d5 d6 d7 + // e: e0 e1 e2 e3 e4 e5 e6 e7 + // f: f0 f1 f2 f3 f4 f5 f6 f7 + // g: g0 g1 g2 g3 g4 g5 g6 g7 + // h: h0 h1 h2 h3 h4 h5 h6 h7 + int i; +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (i = 0; i < 8; i++) { + input[i] = _mm256_loadu_ps(&src[i * ld_src]); + } + + transpose_block(input); + + // store from registers to dst +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (i = 0; i < 8; i++) { + _mm256_storeu_ps(&dst[i * ld_dst], input[i]); + } +} + +template <> +inline void transpose_mxn( + const float* src, + int64_t ld_src, + float* dst, + int64_t ld_dst) { + transpose_mxn(src, ld_src, dst, ld_dst); + transpose_mxn(src + 8, ld_src, dst + 8 * ld_dst, ld_dst); + transpose_mxn(src + 8 * ld_src, ld_src, dst + 8, ld_dst); + transpose_mxn( + src + 8 * ld_src + 8, ld_src, dst + 8 * ld_dst + 8, ld_dst); +} +#endif + +} // namespace CPU_CAPABILITY +} // namespace at::vec + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_half.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_half.h new file mode 100644 index 0000000000000000000000000000000000000000..e5d95b014801a22c7eec6b9295baa51a66f0fd2c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_half.h @@ -0,0 +1,285 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include + +namespace at::vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +#ifdef CPU_CAPABILITY_AVX2 + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized : public Vectorized16 { + public: + using Vectorized16::Vectorized16; + + using value_type = Half; + + Vectorized frac() const; + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return binary_op_as_fp32(a, b, [](const __m256& x, const __m256& y) { + return _mm256_add_ps(x, y); + }); +} +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return binary_op_as_fp32(a, b, [](const __m256& x, const __m256& y) { + return _mm256_sub_ps(x, y); + }); +} +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return binary_op_as_fp32(a, b, [](const __m256& x, const __m256& y) { + return _mm256_mul_ps(x, y); + }); +} +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return binary_op_as_fp32(a, b, [](const __m256& x, const __m256& y) { + return _mm256_div_ps(x, y); + }); +} +Vectorized inline operator&( + const Vectorized& a, + const Vectorized& b) { + return _mm256_and_si256(a, b); +} +Vectorized inline operator|( + const Vectorized& a, + const Vectorized& b) { + return _mm256_or_si256(a, b); +} +Vectorized inline operator^( + const Vectorized& a, + const Vectorized& b) { + return _mm256_xor_si256(a, b); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1.0f); +} +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1.0f); +} +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1.0f); +} +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1.0f); +} +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1.0f); +} +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1.0f); +} + +// frac. Implement this here so we can use subtraction +inline Vectorized Vectorized::frac() const { + return *this - this->trunc(); +} + +// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + __m256 a_lo, a_hi; + __m256 b_lo, b_hi; + cvtfp16_fp32(__m256i(a), a_lo, a_hi); + cvtfp16_fp32(__m256i(b), b_lo, b_hi); + auto max_lo = _mm256_max_ps(a_lo, b_lo); + auto max_hi = _mm256_max_ps(a_hi, b_hi); + auto nan_lo = _mm256_cmp_ps(a_lo, b_lo, _CMP_UNORD_Q); + auto nan_hi = _mm256_cmp_ps(a_hi, b_hi, _CMP_UNORD_Q); + // Exploit the fact that all-ones is a NaN. + auto o1 = _mm256_or_ps(max_lo, nan_lo); + auto o2 = _mm256_or_ps(max_hi, nan_hi); + return cvtfp32_fp16(o1, o2); +} + +// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + __m256 a_lo, a_hi; + __m256 b_lo, b_hi; + cvtfp16_fp32(__m256i(a), a_lo, a_hi); + cvtfp16_fp32(__m256i(b), b_lo, b_hi); + auto min_lo = _mm256_min_ps(a_lo, b_lo); + auto min_hi = _mm256_min_ps(a_hi, b_hi); + auto nan_lo = _mm256_cmp_ps(a_lo, b_lo, _CMP_UNORD_Q); + auto nan_hi = _mm256_cmp_ps(a_hi, b_hi, _CMP_UNORD_Q); + // Exploit the fact that all-ones is a NaN. + auto o1 = _mm256_or_ps(min_lo, nan_lo); + auto o2 = _mm256_or_ps(min_hi, nan_hi); + return cvtfp32_fp16(o1, o2); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min, + const Vectorized& max) { + __m256 a_lo, a_hi; + __m256 min_lo, min_hi; + __m256 max_lo, max_hi; + cvtfp16_fp32(__m256i(a), a_lo, a_hi); + cvtfp16_fp32(__m256i(min), min_lo, min_hi); + cvtfp16_fp32(__m256i(max), max_lo, max_hi); + auto o1 = _mm256_min_ps(max_lo, _mm256_max_ps(min_lo, a_lo)); + auto o2 = _mm256_min_ps(max_hi, _mm256_max_ps(min_hi, a_hi)); + return cvtfp32_fp16(o1, o2); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max) { + __m256 a_lo, a_hi; + __m256 max_lo, max_hi; + cvtfp16_fp32(__m256i(a), a_lo, a_hi); + cvtfp16_fp32(__m256i(max), max_lo, max_hi); + auto o1 = _mm256_min_ps(max_lo, a_lo); + auto o2 = _mm256_min_ps(max_hi, a_hi); + return cvtfp32_fp16(o1, o2); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min) { + __m256 a_lo, a_hi; + __m256 min_lo, min_hi; + cvtfp16_fp32(__m256i(a), a_lo, a_hi); + cvtfp16_fp32(__m256i(min), min_lo, min_hi); + auto o1 = _mm256_max_ps(min_lo, a_lo); + auto o2 = _mm256_max_ps(min_hi, a_hi); + return cvtfp32_fp16(o1, o2); +} + +template <> +inline void convert(const Half* src, Half* dst, int64_t n) { + int64_t i; +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (i = 0; i <= (n - Vectorized::size()); + i += Vectorized::size()) { + auto vsrc = + _mm256_loadu_si256(reinterpret_cast<__m256i*>((void*)(src + i))); + _mm256_storeu_si256(reinterpret_cast<__m256i*>((void*)(dst + i)), vsrc); + } +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (; i < n; i++) { + dst[i] = src[i]; + } +} + +template <> +inline void convert(const float* src, Half* dst, int64_t n) { + int64_t i; + for (i = 0; i + Vectorized::size() <= n; + i += Vectorized::size()) { + __m256 a = _mm256_loadu_ps(&src[i]); + __m256 b = _mm256_loadu_ps(&src[i + 8]); + + __m256i c = cvtfp32_fp16(a, b); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(&dst[i]), c); + } + for (; i < n; i++) { + dst[i] = c10::convert(src[i]); + } +} + +template <> +inline void convert(const double* src, Half* dst, int64_t n) { + auto load_float = [](const double* src) -> __m256 { + // Load one float vector from an array of doubles + __m128 a = _mm256_cvtpd_ps(_mm256_loadu_pd(src)); + __m128 b = _mm256_cvtpd_ps(_mm256_loadu_pd(src + 4)); + return _mm256_insertf128_ps(_mm256_castps128_ps256(a), b, 1); + }; + + int64_t i; + for (i = 0; i + Vectorized::size() <= n; + i += Vectorized::size()) { + __m256 a = load_float(&src[i]); + __m256 b = load_float(&src[i + 8]); + + __m256i c = cvtfp32_fp16(a, b); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(&dst[i]), c); + } + for (; i < n; i++) { + dst[i] = c10::convert(src[i]); + } +} + +template <> +Vectorized inline fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + __m256 a_lo, a_hi; + __m256 b_lo, b_hi; + __m256 c_lo, c_hi; + cvtfp16_fp32(__m256i(a), a_lo, a_hi); + cvtfp16_fp32(__m256i(b), b_lo, b_hi); + cvtfp16_fp32(__m256i(c), c_lo, c_hi); + auto o1 = _mm256_fmadd_ps(a_lo, b_lo, c_lo); + auto o2 = _mm256_fmadd_ps(a_hi, b_hi, c_hi); + return cvtfp32_fp16(o1, o2); +} + +CONVERT_VECTORIZED_INIT(Half, half) +LOAD_FP32_VECTORIZED_INIT(Half, fp16) + +#else // defined(CPU_CAPABILITY_AVX2) + +#if !( \ + defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && \ + !defined(CPU_CAPABILITY_SVE256)) +CONVERT_NON_VECTORIZED_INIT(Half, half) +#endif + +LOAD_FP32_NON_VECTORIZED_INIT(Half, fp16) +#endif // defined(CPU_CAPABILITY_AVX2) +} // namespace CPU_CAPABILITY +} // namespace at::vec + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_int.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_int.h new file mode 100644 index 0000000000000000000000000000000000000000..bb2866dfc45192365a6d31495ccfdfe9fe5c1a98 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_int.h @@ -0,0 +1,2327 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include +#include +#include + +namespace at::vec { +inline namespace CPU_CAPABILITY { + +#ifdef CPU_CAPABILITY_AVX2 + +struct Vectorizedi { + protected: + __m256i values; + + static inline __m256i invert(const __m256i& v) { + const auto ones = _mm256_set1_epi64x(-1); + return _mm256_xor_si256(ones, v); + } + + public: + Vectorizedi() { + values = _mm256_setzero_si256(); + } + Vectorizedi(__m256i v) : values(v) {} + operator __m256i() const { + return values; + } +}; + +#else + +struct Vectorizedi {}; // dummy definition to make Vectorizedi always defined + +#endif // CPU_CAPABILITY_AVX2 + +#ifdef CPU_CAPABILITY_AVX2 + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized : public Vectorizedi { + private: + static const Vectorized ones; + + public: + using value_type = int64_t; + using size_type = int; + static constexpr size_type size() { + return 4; + } + using Vectorizedi::Vectorizedi; + Vectorized() { + values = _mm256_setzero_si256(); + } + Vectorized(int64_t v) { + values = _mm256_set1_epi64x(v); + } + Vectorized(int64_t val1, int64_t val2, int64_t val3, int64_t val4) { + values = _mm256_setr_epi64x(val1, val2, val3, val4); + } + template + static Vectorized blend( + Vectorized a, + Vectorized b) { + __at_align__ int64_t tmp_values[size()]; + a.store(tmp_values); + if (mask & 0x01) + tmp_values[0] = _mm256_extract_epi64(b.values, 0); + if (mask & 0x02) + tmp_values[1] = _mm256_extract_epi64(b.values, 1); + if (mask & 0x04) + tmp_values[2] = _mm256_extract_epi64(b.values, 2); + if (mask & 0x08) + tmp_values[3] = _mm256_extract_epi64(b.values, 3); + return loadu(tmp_values); + } + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + return _mm256_blendv_epi8(a.values, b.values, mask.values); + } + template + static Vectorized arange( + int64_t base = 0, + step_t step = static_cast(1)) { + return Vectorized( + base, base + step, base + 2 * step, base + 3 * step); + } + static Vectorized set( + Vectorized a, + Vectorized b, + int64_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + } + return b; + } + static Vectorized loadu(const void* ptr) { + return _mm256_loadu_si256(reinterpret_cast(ptr)); + } + static Vectorized loadu(const void* ptr, int64_t count) { + __at_align__ int64_t tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to one using "={1}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 1; + } + std::memcpy(tmp_values, ptr, count * sizeof(int64_t)); + return loadu(tmp_values); + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + // ptr need not to be aligned here. See + // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm256-storeu-si256.html + _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values); + } else if (count > 0) { + __at_align__ int64_t tmp_values[size()]; + _mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values); + std::memcpy(ptr, tmp_values, count * sizeof(int64_t)); + } + } + const int64_t& operator[](int idx) const = delete; + int64_t& operator[](int idx) = delete; + Vectorized abs() const { + auto zero = _mm256_set1_epi64x(0); + auto is_larger = _mm256_cmpgt_epi64(zero, values); + auto inverse = _mm256_xor_si256(values, is_larger); + return _mm256_sub_epi64(inverse, is_larger); + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return _mm256_set1_epi64x(0); + } + Vectorized conj() const { + return *this; + } + Vectorized neg() const; + Vectorized operator==(const Vectorized& other) const { + return _mm256_cmpeq_epi64(values, other.values); + } + Vectorized operator!=(const Vectorized& other) const { + return invert(_mm256_cmpeq_epi64(values, other.values)); + } + Vectorized operator<(const Vectorized& other) const { + return _mm256_cmpgt_epi64(other.values, values); + } + Vectorized operator<=(const Vectorized& other) const { + return invert(_mm256_cmpgt_epi64(values, other.values)); + } + Vectorized operator>(const Vectorized& other) const { + return _mm256_cmpgt_epi64(values, other.values); + } + Vectorized operator>=(const Vectorized& other) const { + return invert(_mm256_cmpgt_epi64(other.values, values)); + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized : public Vectorizedi { + private: + static const Vectorized ones; + + public: + using value_type = int32_t; + static constexpr int size() { + return 8; + } + using Vectorizedi::Vectorizedi; + Vectorized() {} + Vectorized(int32_t v) { + values = _mm256_set1_epi32(v); + } + Vectorized( + int32_t val1, + int32_t val2, + int32_t val3, + int32_t val4, + int32_t val5, + int32_t val6, + int32_t val7, + int32_t val8) { + values = _mm256_setr_epi32(val1, val2, val3, val4, val5, val6, val7, val8); + } + template + static Vectorized blend( + Vectorized a, + Vectorized b) { + return _mm256_blend_epi32(a, b, mask); + } + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + return _mm256_blendv_epi8(a.values, b.values, mask.values); + } + template + static Vectorized arange( + int32_t base = 0, + step_t step = static_cast(1)) { + return Vectorized( + base, + base + step, + base + 2 * step, + base + 3 * step, + base + 4 * step, + base + 5 * step, + base + 6 * step, + base + 7 * step); + } + static Vectorized set( + Vectorized a, + Vectorized b, + int32_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + case 4: + return blend<15>(a, b); + case 5: + return blend<31>(a, b); + case 6: + return blend<63>(a, b); + case 7: + return blend<127>(a, b); + } + return b; + } + static Vectorized loadu(const void* ptr) { + return _mm256_loadu_si256(reinterpret_cast(ptr)); + } + static Vectorized loadu(const void* ptr, int32_t count) { + __at_align__ int32_t tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to one using "={1}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 1; + } + std::memcpy(tmp_values, ptr, count * sizeof(int32_t)); + return loadu(tmp_values); + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + // ptr need not to be aligned here. See + // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm256-storeu-si256.html + _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values); + } else if (count > 0) { + __at_align__ int32_t tmp_values[size()]; + _mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values); + std::memcpy(ptr, tmp_values, count * sizeof(int32_t)); + } + } + const int32_t& operator[](int idx) const = delete; + int32_t& operator[](int idx) = delete; + Vectorized abs() const { + return _mm256_abs_epi32(values); + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return _mm256_set1_epi32(0); + } + Vectorized conj() const { + return *this; + } + Vectorized neg() const; + int32_t reduce_add() const { + auto v = values; + // 128-bit shuffle + auto v1 = _mm256_permute2f128_si256(v, v, 0x1); + v = _mm256_add_epi32(v, v1); + // 64-bit shuffle + v1 = _mm256_shuffle_epi32(v, 0x4E); + v = _mm256_add_epi32(v, v1); + // 32-bit shuffle + v1 = _mm256_shuffle_epi32(v, 0xB1); + v = _mm256_add_epi32(v, v1); + __m128i lo = _mm256_castsi256_si128(v); + return _mm_cvtsi128_si32(lo); + } + int32_t reduce_max() const { + auto v = values; + // 128-bit shuffle + auto v1 = _mm256_permute2f128_si256(v, v, 0x1); + v = _mm256_max_epi32(v, v1); + // 64-bit shuffle + v1 = _mm256_shuffle_epi32(v, 0x4E); + v = _mm256_max_epi32(v, v1); + // 32-bit shuffle + v1 = _mm256_shuffle_epi32(v, 0xB1); + v = _mm256_max_epi32(v, v1); + __m128i lo = _mm256_castsi256_si128(v); + return _mm_cvtsi128_si32(lo); + } + Vectorized operator==(const Vectorized& other) const { + return _mm256_cmpeq_epi32(values, other.values); + } + Vectorized operator!=(const Vectorized& other) const { + return invert(_mm256_cmpeq_epi32(values, other.values)); + } + Vectorized operator<(const Vectorized& other) const { + return _mm256_cmpgt_epi32(other.values, values); + } + Vectorized operator<=(const Vectorized& other) const { + return invert(_mm256_cmpgt_epi32(values, other.values)); + } + Vectorized operator>(const Vectorized& other) const { + return _mm256_cmpgt_epi32(values, other.values); + } + Vectorized operator>=(const Vectorized& other) const { + return invert(_mm256_cmpgt_epi32(other.values, values)); + } + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template <> +inline void convert(const int32_t* src, float* dst, int64_t n) { + int64_t i; + // int32_t and float have same size +#ifndef _MSC_VER +#pragma unroll +#endif + for (i = 0; i <= (n - Vectorized::size()); + i += Vectorized::size()) { + auto input_vec = + _mm256_loadu_si256(reinterpret_cast(src + i)); + auto output_vec = _mm256_cvtepi32_ps(input_vec); + _mm256_storeu_ps(reinterpret_cast(dst + i), output_vec); + } +#ifndef _MSC_VER +#pragma unroll +#endif + for (; i < n; i++) { + dst[i] = static_cast(src[i]); + } +} + +template <> +inline void convert(const int32_t* src, double* dst, int64_t n) { + int64_t i; + // int32_t has half the size of double +#ifndef _MSC_VER +#pragma unroll +#endif + for (i = 0; i <= (n - Vectorized::size()); + i += Vectorized::size()) { + auto input_128_vec = + _mm_loadu_si128(reinterpret_cast(src + i)); + auto output_vec = _mm256_cvtepi32_pd(input_128_vec); + _mm256_storeu_pd(reinterpret_cast(dst + i), output_vec); + } +#ifndef _MSC_VER +#pragma unroll +#endif + for (; i < n; i++) { + dst[i] = static_cast(src[i]); + } +} + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized : public Vectorizedi { + private: + static const Vectorized ones; + + public: + using value_type = int16_t; + static constexpr int size() { + return 16; + } + using Vectorizedi::Vectorizedi; + Vectorized() {} + Vectorized(int16_t v) { + values = _mm256_set1_epi16(v); + } + Vectorized( + int16_t val1, + int16_t val2, + int16_t val3, + int16_t val4, + int16_t val5, + int16_t val6, + int16_t val7, + int16_t val8, + int16_t val9, + int16_t val10, + int16_t val11, + int16_t val12, + int16_t val13, + int16_t val14, + int16_t val15, + int16_t val16) { + values = _mm256_setr_epi16( + val1, + val2, + val3, + val4, + val5, + val6, + val7, + val8, + val9, + val10, + val11, + val12, + val13, + val14, + val15, + val16); + } + template + static Vectorized blend( + Vectorized a, + Vectorized b) { + __at_align__ int16_t tmp_values[size()]; + a.store(tmp_values); + if (mask & 0x01) + tmp_values[0] = _mm256_extract_epi16(b.values, 0); + if (mask & 0x02) + tmp_values[1] = _mm256_extract_epi16(b.values, 1); + if (mask & 0x04) + tmp_values[2] = _mm256_extract_epi16(b.values, 2); + if (mask & 0x08) + tmp_values[3] = _mm256_extract_epi16(b.values, 3); + if (mask & 0x10) + tmp_values[4] = _mm256_extract_epi16(b.values, 4); + if (mask & 0x20) + tmp_values[5] = _mm256_extract_epi16(b.values, 5); + if (mask & 0x40) + tmp_values[6] = _mm256_extract_epi16(b.values, 6); + if (mask & 0x80) + tmp_values[7] = _mm256_extract_epi16(b.values, 7); + if (mask & 0x100) + tmp_values[8] = _mm256_extract_epi16(b.values, 8); + if (mask & 0x200) + tmp_values[9] = _mm256_extract_epi16(b.values, 9); + if (mask & 0x400) + tmp_values[10] = _mm256_extract_epi16(b.values, 10); + if (mask & 0x800) + tmp_values[11] = _mm256_extract_epi16(b.values, 11); + if (mask & 0x1000) + tmp_values[12] = _mm256_extract_epi16(b.values, 12); + if (mask & 0x2000) + tmp_values[13] = _mm256_extract_epi16(b.values, 13); + if (mask & 0x4000) + tmp_values[14] = _mm256_extract_epi16(b.values, 14); + if (mask & 0x8000) + tmp_values[15] = _mm256_extract_epi16(b.values, 15); + return loadu(tmp_values); + } + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + return _mm256_blendv_epi8(a.values, b.values, mask.values); + } + template + static Vectorized arange( + int16_t base = 0, + step_t step = static_cast(1)) { + return Vectorized( + base, + base + step, + base + 2 * step, + base + 3 * step, + base + 4 * step, + base + 5 * step, + base + 6 * step, + base + 7 * step, + base + 8 * step, + base + 9 * step, + base + 10 * step, + base + 11 * step, + base + 12 * step, + base + 13 * step, + base + 14 * step, + base + 15 * step); + } + static Vectorized set( + Vectorized a, + Vectorized b, + int16_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + case 4: + return blend<15>(a, b); + case 5: + return blend<31>(a, b); + case 6: + return blend<63>(a, b); + case 7: + return blend<127>(a, b); + case 8: + return blend<255>(a, b); + case 9: + return blend<511>(a, b); + case 10: + return blend<1023>(a, b); + case 11: + return blend<2047>(a, b); + case 12: + return blend<4095>(a, b); + case 13: + return blend<8191>(a, b); + case 14: + return blend<16383>(a, b); + case 15: + return blend<32767>(a, b); + } + return b; + } + static Vectorized loadu(const void* ptr) { + return _mm256_loadu_si256(reinterpret_cast(ptr)); + } + static Vectorized loadu(const void* ptr, int16_t count) { + __at_align__ int16_t tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to one using "={1}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 1; + } + std::memcpy(tmp_values, ptr, count * sizeof(int16_t)); + return loadu(tmp_values); + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + // ptr need not to be aligned here. See + // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm256-storeu-si256.html + _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values); + } else if (count > 0) { + __at_align__ int16_t tmp_values[size()]; + _mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values); + std::memcpy(ptr, tmp_values, count * sizeof(int16_t)); + } + } + const int16_t& operator[](int idx) const = delete; + int16_t& operator[](int idx) = delete; + Vectorized abs() const { + return _mm256_abs_epi16(values); + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return _mm256_set1_epi16(0); + } + Vectorized conj() const { + return *this; + } + Vectorized neg() const; + Vectorized operator==(const Vectorized& other) const { + return _mm256_cmpeq_epi16(values, other.values); + } + Vectorized operator!=(const Vectorized& other) const { + return invert(_mm256_cmpeq_epi16(values, other.values)); + } + Vectorized operator<(const Vectorized& other) const { + return _mm256_cmpgt_epi16(other.values, values); + } + Vectorized operator<=(const Vectorized& other) const { + return invert(_mm256_cmpgt_epi16(values, other.values)); + } + Vectorized operator>(const Vectorized& other) const { + return _mm256_cmpgt_epi16(values, other.values); + } + Vectorized operator>=(const Vectorized& other) const { + return invert(_mm256_cmpgt_epi16(other.values, values)); + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template +class Vectorized8 : public Vectorizedi { + static_assert( + std::is_same_v || std::is_same_v, + "Only int8_t/uint8_t are supported"); + + protected: + static const Vectorized ones; + + public: + using value_type = T; + static constexpr int size() { + return 32; + } + using Vectorizedi::Vectorizedi; + Vectorized8() {} + Vectorized8(T v) { + values = _mm256_set1_epi8(v); + } + Vectorized8( + T val1, + T val2, + T val3, + T val4, + T val5, + T val6, + T val7, + T val8, + T val9, + T val10, + T val11, + T val12, + T val13, + T val14, + T val15, + T val16, + T val17, + T val18, + T val19, + T val20, + T val21, + T val22, + T val23, + T val24, + T val25, + T val26, + T val27, + T val28, + T val29, + T val30, + T val31, + T val32) { + values = _mm256_setr_epi8( + val1, + val2, + val3, + val4, + val5, + val6, + val7, + val8, + val9, + val10, + val11, + val12, + val13, + val14, + val15, + val16, + val17, + val18, + val19, + val20, + val21, + val22, + val23, + val24, + val25, + val26, + val27, + val28, + val29, + val30, + val31, + val32); + } + template + static Vectorized blend(Vectorized a, Vectorized b) { + __at_align__ T tmp_values[size()]; + a.store(tmp_values); + if (mask & 0x01) + tmp_values[0] = _mm256_extract_epi8(b.values, 0); + if (mask & 0x02) + tmp_values[1] = _mm256_extract_epi8(b.values, 1); + if (mask & 0x04) + tmp_values[2] = _mm256_extract_epi8(b.values, 2); + if (mask & 0x08) + tmp_values[3] = _mm256_extract_epi8(b.values, 3); + if (mask & 0x10) + tmp_values[4] = _mm256_extract_epi8(b.values, 4); + if (mask & 0x20) + tmp_values[5] = _mm256_extract_epi8(b.values, 5); + if (mask & 0x40) + tmp_values[6] = _mm256_extract_epi8(b.values, 6); + if (mask & 0x80) + tmp_values[7] = _mm256_extract_epi8(b.values, 7); + if (mask & 0x100) + tmp_values[8] = _mm256_extract_epi8(b.values, 8); + if (mask & 0x200) + tmp_values[9] = _mm256_extract_epi8(b.values, 9); + if (mask & 0x400) + tmp_values[10] = _mm256_extract_epi8(b.values, 10); + if (mask & 0x800) + tmp_values[11] = _mm256_extract_epi8(b.values, 11); + if (mask & 0x1000) + tmp_values[12] = _mm256_extract_epi8(b.values, 12); + if (mask & 0x2000) + tmp_values[13] = _mm256_extract_epi8(b.values, 13); + if (mask & 0x4000) + tmp_values[14] = _mm256_extract_epi8(b.values, 14); + if (mask & 0x8000) + tmp_values[15] = _mm256_extract_epi8(b.values, 15); + if (mask & 0x010000) + tmp_values[16] = _mm256_extract_epi8(b.values, 16); + if (mask & 0x020000) + tmp_values[17] = _mm256_extract_epi8(b.values, 17); + if (mask & 0x040000) + tmp_values[18] = _mm256_extract_epi8(b.values, 18); + if (mask & 0x080000) + tmp_values[19] = _mm256_extract_epi8(b.values, 19); + if (mask & 0x100000) + tmp_values[20] = _mm256_extract_epi8(b.values, 20); + if (mask & 0x200000) + tmp_values[21] = _mm256_extract_epi8(b.values, 21); + if (mask & 0x400000) + tmp_values[22] = _mm256_extract_epi8(b.values, 22); + if (mask & 0x800000) + tmp_values[23] = _mm256_extract_epi8(b.values, 23); + if (mask & 0x1000000) + tmp_values[24] = _mm256_extract_epi8(b.values, 24); + if (mask & 0x2000000) + tmp_values[25] = _mm256_extract_epi8(b.values, 25); + if (mask & 0x4000000) + tmp_values[26] = _mm256_extract_epi8(b.values, 26); + if (mask & 0x8000000) + tmp_values[27] = _mm256_extract_epi8(b.values, 27); + if (mask & 0x10000000) + tmp_values[28] = _mm256_extract_epi8(b.values, 28); + if (mask & 0x20000000) + tmp_values[29] = _mm256_extract_epi8(b.values, 29); + if (mask & 0x40000000) + tmp_values[30] = _mm256_extract_epi8(b.values, 30); + if (mask & 0x80000000) + tmp_values[31] = _mm256_extract_epi8(b.values, 31); + return loadu(tmp_values); + } + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + return _mm256_blendv_epi8(a.values, b.values, mask.values); + } + template + static Vectorized arange( + T base = 0, + step_t step = static_cast(1)) { + return Vectorized( + base, + base + step, + base + 2 * step, + base + 3 * step, + base + 4 * step, + base + 5 * step, + base + 6 * step, + base + 7 * step, + base + 8 * step, + base + 9 * step, + base + 10 * step, + base + 11 * step, + base + 12 * step, + base + 13 * step, + base + 14 * step, + base + 15 * step, + base + 16 * step, + base + 17 * step, + base + 18 * step, + base + 19 * step, + base + 20 * step, + base + 21 * step, + base + 22 * step, + base + 23 * step, + base + 24 * step, + base + 25 * step, + base + 26 * step, + base + 27 * step, + base + 28 * step, + base + 29 * step, + base + 30 * step, + base + 31 * step); + } + static Vectorized set(Vectorized a, Vectorized b, T count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<0x1>(a, b); + case 2: + return blend<0x3>(a, b); + case 3: + return blend<0x7>(a, b); + case 4: + return blend<0xF>(a, b); + case 5: + return blend<0x1F>(a, b); + case 6: + return blend<0x3F>(a, b); + case 7: + return blend<0x7F>(a, b); + case 8: + return blend<0xFF>(a, b); + case 9: + return blend<0x1FF>(a, b); + case 10: + return blend<0x3FF>(a, b); + case 11: + return blend<0x7FF>(a, b); + case 12: + return blend<0xFFF>(a, b); + case 13: + return blend<0x1FFF>(a, b); + case 14: + return blend<0x3FFF>(a, b); + case 15: + return blend<0x7FFF>(a, b); + case 16: + return blend<0xFFFF>(a, b); + case 17: + return blend<0x1FFFF>(a, b); + case 18: + return blend<0x3FFFF>(a, b); + case 19: + return blend<0x7FFFF>(a, b); + case 20: + return blend<0xFFFFF>(a, b); + case 21: + return blend<0x1FFFFF>(a, b); + case 22: + return blend<0x3FFFFF>(a, b); + case 23: + return blend<0x7FFFFF>(a, b); + case 24: + return blend<0xFFFFFF>(a, b); + case 25: + return blend<0x1FFFFFF>(a, b); + case 26: + return blend<0x3FFFFFF>(a, b); + case 27: + return blend<0x7FFFFFF>(a, b); + case 28: + return blend<0xFFFFFFF>(a, b); + case 29: + return blend<0x1FFFFFFF>(a, b); + case 30: + return blend<0x3FFFFFFF>(a, b); + case 31: + return blend<0x7FFFFFFF>(a, b); + } + return b; + } + static Vectorized loadu(const void* ptr) { + return _mm256_loadu_si256(reinterpret_cast(ptr)); + } + static Vectorized loadu_one_fourth(const void* ptr) { + // Fast path if only load element number of 8. + // Note: We didn't merge it as fast path of loadu(const void* ptr, T count), + // Because loadu(const void* ptr, T count) requires zero initialization for + // upper 128 bits. However, by using _mm256_castsi128_si256, the upper 128 + // bits of the result are undefined. + // TODO We can use _mm256_zextsi128_si256 in the future, + // since gcc 9.3 doesn't support it now. + __m128i input_128 = _mm_loadl_epi64(reinterpret_cast(ptr)); + return _mm256_castsi128_si256(input_128); + } + static Vectorized loadu(const void* ptr, T count) { + __at_align__ T tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to one using "={1}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 1; + } + std::memcpy(tmp_values, ptr, count * sizeof(T)); + return loadu(tmp_values); + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + // ptr need not to be aligned here. See + // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm256-storeu-si256.html + _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values); + } else if (count > 0) { + if (count == 8) { + // Fast path if only store element number of 8 + _mm_storel_epi64( + reinterpret_cast<__m128i*>(ptr), _mm256_castsi256_si128(values)); + } else { + __at_align__ T tmp_values[size()]; + _mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values); + std::memcpy(ptr, tmp_values, count * sizeof(T)); + } + } + } + const T& operator[](int idx) const = delete; + T& operator[](int idx) = delete; + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return _mm256_set1_epi8(0); + } + Vectorized conj() const { + return *this; + } +}; + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized : public Vectorized8 { + public: + using Vectorized8::Vectorized8; + + Vectorized neg() const; + + Vectorized abs() const { + return _mm256_abs_epi8(values); + } + + Vectorized operator==(const Vectorized& other) const { + return _mm256_cmpeq_epi8(values, other.values); + } + Vectorized operator!=(const Vectorized& other) const { + return invert(_mm256_cmpeq_epi8(values, other.values)); + } + Vectorized operator<(const Vectorized& other) const { + return _mm256_cmpgt_epi8(other.values, values); + } + Vectorized operator<=(const Vectorized& other) const { + return invert(_mm256_cmpgt_epi8(values, other.values)); + } + Vectorized operator>(const Vectorized& other) const { + return other < *this; + } + Vectorized operator>=(const Vectorized& other) const { + return other <= *this; + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized : public Vectorized8 { + public: + using Vectorized8::Vectorized8; + + Vectorized neg() const; + + Vectorized abs() const { + return *this; + } + + Vectorized operator==(const Vectorized& other) const { + return _mm256_cmpeq_epi8(values, other.values); + } + Vectorized operator!=(const Vectorized& other) const { + return invert(_mm256_cmpeq_epi8(values, other.values)); + } + Vectorized operator<(const Vectorized& other) const { + __m256i max = _mm256_max_epu8(values, other.values); + return invert(_mm256_cmpeq_epi8(max, values)); + } + Vectorized operator<=(const Vectorized& other) const { + __m256i max = _mm256_max_epu8(values, other.values); + return _mm256_cmpeq_epi8(max, other.values); + } + Vectorized operator>(const Vectorized& other) const { + return other < *this; + } + Vectorized operator>=(const Vectorized& other) const { + return other <= *this; + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return _mm256_add_epi64(a, b); +} + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return _mm256_add_epi32(a, b); +} + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return _mm256_add_epi16(a, b); +} + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return _mm256_add_epi8(a, b); +} + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return _mm256_add_epi8(a, b); +} + +template <> +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return _mm256_sub_epi64(a, b); +} + +template <> +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return _mm256_sub_epi32(a, b); +} + +template <> +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return _mm256_sub_epi16(a, b); +} + +template <> +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return _mm256_sub_epi8(a, b); +} + +template <> +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return _mm256_sub_epi8(a, b); +} + +// Negation. Defined here so we can utilize operator- +inline Vectorized Vectorized::neg() const { + return Vectorized(0) - *this; +} + +inline Vectorized Vectorized::neg() const { + return Vectorized(0) - *this; +} + +inline Vectorized Vectorized::neg() const { + return Vectorized(0) - *this; +} + +inline Vectorized Vectorized::neg() const { + return Vectorized(0) - *this; +} + +inline Vectorized Vectorized::neg() const { + return Vectorized(0) - *this; +} + +// Emulate operations with no native 64-bit support in avx, +// by extracting each element, performing the operation pointwise, +// then combining the results into a vector. +template +Vectorized inline emulate( + const Vectorized& a, + const Vectorized& b, + const op_t& op) { + int64_t a0 = _mm256_extract_epi64(a, 0); + int64_t a1 = _mm256_extract_epi64(a, 1); + int64_t a2 = _mm256_extract_epi64(a, 2); + int64_t a3 = _mm256_extract_epi64(a, 3); + + int64_t b0 = _mm256_extract_epi64(b, 0); + int64_t b1 = _mm256_extract_epi64(b, 1); + int64_t b2 = _mm256_extract_epi64(b, 2); + int64_t b3 = _mm256_extract_epi64(b, 3); + + int64_t c0 = op(a0, b0); + int64_t c1 = op(a1, b1); + int64_t c2 = op(a2, b2); + int64_t c3 = op(a3, b3); + + return _mm256_set_epi64x(c3, c2, c1, c0); +} + +template +Vectorized inline emulate( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c, + const op_t& op) { + int64_t a0 = _mm256_extract_epi64(a, 0); + int64_t a1 = _mm256_extract_epi64(a, 1); + int64_t a2 = _mm256_extract_epi64(a, 2); + int64_t a3 = _mm256_extract_epi64(a, 3); + + int64_t b0 = _mm256_extract_epi64(b, 0); + int64_t b1 = _mm256_extract_epi64(b, 1); + int64_t b2 = _mm256_extract_epi64(b, 2); + int64_t b3 = _mm256_extract_epi64(b, 3); + + int64_t c0 = _mm256_extract_epi64(c, 0); + int64_t c1 = _mm256_extract_epi64(c, 1); + int64_t c2 = _mm256_extract_epi64(c, 2); + int64_t c3 = _mm256_extract_epi64(c, 3); + + int64_t d0 = op(a0, b0, c0); + int64_t d1 = op(a1, b1, c1); + int64_t d2 = op(a2, b2, c2); + int64_t d3 = op(a3, b3, c3); + + return _mm256_set_epi64x(d3, d2, d1, d0); +} + +// AVX2 has no intrinsic for int64_t multiply so it needs to be emulated +// This could be implemented more efficiently using epi32 instructions +// This is also technically avx compatible, but then we'll need AVX +// code for add as well. +// Note: intentionally ignores undefined behavior like (-lowest * -1). +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return emulate( + a, b, [](int64_t a_point, int64_t b_point) __ubsan_ignore_undefined__ { + return a_point * b_point; + }); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return _mm256_mullo_epi32(a, b); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return _mm256_mullo_epi16(a, b); +} + +template +Vectorized inline int_elementwise_binary_256( + const Vectorized& a, + const Vectorized& b, + Op op) { + T values_a[Vectorized::size()]; + T values_b[Vectorized::size()]; + a.store(values_a); + b.store(values_b); + for (int i = 0; i != Vectorized::size(); i++) { + values_a[i] = op(values_a[i], values_b[i]); + } + return Vectorized::loadu(values_a); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + // We don't have an instruction for multiplying int8_t +#ifndef CPU_CAPABILITY_AVX2 + return int_elementwise_binary_256(a, b, std::multiplies()); +#else + __m256i mask00FF = _mm256_set1_epi16(0x00FF); + __m256i a_lo = _mm256_srai_epi16(_mm256_slli_epi16(a, 8), 8); + __m256i b_lo = _mm256_srai_epi16(_mm256_slli_epi16(b, 8), 8); + __m256i a_hi = _mm256_srai_epi16(a, 8); + __m256i b_hi = _mm256_srai_epi16(b, 8); + __m256i res_lo = _mm256_and_si256(_mm256_mullo_epi16(a_lo, b_lo), mask00FF); + __m256i res_hi = _mm256_slli_epi16(_mm256_mullo_epi16(a_hi, b_hi), 8); + __m256i res = _mm256_or_si256(res_hi, res_lo); + return res; +#endif +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + // We don't have an instruction for multiplying uint8_t +#ifndef CPU_CAPABILITY_AVX2 + return int_elementwise_binary_256(a, b, std::multiplies()); +#else + __m256i mask00FF = _mm256_set1_epi16(0x00FF); + __m256i a_lo = _mm256_and_si256(a, mask00FF); + __m256i b_lo = _mm256_and_si256(b, mask00FF); + __m256i a_hi = _mm256_srli_epi16(a, 8); + __m256i b_hi = _mm256_srli_epi16(b, 8); + __m256i res_lo = _mm256_and_si256(_mm256_mullo_epi16(a_lo, b_lo), mask00FF); + __m256i res_hi = _mm256_slli_epi16(_mm256_mullo_epi16(a_hi, b_hi), 8); + __m256i res = _mm256_or_si256(res_hi, res_lo); + return res; +#endif +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { +#ifndef CPU_CAPABILITY_AVX2 + return emulate(a, b, [](int64_t a_point, int64_t b_point) { + return std::min(a_point, b_point); + }); +#else + __m256i cmp = _mm256_cmpgt_epi64(a, b); + return _mm256_blendv_epi8(a, b, cmp); +#endif +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return _mm256_min_epi32(a, b); +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return _mm256_min_epi16(a, b); +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return _mm256_min_epi8(a, b); +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return _mm256_min_epu8(a, b); +} + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { +#ifndef CPU_CAPABILITY_AVX2 + return emulate(a, b, [](int64_t a_point, int64_t b_point) { + return std::max(a_point, b_point); + }); +#else + __m256i cmp = _mm256_cmpgt_epi64(a, b); + return _mm256_blendv_epi8(b, a, cmp); +#endif +} + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return _mm256_max_epi32(a, b); +} + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return _mm256_max_epi16(a, b); +} + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return _mm256_max_epi8(a, b); +} + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return _mm256_max_epu8(a, b); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min_val, + const Vectorized& max_val) { +#ifndef CPU_CAPABILITY_AVX2 + return emulate( + a, + min_val, + max_val, + [](int64_t a_point, int64_t min_point, int64_t max_point) { + return std::min(max_point, std::max(a_point, min_point)); + }); +#else + return minimum(maximum(a, min_val), max_val); +#endif +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min_val, + const Vectorized& max_val) { + return _mm256_min_epi32(max_val, _mm256_max_epi32(a, min_val)); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min_val, + const Vectorized& max_val) { + return _mm256_min_epi16(max_val, _mm256_max_epi16(a, min_val)); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min_val, + const Vectorized& max_val) { + return _mm256_min_epi8(max_val, _mm256_max_epi8(a, min_val)); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min_val, + const Vectorized& max_val) { + return _mm256_min_epu8(max_val, _mm256_max_epu8(a, min_val)); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max_val) { +#ifndef CPU_CAPABILITY_AVX2 + return emulate(a, max_val, [](int64_t a_point, int64_t max_point) { + return std::min(max_point, a_point); + }); +#else + return minimum(max_val, a); +#endif +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max_val) { + return _mm256_min_epi32(max_val, a); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max_val) { + return _mm256_min_epi16(max_val, a); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max_val) { + return _mm256_min_epi8(max_val, a); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max_val) { + return _mm256_min_epu8(max_val, a); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min_val) { +#ifndef CPU_CAPABILITY_AVX2 + return emulate(a, min_val, [](int64_t a_point, int64_t min_point) { + return std::max(min_point, a_point); + }); +#else + return maximum(min_val, a); +#endif +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min_val) { + return _mm256_max_epi32(min_val, a); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min_val) { + return _mm256_max_epi16(min_val, a); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min_val) { + return _mm256_max_epi8(min_val, a); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min_val) { + return _mm256_max_epu8(min_val, a); +} + +template +std::enable_if_t< + !(std::is_same_v || std::is_same_v), + Vectorized< + int32_t>> inline convert_to_int32(const T* ptr, int count = Vectorized::size()) { + return Vectorized::loadu(ptr, count); +} + +template +std:: + enable_if_t, Vectorized> inline convert_to_int32( + const int8_t* ptr, + int count = Vectorized::size()) { + if (count == Vectorized::size()) { + return _mm256_cvtepi8_epi32( + _mm_loadl_epi64(reinterpret_cast(ptr))); + } else { + auto a = Vectorized::loadu(ptr, count); + return _mm256_cvtepi8_epi32(_mm256_castsi256_si128(a)); + } +} + +template +std:: + enable_if_t, Vectorized> inline convert_to_int32( + const uint8_t* ptr, + int count = Vectorized::size()) { + if (count == Vectorized::size()) { + return _mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast(ptr))); + } else { + auto a = Vectorized::loadu(ptr, count); + return _mm256_cvtepu8_epi32(_mm256_castsi256_si128(a)); + } +} + +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return int_elementwise_binary_256(a, b, std::divides()); +} +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return int_elementwise_binary_256(a, b, std::divides()); +} +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return int_elementwise_binary_256(a, b, std::divides()); +} +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return int_elementwise_binary_256(a, b, std::divides()); +} +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return int_elementwise_binary_256(a, b, std::divides()); +} + +template < + class T, + typename std::enable_if_t< + std::is_base_of>::value, + int> = 0> +inline Vectorized operator&(const Vectorized& a, const Vectorized& b) { + return _mm256_and_si256(a, b); +} +template < + class T, + typename std::enable_if_t< + std::is_base_of>::value, + int> = 0> +inline Vectorized operator|(const Vectorized& a, const Vectorized& b) { + return _mm256_or_si256(a, b); +} +template < + class T, + typename std::enable_if_t< + std::is_base_of>::value, + int> = 0> +inline Vectorized operator^(const Vectorized& a, const Vectorized& b) { + return _mm256_xor_si256(a, b); +} +template < + class T, + typename std::enable_if_t< + std::is_base_of>::value, + int> = 0> +inline Vectorized operator~(const Vectorized& a) { + return _mm256_xor_si256(a, _mm256_set1_epi32(-1)); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1); +} + +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1); +} + +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1); +} + +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1); +} + +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1); +} + +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1); +} + +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1); +} + +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1); +} + +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1); +} + +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1); +} + +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1); +} + +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1); +} + +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1); +} + +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1); +} + +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1); +} + +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1); +} + +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1); +} + +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1); +} + +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1); +} + +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1); +} + +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1); +} + +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1); +} + +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1); +} + +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1); +} + +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1); +} + +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1); +} + +template +Vectorized inline shift_256_16( + const Vectorized& a, + const Vectorized& b) { + // No vector instruction for shifting int16_t, so emulating it instead. + + // Control masks for shuffle operation, treating 256 bits as an + // array of 16-bit elements, and considering pairs of neighboring + // elements. Specifically, a mask named "ctl_M_N" (M,N in [0,1], and + // M!=N) is set so that shuffle will move element with index M from + // input pair into element with index N in output pair, and element + // with index M in output pair will be set to all 0s. + __m256i ctl_0_1 = _mm256_set_epi8( + 29, + 28, + 0x80, + 0x80, + 25, + 24, + 0x80, + 0x80, + 21, + 20, + 0x80, + 0x80, + 17, + 16, + 0x80, + 0x80, + 13, + 12, + 0x80, + 0x80, + 9, + 8, + 0x80, + 0x80, + 5, + 4, + 0x80, + 0x80, + 1, + 0, + 0x80, + 0x80); + __m256i ctl_1_0 = _mm256_set_epi8( + 0x80, + 0x80, + 31, + 30, + 0x80, + 0x80, + 27, + 26, + 0x80, + 0x80, + 23, + 22, + 0x80, + 0x80, + 19, + 18, + 0x80, + 0x80, + 15, + 14, + 0x80, + 0x80, + 11, + 10, + 0x80, + 0x80, + 7, + 6, + 0x80, + 0x80, + 3, + 2); + + // Masks for bitwise and operation, treating 256 bits as an array of + // 16-bit elements, and considering them in pairs of neighboring + // elements. A mask named "keep_M" (M in [0,1]) is set so that + // bitwise and will copy element with index M from input pair into + // element with the same index in output pair, while the other + // element in output pair will be set to all 0s. + __m256i keep_0 = _mm256_set1_epi32(0xFFFF); + __m256i keep_1 = _mm256_set1_epi32(0xFFFF0000); + + // Take each 16-bit element with idx%2==0 from input array to be + // shifted and extend it to 32 bits so that 0s are added to the + // right. Then, perform shifting on this 32-bit number. Upper 16 + // bits will be proper result of shifting original 16-bit number, so + // write them to result array, into the same position from which + // corresponding input element is taken. Also, make sure that + // result array elements with idx%2!=0 are set to all 0s. + // + // Note that number of bits to shift for is extended to 32 bits by + // adding 0s to the left. That means this number is not properly + // sign-extended for negative values. However, number of bits to + // shift is treated as an unsigned integer by respective shift + // intrinsics anyway so if negative then either with or without + // proper sign extension, it will be interpreted as a number greater + // than 32, and the shifting result will be the same. + __m256i a0 = _mm256_shuffle_epi8(a, ctl_0_1); + __m256i b0 = _mm256_and_si256(b, keep_0); + __m256i c0; + if (left_shift) + c0 = _mm256_sllv_epi32(a0, b0); + else + c0 = _mm256_srav_epi32(a0, b0); + c0 = _mm256_shuffle_epi8(c0, ctl_1_0); + + // Perform shifting the same way for input array elements with + // idx%2==1. + __m256i a1 = _mm256_and_si256(a, keep_1); + __m256i b1 = _mm256_shuffle_epi8(b, ctl_1_0); + __m256i c1; + if (left_shift) + c1 = _mm256_sllv_epi32(a1, b1); + else + c1 = _mm256_srav_epi32(a1, b1); + c1 = _mm256_and_si256(c1, keep_1); + + // Merge partial results into the final result. + __m256i c = _mm256_or_si256(c0, c1); + + return c; +} + +template < + bool left_shift, + typename T, + typename std::enable_if_t< + std::is_same_v || std::is_same_v, + int> = 0> +Vectorized inline shift_256_8( + const Vectorized& a, + const Vectorized& b) { + // No vector instruction for shifting int8_t/uint8_t, so emulating + // it instead. + + // Control masks for shuffle operation, treating 256 bits as an + // array of 8-bit elements, and considering quadruples of + // neighboring elements. Specifically, a mask named "ctl_M_N" (M,N + // in [0,1,2,3], and M!=N) is set so that shuffle will move element + // with index M from input quadruple into element with index N in + // output quadruple, and other elements in output quadruple will be + // set to all 0s. + __m256i ctl_0_3 = _mm256_set_epi8( + 28, + 0x80, + 0x80, + 0x80, + 24, + 0x80, + 0x80, + 0x80, + 20, + 0x80, + 0x80, + 0x80, + 16, + 0x80, + 0x80, + 0x80, + 12, + 0x80, + 0x80, + 0x80, + 8, + 0x80, + 0x80, + 0x80, + 4, + 0x80, + 0x80, + 0x80, + 0, + 0x80, + 0x80, + 0x80); + __m256i ctl_1_0 = _mm256_set_epi8( + 0x80, + 0x80, + 0x80, + 29, + 0x80, + 0x80, + 0x80, + 25, + 0x80, + 0x80, + 0x80, + 21, + 0x80, + 0x80, + 0x80, + 17, + 0x80, + 0x80, + 0x80, + 13, + 0x80, + 0x80, + 0x80, + 9, + 0x80, + 0x80, + 0x80, + 5, + 0x80, + 0x80, + 0x80, + 1); + __m256i ctl_1_3 = _mm256_set_epi8( + 29, + 0x80, + 0x80, + 0x80, + 25, + 0x80, + 0x80, + 0x80, + 21, + 0x80, + 0x80, + 0x80, + 17, + 0x80, + 0x80, + 0x80, + 13, + 0x80, + 0x80, + 0x80, + 9, + 0x80, + 0x80, + 0x80, + 5, + 0x80, + 0x80, + 0x80, + 1, + 0x80, + 0x80, + 0x80); + __m256i ctl_2_0 = _mm256_set_epi8( + 0x80, + 0x80, + 0x80, + 30, + 0x80, + 0x80, + 0x80, + 26, + 0x80, + 0x80, + 0x80, + 22, + 0x80, + 0x80, + 0x80, + 18, + 0x80, + 0x80, + 0x80, + 14, + 0x80, + 0x80, + 0x80, + 10, + 0x80, + 0x80, + 0x80, + 6, + 0x80, + 0x80, + 0x80, + 2); + __m256i ctl_2_3 = _mm256_set_epi8( + 30, + 0x80, + 0x80, + 0x80, + 26, + 0x80, + 0x80, + 0x80, + 22, + 0x80, + 0x80, + 0x80, + 18, + 0x80, + 0x80, + 0x80, + 14, + 0x80, + 0x80, + 0x80, + 10, + 0x80, + 0x80, + 0x80, + 6, + 0x80, + 0x80, + 0x80, + 2, + 0x80, + 0x80, + 0x80); + __m256i ctl_3_0 = _mm256_set_epi8( + 0x80, + 0x80, + 0x80, + 31, + 0x80, + 0x80, + 0x80, + 27, + 0x80, + 0x80, + 0x80, + 23, + 0x80, + 0x80, + 0x80, + 19, + 0x80, + 0x80, + 0x80, + 15, + 0x80, + 0x80, + 0x80, + 11, + 0x80, + 0x80, + 0x80, + 7, + 0x80, + 0x80, + 0x80, + 3); + __m256i ctl_3_1 = _mm256_set_epi8( + 0x80, + 0x80, + 31, + 0x80, + 0x80, + 0x80, + 27, + 0x80, + 0x80, + 0x80, + 23, + 0x80, + 0x80, + 0x80, + 19, + 0x80, + 0x80, + 0x80, + 15, + 0x80, + 0x80, + 0x80, + 11, + 0x80, + 0x80, + 0x80, + 7, + 0x80, + 0x80, + 0x80, + 3, + 0x80); + __m256i ctl_3_2 = _mm256_set_epi8( + 0x80, + 31, + 0x80, + 0x80, + 0x80, + 27, + 0x80, + 0x80, + 0x80, + 23, + 0x80, + 0x80, + 0x80, + 19, + 0x80, + 0x80, + 0x80, + 15, + 0x80, + 0x80, + 0x80, + 11, + 0x80, + 0x80, + 0x80, + 7, + 0x80, + 0x80, + 0x80, + 3, + 0x80, + 0x80); + + // Masks for bitwise and operation, treating 256 bits as an array of + // 8-bit elements, and considering them in quadruples of neighboring + // elements. A mask named "keep_M" (M in [0,1,2,3]) is set so that + // bitwise and will copy element with index M from input quadruple + // into element with the same index in output quadruple, while the + // other elements in output quadruple will be set to all 0s. + __m256i keep_0 = _mm256_set1_epi32(0xFF); + __m256i keep_3 = _mm256_set1_epi32(0xFF000000); + + // Take each 8-bit element with idx%4==0 from input array to be + // shifted and extend it to 32 bits so that 0s are added to the + // right. Then, perform shifting on this 32-bit number. Upper 8 + // bits will be proper result of shifting original 8-bit number, so + // write them to result array, into the same position from which + // corresponding input element is taken. Also, make sure that + // result array elements with idx%4!=0 are set to all 0s. + // + // Note that number of bits to shift for is extended to 32 bits by + // adding 0s to the left. That means this number is not properly + // sign-extended for negative values. However, number of bits to + // shift is treated as an unsigned integer by respective shift + // intrinsics anyway so if negative then either with or without + // proper sign extension, it will be interpreted as a number greater + // than 32, and the shifting result will be the same. + __m256i a0 = _mm256_shuffle_epi8(a, ctl_0_3); + __m256i b0 = _mm256_and_si256(b, keep_0); + __m256i c0; + if (left_shift) + c0 = _mm256_sllv_epi32(a0, b0); + else if constexpr (std::is_same_v) + c0 = _mm256_srav_epi32(a0, b0); + else + c0 = _mm256_srlv_epi32(a0, b0); + c0 = _mm256_shuffle_epi8(c0, ctl_3_0); + + // Perform shifting the same way for input array elements with + // idx%4==1. + __m256i a1 = _mm256_shuffle_epi8(a, ctl_1_3); + __m256i b1 = _mm256_shuffle_epi8(b, ctl_1_0); + __m256i c1; + if (left_shift) + c1 = _mm256_sllv_epi32(a1, b1); + else if constexpr (std::is_same_v) + c1 = _mm256_srav_epi32(a1, b1); + else + c1 = _mm256_srlv_epi32(a1, b1); + c1 = _mm256_shuffle_epi8(c1, ctl_3_1); + + // Perform shifting the same way for input array elements with + // idx%4==2. + __m256i a2 = _mm256_shuffle_epi8(a, ctl_2_3); + __m256i b2 = _mm256_shuffle_epi8(b, ctl_2_0); + __m256i c2; + if (left_shift) + c2 = _mm256_sllv_epi32(a2, b2); + else if constexpr (std::is_same_v) + c2 = _mm256_srav_epi32(a2, b2); + else + c2 = _mm256_srlv_epi32(a2, b2); + c2 = _mm256_shuffle_epi8(c2, ctl_3_2); + + // Perform shifting the same way for input array elements with + // idx%4==3. + __m256i a3 = _mm256_and_si256(a, keep_3); + __m256i b3 = _mm256_shuffle_epi8(b, ctl_3_0); + __m256i c3; + if (left_shift) + c3 = _mm256_sllv_epi32(a3, b3); + else if constexpr (std::is_same_v) + c3 = _mm256_srav_epi32(a3, b3); + else + c3 = _mm256_srlv_epi32(a3, b3); + c3 = _mm256_and_si256(c3, keep_3); + + // Merge partial results into the final result. + __m256i c01 = _mm256_or_si256(c0, c1); + __m256i c23 = _mm256_or_si256(c2, c3); + __m256i c = _mm256_or_si256(c01, c23); + + return c; +} + +template <> +Vectorized inline operator<<( + const Vectorized& a, + const Vectorized& b) { + return _mm256_sllv_epi64(a, b); +} + +template <> +Vectorized inline operator<<( + const Vectorized& a, + const Vectorized& b) { + return _mm256_sllv_epi32(a, b); +} + +template <> +Vectorized inline operator<<( + const Vectorized& a, + const Vectorized& b) { + return shift_256_16(a, b); +} + +template <> +Vectorized inline operator<<( + const Vectorized& a, + const Vectorized& b) { + return shift_256_8(a, b); +} + +template <> +Vectorized inline operator<<( + const Vectorized& a, + const Vectorized& b) { + return shift_256_8(a, b); +} + +template <> +Vectorized inline operator>>( + const Vectorized& a, + const Vectorized& b) { + // No vector instruction for right arithmetic shifting int64_t, so emulating + // it instead. + + // Clamp the shift values such that shift values < 0 and > 64 are changed to + // 64 which results in -1 for negative input and 0 for non-negative input. + __m256i zero = _mm256_set1_epi64x(0); + __m256i max_shift = _mm256_set1_epi64x(64); + __m256i mask = _mm256_or_si256( + _mm256_cmpgt_epi64(zero, b), _mm256_cmpgt_epi64(b, max_shift)); + __m256i shift = _mm256_blendv_epi8(b, max_shift, mask); + // Shift the number logically to the right, thus filling the most + // significant bits with 0s. Then, replace these bits with the sign + // bit. + __m256i sign_bits = _mm256_cmpgt_epi64(zero, a); + __m256i sign_shift = _mm256_sub_epi64(max_shift, shift); + __m256i sign_ext = _mm256_sllv_epi64(sign_bits, sign_shift); + __m256i c = _mm256_srlv_epi64(a, shift); + c = _mm256_or_si256(c, sign_ext); + + return c; +} + +template <> +Vectorized inline operator>>( + const Vectorized& a, + const Vectorized& b) { + return _mm256_srav_epi32(a, b); +} + +template <> +Vectorized inline operator>>( + const Vectorized& a, + const Vectorized& b) { + return shift_256_16(a, b); +} + +template <> +Vectorized inline operator>>( + const Vectorized& a, + const Vectorized& b) { + return shift_256_8(a, b); +} + +template <> +Vectorized inline operator>>( + const Vectorized& a, + const Vectorized& b) { + return shift_256_8(a, b); +} + +#endif + +} // namespace CPU_CAPABILITY +} // namespace at::vec + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_mask.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_mask.h new file mode 100644 index 0000000000000000000000000000000000000000..595e0c4946a461bb6cc446d202f2156ef4bfbdc9 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_mask.h @@ -0,0 +1,303 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include + +namespace at::vec { +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) + +template +struct VecMaskLoad< + T, + dst_n, + mask_t, + mask_n, + typename std::enable_if_t< + (mask_n == dst_n * 2 && dst_n >= 1) && + (std::is_same_v || std::is_same_v), + void>> { + static inline VectorizedN apply( + const T* ptr, + const VecMask& vec_mask) { + VectorizedN tmp_vec; + VectorizedN result; + for (int i = 0; i < dst_n; i++) { + tmp_vec[0] = vec_mask[2 * i]; + tmp_vec[1] = vec_mask[2 * i + 1]; + auto int64_mask = VecMask(tmp_vec).template cast(); + auto int_mask = int64_mask.template cast()[0]; + if constexpr (std::is_same_v) { + result[i] = Vectorized( + _mm256_maskload_ps(ptr + i * Vectorized::size(), int_mask)); + } else { + result[i] = Vectorized( + _mm256_maskload_epi32(ptr + i * Vectorized::size(), int_mask)); + } + } + return result; + } +}; + +template +struct VecMaskLoad< + T, + dst_n, + mask_t, + dst_n, + typename std::enable_if_t< + std::is_same_v || std::is_same_v, + void>> { + static inline VectorizedN apply( + const T* ptr, + const VecMask& vec_mask) { + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < dst_n; i++) { + auto tmp_mask = VecMask(vec_mask[i]); + auto int_mask = tmp_mask.template cast()[0]; + if constexpr (std::is_same_v) { + result[i] = Vectorized( + _mm256_maskload_ps(ptr + i * Vectorized::size(), int_mask)); + } else { + result[i] = Vectorized( + _mm256_maskload_epi32(ptr + i * Vectorized::size(), int_mask)); + } + } + return result; + } +}; + +template +struct VecMaskLoad< + T, + 2, + mask_t, + 1, + typename std::enable_if_t< + std::is_same_v || std::is_same_v>> { + static inline VectorizedN apply( + const T* ptr, + const VecMask& vec_mask) { + auto int64_mask = vec_mask.template cast(); + auto result = at::vec::VectorizedN(); + if constexpr (std::is_same_v) { + result[0] = _mm256_maskload_pd(ptr, int64_mask[0]); + result[1] = _mm256_maskload_pd( + ptr + at::vec::Vectorized::size(), int64_mask[1]); + } else { + result[0] = _mm256_maskload_epi64( + reinterpret_cast(ptr), int64_mask[0]); + result[1] = _mm256_maskload_epi64( + reinterpret_cast( + ptr + at::vec::Vectorized::size()), + int64_mask[1]); + } + return result; + } +}; + +// TODO: add specialization of VecMaskLoad for bfloat16/half and int8/uint8 + +template +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < N; ++i) { + result[i] = _mm256_castsi256_ps(vec_mask[i]); + } + return result; + } +}; + +template +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < N; ++i) { + result[i] = _mm256_castps_si256(vec_mask[i]); + } + return result; + } +}; + +template +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < N; ++i) { + result[i] = _mm256_castpd_si256(vec_mask[i]); + } + return result; + } +}; + +template +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < N; ++i) { + result[i] = _mm256_castsi256_pd(vec_mask[i]); + } + return result; + } +}; + +template +struct VecMaskCast< + int64_t, + dst_n, + mask_t, + mask_n, + typename std::enable_if_t< + (dst_n == 2 * mask_n) && + (std::is_same_v || std::is_same_v), + void>> { + static inline VecMask apply( + const VecMask& vec_mask) { + VectorizedN result; + auto int_mask = vec_mask.template cast(); +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < mask_n; ++i) { + auto int64_vec = + convert(VectorizedN(int_mask[i])); + result[2 * i] = int64_vec[0]; + result[2 * i + 1] = int64_vec[1]; + } + return VecMask(result); + } +}; + +template +struct VecMaskCast< + dst_t, + dst_n, + int64_t, + mask_n, + typename std::enable_if_t< + (mask_n == 2 * dst_n) && + (std::is_same_v || std::is_same_v), + void>> { + static inline VecMask apply( + const VecMask& vec_mask) { + VectorizedN result; + VectorizedN int64_vec; + for (int i = 0; i < dst_n; ++i) { + int64_vec[0] = vec_mask[2 * i]; + int64_vec[1] = vec_mask[2 * i + 1]; + result[i] = convert(int64_vec); + } + return VecMask(result).template cast(); + } +}; + +template <> +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + auto int64_mask = VecMaskCast::apply(vec_mask); + return VecMaskCast::apply(int64_mask); + } +}; +template <> +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + auto int64_mask = VecMaskCast::apply(vec_mask); + return VecMaskCast::apply(int64_mask); + } +}; + +template <> +inline bool VecMask::all_zero() const { + return _mm256_testz_si256(mask_[0], mask_[0]); +} + +template <> +inline bool VecMask::is_masked(int i) const { + return _mm256_movemask_ps(_mm256_castsi256_ps(mask_[0])) & (1 << i); +} + +template <> +inline bool VecMask::all_masked() const { + int mask = _mm256_movemask_ps(_mm256_castsi256_ps(mask_[0])); + return mask == 0xff; +} + +template +struct VecMaskCheck { + static inline bool all_zero(const VectorizedN& vec_mask) { + bool all_zero = true; + for (int i = 0; i < N; ++i) { + all_zero = all_zero && (_mm256_testz_si256(vec_mask[i], vec_mask[i]) > 0); + if (!all_zero) { + return all_zero; + } + } + return all_zero; + } + + static inline bool is_masked(const VectorizedN& vec_mask, int i) { + for (int j = 0; j < N; ++j) { + if (i < (j + 1) * 4) { + return _mm256_movemask_pd(_mm256_castsi256_pd(vec_mask[j])) & + (1 << (i - j * 4)); + } + } + return false; + } + + static inline bool all_masked(const VectorizedN& vec_mask) { + bool all_masked = true; + for (int i = 0; i < N; ++i) { + all_masked = all_masked && + (_mm256_movemask_pd(_mm256_castsi256_pd(vec_mask[i])) == 0x0f); + if (!all_masked) { + return all_masked; + } + } + return all_masked; + } +}; + +#define VEC_MASK_METHOD_WITH_CAST_TO_INT( \ + T, N, return_type, method, args_def, args) \ + template <> \ + inline return_type VecMask::method args_def const { \ + return cast().method args; \ + } + +VEC_MASK_METHOD_WITH_CAST_TO_INT(float, 1, bool, all_zero, (), ()) +VEC_MASK_METHOD_WITH_CAST_TO_INT(int64_t, 2, bool, all_zero, (), ()) +VEC_MASK_METHOD_WITH_CAST_TO_INT(float, 1, bool, is_masked, (int i), (i)) +VEC_MASK_METHOD_WITH_CAST_TO_INT(int64_t, 2, bool, is_masked, (int i), (i)) +VEC_MASK_METHOD_WITH_CAST_TO_INT(float, 1, bool, all_masked, (), ()) +VEC_MASK_METHOD_WITH_CAST_TO_INT(int64_t, 2, bool, all_masked, (), ()) + +#undef VEC_MASK_DEFINE_METHOD_WITH_CAST_TO_INT + +#endif + +} // namespace CPU_CAPABILITY +} // namespace at::vec + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_qint.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_qint.h new file mode 100644 index 0000000000000000000000000000000000000000..7e77d78528b5d6a069347064e8dc21cbf6151682 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_qint.h @@ -0,0 +1,1429 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include +#include + +#include +#include +#include +#include + +#include +#include + +// This file defines Vectorized<> for the quantized types. +// +// +// Currently, we simply use these classes as efficient converters between +// the quantized types and Vectorized, usually in bandwidth-bound cases +// where doing the arithmetic in full-precision is acceptable (e.g. +// elementwise operators). +// +// +// Conversions are as follows: +// Vectorized -> 4x Vectorized +// Vectorized -> 4x Vectorized +// Vectorized -> 1x Vectorized +// +// The size of the returned float vector is specified by the special +// constexpr function float_num_vecs. The type of the value returned +// from dequantize (and expected as an argument to quantize) is +// specified by float_vec_return_type. +// +// When writing kernels with these vectors, it is expected that floating- +// point operations will be carried out in a loop over +// Vectorized::float_num_vecs iterations. + +namespace at::vec { +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_AVX2) + +#ifdef _MSC_VER +__declspec(align(64)) struct Vectorizedqi { + protected: + __m256i vals; +#else +struct Vectorizedqi { + protected: + __m256i vals __attribute__((aligned(64))); +#endif + + public: + Vectorizedqi() { + vals = _mm256_setzero_si256(); + } + Vectorizedqi(__m256i v) : vals(v) {} + operator __m256i() const { + return vals; + } +}; + +template +__m256i pack_saturate_and_clamp( + __m256i first, + __m256i second, + T min_val, + T max_val); + +template <> +inline __m256i pack_saturate_and_clamp( + __m256i /*first*/, + __m256i /*second*/, + int32_t /*min_val*/, + int32_t /*max_val*/) { + // This function is for linkage only, will not be used + TORCH_CHECK(false, "pack_saturate_and_clamp is not supported"); +} + +template <> +inline __m256i pack_saturate_and_clamp( + __m256i first, + __m256i second, + int8_t min_val, + int8_t max_val) { + __m256i packed_and_sat = _mm256_packs_epi16(first, second); + return _mm256_max_epi8( + _mm256_set1_epi8(min_val), + _mm256_min_epi8(packed_and_sat, _mm256_set1_epi8(max_val))); +} + +template <> +inline __m256i pack_saturate_and_clamp( + __m256i first, + __m256i second, + uint8_t min_val, + uint8_t max_val) { + __m256i packed_and_sat = _mm256_packus_epi16(first, second); + return _mm256_max_epu8( + _mm256_set1_epi8(min_val), + _mm256_min_epu8(packed_and_sat, _mm256_set1_epi8(max_val))); +} + +template +typename std::enable_if_t< + std::is_same_v || std::is_same_v, + at::vec::Vectorized< + float>> inline convert_int8_to_float(at::vec::Vectorized src) { + // Note: this function only convert inputs number of elements equal to + // at::vec::Vectorized.size() Only handle first 8*8 bits + __m128i input_128 = _mm256_castsi256_si128(src); + // Convert from 8*uint8/int8 to 8*int32 + __m256i input_256_int32; + if constexpr (std::is_same_v) + input_256_int32 = _mm256_cvtepu8_epi32(input_128); + else + input_256_int32 = _mm256_cvtepi8_epi32(input_128); + // Convert from 8*int32 to 8*float + return _mm256_cvtepi32_ps(input_256_int32); +} + +template +at::vec::Vectorized inline convert_float_to_int8( + at::vec::Vectorized src); + +template <> +at::vec::Vectorized inline convert_float_to_int8( + at::vec::Vectorized src) { + // Convert from float32 to int32 with truncation + __m256i x_values_int32 = _mm256_cvttps_epi32(src); + + // Convert from int32 to int16 using signed saturation + __m256i xy_packed_v = _mm256_packs_epi32(x_values_int32, x_values_int32); + + constexpr auto min_val = std::numeric_limits::min(); + constexpr auto max_val = std::numeric_limits::max(); + + // Convert from int16 to int8 using unsigned saturation + __m256i xyzw_clamped_v = pack_saturate_and_clamp( + xy_packed_v, xy_packed_v, min_val, max_val); + __m256i permute_mask_v = + _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00); + return _mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v); +} + +template <> +at::vec::Vectorized inline convert_float_to_int8( + at::vec::Vectorized src) { + // The type of *_val should be int32_t to ensure correct clamping behavior. + constexpr auto min_val = std::numeric_limits::min(); + constexpr auto max_val = std::numeric_limits::max(); + __m256 float32_min_val = _mm256_set1_ps(float(min_val)); + __m256 float32_max_val = _mm256_set1_ps(float(max_val)); + __m256 float32_src = _mm256_max_ps(src, float32_min_val); + float32_src = _mm256_min_ps(float32_src, float32_max_val); + __m256i truncated_src = _mm256_cvttps_epi32(float32_src); + + __m128i r1 = _mm256_castsi256_si128(truncated_src); + __m128i mask = _mm_setr_epi8( + 0, 4, 8, 12, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1); + __m128i r1_shuffled = _mm_shuffle_epi8(r1, mask); + __m128i r2 = _mm256_extractf128_si256(truncated_src, 1); + __m128i r2_shuffled = _mm_shuffle_epi8(r2, mask); + __m128i result = _mm_unpacklo_epi32(r1_shuffled, r2_shuffled); + + return _mm256_castsi128_si256(result); +} + +template +__FORCE_INLINE void QuantizeAvx2( + const float* src, + T* dst, + int len, + float inverse_scale, + int64_t zero_point) { + constexpr int VLEN = 8; + constexpr auto min_val = std::numeric_limits::min(); + constexpr auto max_val = std::numeric_limits::max(); + const __m256i min_v = _mm256_set1_epi32(min_val); + const __m256i max_v = _mm256_set1_epi32(max_val); + // This is the largest int32 value < int32_max exactly representable in float + constexpr int32_t int32_float_max_val = + std::numeric_limits::max() - 127; + int i = 0; + __m256 inverse_scale_v = _mm256_set1_ps(inverse_scale); + // clang-format off + static const __m256i shuffle_mask_v = _mm256_set_epi8( + 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, + 0x0c, 0x08, 0x04, 0x00, + 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, + 0x0c, 0x08, 0x04, 0x00); + // clang-format on + __m256i permute_mask_v = + _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00); + __m256i permute_mask_l8_v = + _mm256_set_epi32(0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00); + int len_aligned = len / (VLEN * 4) * (VLEN * 4); + for (; i < len_aligned; i += 4 * VLEN) { + // x + __m256 x_vals = _mm256_load_ps(src + i); + __m256 x_transformed_v = _mm256_mul_ps(x_vals, inverse_scale_v); + // If the floating point value is greater than int32_max, + // _mm256_cvtps_epi32 converts them to -ve. Clip at int32_float_max_val to + // Clip at int32_float_max_val to avoid this. + x_transformed_v = + _mm256_min_ps(x_transformed_v, _mm256_set1_ps(int32_float_max_val)); + // y + __m256 y_vals = _mm256_load_ps(src + i + VLEN); + __m256 y_transformed_v = _mm256_mul_ps(y_vals, inverse_scale_v); + y_transformed_v = + _mm256_min_ps(y_transformed_v, _mm256_set1_ps(int32_float_max_val)); + // z + __m256 z_vals = _mm256_load_ps(src + i + 2 * VLEN); + __m256 z_transformed_v = _mm256_mul_ps(z_vals, inverse_scale_v); + z_transformed_v = + _mm256_min_ps(z_transformed_v, _mm256_set1_ps(int32_float_max_val)); + // w + __m256 w_vals = _mm256_load_ps(src + i + 3 * VLEN); + __m256 w_transformed_v = _mm256_mul_ps(w_vals, inverse_scale_v); + w_transformed_v = + _mm256_min_ps(w_transformed_v, _mm256_set1_ps(int32_float_max_val)); + + __m256i x_rounded_v = _mm256_cvtps_epi32(x_transformed_v); + __m256i y_rounded_v = _mm256_cvtps_epi32(y_transformed_v); + __m256i z_rounded_v = _mm256_cvtps_epi32(z_transformed_v); + __m256i w_rounded_v = _mm256_cvtps_epi32(w_transformed_v); + + // add zero point + x_rounded_v = _mm256_add_epi32(x_rounded_v, _mm256_set1_epi32(zero_point)); + y_rounded_v = _mm256_add_epi32(y_rounded_v, _mm256_set1_epi32(zero_point)); + z_rounded_v = _mm256_add_epi32(z_rounded_v, _mm256_set1_epi32(zero_point)); + w_rounded_v = _mm256_add_epi32(w_rounded_v, _mm256_set1_epi32(zero_point)); + + __m256i xy_packed_v = _mm256_packs_epi32(x_rounded_v, y_rounded_v); + __m256i zw_packed_v = _mm256_packs_epi32(z_rounded_v, w_rounded_v); + __m256i xyzw_clamped_v = + pack_saturate_and_clamp(xy_packed_v, zw_packed_v, min_val, max_val); + + xyzw_clamped_v = + _mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst + i), xyzw_clamped_v); + } + + // Additional 8-lane AVX2 version to take advantage when len is smaller + // based on fbgemm::QuantizeAvx2 (https://github.com/pytorch/FBGEMM) + for (; i < len / VLEN * VLEN; i += VLEN) { + __m256 x_vals = _mm256_load_ps(src + i); + __m256 x_transformed_v = _mm256_mul_ps(x_vals, inverse_scale_v); + x_transformed_v = + _mm256_min_ps(x_transformed_v, _mm256_set1_ps(int32_float_max_val)); + __m256i x_rounded_v = _mm256_cvtps_epi32(x_transformed_v); + x_rounded_v = _mm256_add_epi32(x_rounded_v, _mm256_set1_epi32(zero_point)); + __m256i x_clipped_v = + _mm256_max_epi32(min_v, _mm256_min_epi32(max_v, x_rounded_v)); + + x_clipped_v = _mm256_shuffle_epi8(x_clipped_v, shuffle_mask_v); + x_clipped_v = _mm256_permutevar8x32_epi32(x_clipped_v, permute_mask_l8_v); + _mm_storel_epi64( + reinterpret_cast<__m128i*>(dst + i), + _mm256_castsi256_si128(x_clipped_v)); + } + + for (; i < len; ++i) { + float transformed = src[i] * inverse_scale; + + // Not exactly the same behavior as the vectorized code. + // The vectorized code above always rounds to even in halfway cases + // (https://software.intel.com/en-us/node/523819), but std::nearbyint + // does the same only when the current rounding mode is FE_TONEAREST. + // However, in practice, this should not be a problem because most cases + // use the default rounding mode FE_TONEAREST. + // Note that we cannot implement the same behavior as the vectorized code + // using std::round because it does rounding away from zero in halfway + // cases. + transformed = zero_point + std::nearbyint(transformed); + float clipped = + std::min(std::max(transformed, float(min_val)), float(max_val)); + dst[i] = clipped; + } +} + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +struct Vectorized : public Vectorizedqi { + using size_type = int; + static constexpr size_type kSize = Vectorized::size(); + static constexpr size_type size() { + return kSize; + } + + static constexpr int kFloatNumVecs = kSize / Vectorized::size(); + static constexpr int float_num_vecs() { + return kFloatNumVecs; + } + + static constexpr int int_num_vecs() { + return 1; + } + + using float_vec_return_type = std::array, kFloatNumVecs>; + using int_vec_return_type = std::array, 1>; + using value_type = c10::qint32::underlying; + + public: + using Vectorizedqi::Vectorizedqi; + Vectorized() {} + + Vectorized(__m256i vals_) { + vals = vals_; + } + + // Broadcast constructor + Vectorized(const c10::qint32& val) { + value_type uw = val.val_; + vals = _mm256_set1_epi32(uw); + } + + void store(void* ptr, int count = size()) const { + if (count != size()) { + memcpy(ptr, &vals, count * sizeof(value_type)); + } else { + _mm256_storeu_si256((__m256i*)ptr, vals); + } + } + + static Vectorized loadu(const void* ptr) { + return Vectorized(ptr); + } + + static Vectorized loadu(const void* ptr, int64_t count) { + __at_align__ value_type tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(value_type)); + return _mm256_loadu_si256((const __m256i*)tmp_values); + } + + float_vec_return_type dequantize( + Vectorized scale, + Vectorized /*zero_point*/, + Vectorized scale_zp_premul) const { + __m256 float_vals = _mm256_cvtepi32_ps(vals); + return {vec::fmadd(scale, Vectorized(float_vals), scale_zp_premul)}; + } + + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point) const { + __m256 float_vals = _mm256_cvtepi32_ps(vals); + return {(Vectorized(float_vals) - zero_point) * scale}; + } + + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float /*inverse_scale*/) { + Vectorized retval; + auto rhs_data = (__m256)rhs[0]; + at::native::quantize_vec( + scale, + zero_point, + (float*)&rhs_data, + (c10::qint32*)&retval.vals, + size()); + return retval; + } + + Vectorized maximum(Vectorized b) const { + return _mm256_max_epi32(vals, b.vals); + } + + Vectorized minimum(Vectorized b) const { + return _mm256_min_epi32(vals, b.vals); + } + + Vectorized relu(Vectorized zero_point) const { + return maximum(zero_point); + } + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) { + return _mm256_min_epi32( + _mm256_max_epi32(vals, zero_point.vals), q_six.vals); + } + + int_vec_return_type widening_subtract(Vectorized b) const { + return {_mm256_sub_epi32(vals, b)}; + } + + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + __m256 multiplier_v = _mm256_set1_ps(multiplier); + __m256i zero_point_v = _mm256_set1_epi32(zero_point); + + __m256 scaled = _mm256_mul_ps(_mm256_cvtepi32_ps(inp[0]), multiplier_v); + __m256i rounded = _mm256_cvtps_epi32(scaled); + return _mm256_add_epi32(rounded, zero_point_v); + } + + private: + // Load from memory constructor + Vectorized(const void* ptr) { + vals = _mm256_loadu_si256((const __m256i*)ptr); + } +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return _mm256_mullo_epi32(a, b); +} + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return _mm256_add_epi32(a, b); +} + +/* + * Convert values from int32 back to int8/uint8 + */ +template +__m256i RequantizeAvx2( + const std::array, 4>& inp, + __m256 multiplier, + __m256i zp) { + static_assert( + std::is_same_v || std::is_same_v, + "Only int8_t/uint8_t are supported"); + constexpr auto min_val = std::numeric_limits::min(); + constexpr auto max_val = std::numeric_limits::max(); + __m256i permute_mask_v = + _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00); + __m256 x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(inp[0]), multiplier); + __m256 y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(inp[1]), multiplier); + __m256 z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(inp[2]), multiplier); + __m256 w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(inp[3]), multiplier); + + __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v); + __m256i y_rounded_v = _mm256_cvtps_epi32(y_scaled_v); + __m256i z_rounded_v = _mm256_cvtps_epi32(z_scaled_v); + __m256i w_rounded_v = _mm256_cvtps_epi32(w_scaled_v); + + /* Add zero point */ + __m256i x_v = _mm256_add_epi32(x_rounded_v, zp); + __m256i y_v = _mm256_add_epi32(y_rounded_v, zp); + __m256i z_v = _mm256_add_epi32(z_rounded_v, zp); + __m256i w_v = _mm256_add_epi32(w_rounded_v, zp); + + /* Pack to int16_t and saturate */ + __m256i xy_packed_v = _mm256_packs_epi32(x_v, y_v); + __m256i zw_packed_v = _mm256_packs_epi32(z_v, w_v); + + __m256i xyzw_clamped_v = + pack_saturate_and_clamp(xy_packed_v, zw_packed_v, min_val, max_val); + + /* + * xyzw_clamped_v has results in the following layout so we need to + * permute: x0-3 y0-3 z0-3 w0-3 x4-7 y4-7 z4-7 w4-7 + */ + xyzw_clamped_v = _mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v); + return xyzw_clamped_v; +} + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +struct Vectorized : public Vectorizedqi { + static constexpr int kSize = VECTOR_WIDTH; + static constexpr int size() { + return kSize; + } + + static constexpr int kFloatNumVecs = kSize / Vectorized::size(); + static constexpr int float_num_vecs() { + return kFloatNumVecs; + } + + static constexpr int kIntNumVecs = kSize / Vectorized::size(); + static constexpr int int_num_vecs() { + return kIntNumVecs; + } + + using float_vec_return_type = std::array, kFloatNumVecs>; + using int_vec_return_type = std::array, kIntNumVecs>; + using value_type = c10::qint8::underlying; + + public: + using Vectorizedqi::Vectorizedqi; + + Vectorized() {} + Vectorized(__m256i vals_) { + vals = vals_; + } + + // Broadcast constructor + Vectorized(const c10::qint8& val) { + value_type uw = val.val_; + vals = _mm256_set1_epi8(uw); + } + + // This is needed because the compiler emits awful code for the default + // constructor for moving the enum + // NOLINTNEXTLINE(clang-diagnostic-deprecated-copy) + C10_CLANG_DIAGNOSTIC_PUSH() +#if C10_CLANG_HAS_WARNING("-Wdeprecated-copy") + C10_CLANG_DIAGNOSTIC_IGNORE("-Wdeprecated-copy") +#endif + Vectorized(const Vectorized& other) : Vectorizedqi(other.vals) {} + C10_CLANG_DIAGNOSTIC_POP() + + void store(void* ptr, int count = size()) const { + if (count != size()) { + memcpy(ptr, &vals, count * sizeof(value_type)); + } else { + _mm256_storeu_si256((__m256i*)ptr, vals); + } + } + + static Vectorized loadu(const void* ptr) { + return Vectorized(ptr); + } + + static Vectorized loadu(const void* ptr, int64_t count) { + __at_align__ value_type tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(value_type)); + return _mm256_loadu_si256((const __m256i*)tmp_values); + } + + private: + __m256i cvtepi8_epi32(__m128i epi8_vals) const { + return _mm256_cvtepi8_epi32(epi8_vals); + } + + public: + float_vec_return_type dequantize( + Vectorized scale, + Vectorized /*zero_point*/, + Vectorized scale_neg_zp_premul) const { + __m128i int_val0 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 0)); + __m128i int_val1 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 1)); + __m128i int_val2 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 2)); + __m128i int_val3 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 3)); + + __m256 float_val0 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val0)); + __m256 float_val1 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val1)); + __m256 float_val2 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val2)); + __m256 float_val3 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val3)); + + auto val0 = + vec::fmadd(scale, Vectorized(float_val0), scale_neg_zp_premul); + auto val1 = + vec::fmadd(scale, Vectorized(float_val1), scale_neg_zp_premul); + auto val2 = + vec::fmadd(scale, Vectorized(float_val2), scale_neg_zp_premul); + auto val3 = + vec::fmadd(scale, Vectorized(float_val3), scale_neg_zp_premul); + return {val0, val1, val2, val3}; + } + + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point) const { + __m128i int_val0 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 0)); + __m128i int_val1 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 1)); + __m128i int_val2 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 2)); + __m128i int_val3 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 3)); + + __m256 float_val0 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val0)); + __m256 float_val1 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val1)); + __m256 float_val2 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val2)); + __m256 float_val3 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val3)); + + auto val0 = (Vectorized(float_val0) - zero_point) * scale; + auto val1 = (Vectorized(float_val1) - zero_point) * scale; + auto val2 = (Vectorized(float_val2) - zero_point) * scale; + auto val3 = (Vectorized(float_val3) - zero_point) * scale; + return {val0, val1, val2, val3}; + } + + static Vectorized quantize( + const float_vec_return_type& rhs, + float /*scale*/, + int32_t zero_point, + float inverse_scale) { + auto* rhs_data = (float*)rhs.data(); + int8_t quantized_values[32]; + QuantizeAvx2( + rhs_data, quantized_values, 32, inverse_scale, zero_point); + return Vectorized::loadu(quantized_values); + } + + Vectorized maximum(Vectorized b) const { + return _mm256_max_epi8(vals, b.vals); + } + + Vectorized minimum(Vectorized b) const { + return _mm256_min_epi8(vals, b.vals); + } + + Vectorized relu(Vectorized zero_point) const { + return maximum(zero_point); + } + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) { + return _mm256_min_epi8(_mm256_max_epi8(vals, zero_point.vals), q_six.vals); + } + + int_vec_return_type widening_subtract(Vectorized b) const { + __m128i int_val0 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 0)); + __m128i int_val1 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 1)); + __m128i int_val2 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 2)); + __m128i int_val3 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 3)); + + __m256i int32_val0 = cvtepi8_epi32(int_val0); + __m256i int32_val1 = cvtepi8_epi32(int_val1); + __m256i int32_val2 = cvtepi8_epi32(int_val2); + __m256i int32_val3 = cvtepi8_epi32(int_val3); + + __m128i int_b0 = _mm_set1_epi64x(_mm256_extract_epi64(b, 0)); + __m128i int_b1 = _mm_set1_epi64x(_mm256_extract_epi64(b, 1)); + __m128i int_b2 = _mm_set1_epi64x(_mm256_extract_epi64(b, 2)); + __m128i int_b3 = _mm_set1_epi64x(_mm256_extract_epi64(b, 3)); + + __m256i int32_b0 = cvtepi8_epi32(int_b0); + __m256i int32_b1 = cvtepi8_epi32(int_b1); + __m256i int32_b2 = cvtepi8_epi32(int_b2); + __m256i int32_b3 = cvtepi8_epi32(int_b3); + + __m256i res_0 = _mm256_sub_epi32(int32_val0, int32_b0); + __m256i res_1 = _mm256_sub_epi32(int32_val1, int32_b1); + __m256i res_2 = _mm256_sub_epi32(int32_val2, int32_b2); + __m256i res_3 = _mm256_sub_epi32(int32_val3, int32_b3); + + return { + Vectorized(res_0), + Vectorized(res_1), + Vectorized(res_2), + Vectorized(res_3)}; + } + + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + __m256 multiplier_v = _mm256_set1_ps(multiplier); + __m256i zero_point_v = _mm256_set1_epi32(zero_point); + return RequantizeAvx2(inp, multiplier_v, zero_point_v); + } + + private: + // Load from memory constructor + Vectorized(const void* ptr) { + vals = _mm256_loadu_si256((const __m256i*)ptr); + } +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +struct Vectorized : public Vectorizedqi { + static constexpr int kSize = VECTOR_WIDTH; + static constexpr int size() { + return kSize; + } + + static constexpr int kFloatNumVecs = kSize / Vectorized::size(); + static constexpr int float_num_vecs() { + return kFloatNumVecs; + } + + static constexpr int kIntNumVecs = kSize / Vectorized::size(); + static constexpr int int_num_vecs() { + return kIntNumVecs; + } + + using float_vec_return_type = std::array, kFloatNumVecs>; + using int_vec_return_type = std::array, kIntNumVecs>; + using value_type = c10::quint8::underlying; + + public: + using Vectorizedqi::Vectorizedqi; + Vectorized() {} + + Vectorized(__m256i vals_) { + vals = vals_; + } + + // Broadcast constructor + Vectorized(const c10::quint8& val) { + value_type uw = val.val_; + vals = _mm256_set1_epi8(uw); + } + + // NOLINTNEXTLINE(clang-diagnostic-deprecated-copy) + C10_CLANG_DIAGNOSTIC_PUSH() +#if C10_CLANG_HAS_WARNING("-Wdeprecated-copy") + C10_CLANG_DIAGNOSTIC_IGNORE("-Wdeprecated-copy") +#endif + Vectorized(const Vectorized& other) : Vectorizedqi(other.vals) {} + C10_CLANG_DIAGNOSTIC_POP() + + void store(void* ptr, int count = size()) const { + if (count != size()) { + memcpy(ptr, &vals, count * sizeof(value_type)); + } else { + _mm256_storeu_si256((__m256i*)ptr, vals); + } + } + + static Vectorized loadu(const void* ptr) { + return Vectorized(ptr); + } + + static Vectorized loadu(const void* ptr, int64_t count) { + __at_align__ value_type tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(value_type)); + return _mm256_loadu_si256((const __m256i*)tmp_values); + } + + private: + __m256i cvtepu8_epi32(__m128i epu8_vals) const { + return _mm256_cvtepu8_epi32(epu8_vals); + } + + public: + float_vec_return_type dequantize( + Vectorized scale, + Vectorized /*zero_point*/, + Vectorized scale_zp_premul) const { + __m128i int_val0 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 0)); + __m128i int_val1 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 1)); + __m128i int_val2 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 2)); + __m128i int_val3 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 3)); + + __m256 float_val0 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val0)); + __m256 float_val1 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val1)); + __m256 float_val2 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val2)); + __m256 float_val3 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val3)); + + auto val0 = + vec::fmadd(scale, Vectorized(float_val0), scale_zp_premul); + auto val1 = + vec::fmadd(scale, Vectorized(float_val1), scale_zp_premul); + auto val2 = + vec::fmadd(scale, Vectorized(float_val2), scale_zp_premul); + auto val3 = + vec::fmadd(scale, Vectorized(float_val3), scale_zp_premul); + return {val0, val1, val2, val3}; + } + + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point) const { + __m128i int_val0 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 0)); + __m128i int_val1 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 1)); + __m128i int_val2 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 2)); + __m128i int_val3 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 3)); + + __m256 float_val0 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val0)); + __m256 float_val1 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val1)); + __m256 float_val2 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val2)); + __m256 float_val3 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val3)); + + auto val0 = (Vectorized(float_val0) - zero_point) * scale; + auto val1 = (Vectorized(float_val1) - zero_point) * scale; + auto val2 = (Vectorized(float_val2) - zero_point) * scale; + auto val3 = (Vectorized(float_val3) - zero_point) * scale; + return {val0, val1, val2, val3}; + } + + static Vectorized quantize( + const float_vec_return_type& rhs, + float /*scale*/, + int32_t zero_point, + float inverse_scale) { + auto* rhs_data = (float*)rhs.data(); + uint8_t quantized_values[32]; + QuantizeAvx2( + rhs_data, quantized_values, 32, inverse_scale, zero_point); + return Vectorized::loadu(quantized_values); + } + + Vectorized maximum(Vectorized b) const { + return _mm256_max_epu8(vals, b.vals); + } + + Vectorized minimum(Vectorized b) const { + return _mm256_min_epu8(vals, b.vals); + } + + Vectorized relu(Vectorized zero_point) const { + return maximum(zero_point); + } + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) { + return _mm256_min_epu8(_mm256_max_epu8(vals, zero_point.vals), q_six.vals); + } + + int_vec_return_type widening_subtract(Vectorized b) const { + __m128i int_val0 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 0)); + __m128i int_val1 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 1)); + __m128i int_val2 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 2)); + __m128i int_val3 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 3)); + + __m256i int32_val0 = cvtepu8_epi32(int_val0); + __m256i int32_val1 = cvtepu8_epi32(int_val1); + __m256i int32_val2 = cvtepu8_epi32(int_val2); + __m256i int32_val3 = cvtepu8_epi32(int_val3); + + __m128i int_b0 = _mm_set1_epi64x(_mm256_extract_epi64(b, 0)); + __m128i int_b1 = _mm_set1_epi64x(_mm256_extract_epi64(b, 1)); + __m128i int_b2 = _mm_set1_epi64x(_mm256_extract_epi64(b, 2)); + __m128i int_b3 = _mm_set1_epi64x(_mm256_extract_epi64(b, 3)); + + __m256i int32_b0 = cvtepu8_epi32(int_b0); + __m256i int32_b1 = cvtepu8_epi32(int_b1); + __m256i int32_b2 = cvtepu8_epi32(int_b2); + __m256i int32_b3 = cvtepu8_epi32(int_b3); + + __m256i res_0 = _mm256_sub_epi32(int32_val0, int32_b0); + __m256i res_1 = _mm256_sub_epi32(int32_val1, int32_b1); + __m256i res_2 = _mm256_sub_epi32(int32_val2, int32_b2); + __m256i res_3 = _mm256_sub_epi32(int32_val3, int32_b3); + return { + Vectorized(res_0), + Vectorized(res_1), + Vectorized(res_2), + Vectorized(res_3)}; + } + + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + __m256 multiplier_v = _mm256_set1_ps(multiplier); + __m256i zero_point_v = _mm256_set1_epi32(zero_point); + return RequantizeAvx2(inp, multiplier_v, zero_point_v); + } + + private: + // Load from memory constructor + Vectorized(const void* ptr) { + vals = _mm256_loadu_si256((const __m256i*)ptr); + } +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +#elif !defined(CPU_CAPABILITY_SVE256) + +// NOTE: These are low-performance implementations that we fall back on +// if we are not building with AVX2. This may not be an issue, because +// currently for quantization we assume the user has at least AVX512 +// installed, so these can simply act as a reference implementation. +// +// If in the future we relax this requirement (AVX2+), we should probably +// revisit these implementations + +template < + typename T, + typename float_vec_return_type_, + typename int_vec_return_type_, + int size_> +struct VectorizedQuantizedConverter { + static constexpr int size() { + return size_; + } + + static constexpr int float_num_vecs() { + return size_ / Vectorized::size(); + } + + static constexpr int int_num_vecs() { + return size_ / Vectorized::size(); + } + + using float_vec_return_type = float_vec_return_type_; + using int_vec_return_type = int_vec_return_type_; + + using value_type = typename T::underlying; + std::array vals; + + VectorizedQuantizedConverter(T val) { + for (const auto i : c10::irange(size())) { + vals[i] = val.val_; + } + } + + VectorizedQuantizedConverter(const void* ptr) { + memcpy(vals.data(), ptr, sizeof(value_type) * size()); + } + + void store(void* ptr, int count = size()) const { + memcpy(ptr, vals.data(), count * sizeof(value_type)); + } + + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point, + Vectorized /*scale_zp_premul*/) const { + float_vec_return_type rv; + for (const auto i : c10::irange(float_num_vecs())) { + float tmp_vals[Vectorized::size()]; + for (const auto j : c10::irange(Vectorized::size())) { + tmp_vals[j] = at::native::dequantize_val( + scale[j], + zero_point[j], + T(vals[Vectorized::size() * i + j])); + } + rv[i] = Vectorized(tmp_vals); + } + return rv; + } + + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point) const { + Vectorized scale_zp_premul; + return dequantize(scale, zero_point, scale_zp_premul); + } + + protected: + VectorizedQuantizedConverter() {} +}; + +template <> +struct Vectorized : public VectorizedQuantizedConverter< + c10::qint32, + std::array, 1>, + std::array, 1>, + Vectorized::size()> { + using VectorizedQuantizedConverter::VectorizedQuantizedConverter; + + static Vectorized loadu(const void* ptr) { + return Vectorized(ptr); + } + + static Vectorized loadu(const void* ptr, int64_t count) { + __at_align__ value_type tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(value_type)); + return Vectorized(tmp_values); + } + + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float /*inverse_scale*/) { + std::array qvals; + std::array::size()> float_vals; + + for (const auto i : c10::irange(float_num_vecs())) { + rhs[i].store(&float_vals[i * Vectorized::size()]); + } + + at::native::quantize_vec( + scale, + zero_point, + float_vals.data(), + (c10::qint32*)qvals.data(), + float_vals.size()); + + return Vectorized::loadu(qvals.data()); + } + + Vectorized maximum(Vectorized b) const { + Vectorized retval; + for (const auto i : c10::irange(size())) { + retval.vals[i] = std::max(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized minimum(Vectorized b) const { + Vectorized retval; + for (const auto i : c10::irange(size())) { + retval.vals[i] = std::min(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized relu(Vectorized zero_point) const { + return maximum(zero_point); + } + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) { + Vectorized retval; + for (const auto i : c10::irange(size())) { + retval.vals[i] = std::min( + std::max(vals[i], zero_point.vals[i]), q_six.vals[i]); + } + return retval; + } + + int_vec_return_type widening_subtract(Vectorized b) const { + int_vec_return_type retval; + for (const auto i : c10::irange(size())) { + retval[0].vals[i] = vals[i] - b.vals[i]; + } + return retval; + } + + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + Vectorized retval; + for (const auto i : c10::irange(size())) { + retval.vals[i] = + std::nearbyint(static_cast(inp[0].vals[i]) * multiplier) + + zero_point; + } + return retval; + } +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + Vectorized retval; + for (const auto i : c10::irange(std::decay_t::size())) { + retval.vals[i] = a.vals[i] * b.vals[i]; + } + return retval; +} + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + Vectorized retval; + for (const auto i : c10::irange(std::decay_t::size())) { + retval.vals[i] = a.vals[i] + b.vals[i]; + } + return retval; +} + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +struct Vectorized : public VectorizedQuantizedConverter< + c10::qint8, + std::array, 4>, + std::array, 4>, + 4 * Vectorized::size()> { + using VectorizedQuantizedConverter::VectorizedQuantizedConverter; + + static Vectorized loadu(const void* ptr) { + return Vectorized(ptr); + } + + static Vectorized loadu(const void* ptr, int64_t count) { + __at_align__ value_type tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(value_type)); + return Vectorized(tmp_values); + } + + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float /*inverse_scale*/) { + std::array qvals; + std::array::size()> float_vals; + + for (const auto i : c10::irange(float_num_vecs())) { + rhs[i].store(&float_vals[i * Vectorized::size()]); + } + + at::native::quantize_vec( + scale, + zero_point, + float_vals.data(), + (c10::qint8*)qvals.data(), + float_vals.size()); + + return Vectorized::loadu(qvals.data()); + } + + Vectorized maximum(Vectorized b) const { + Vectorized retval; + for (const auto i : c10::irange(size())) { + retval.vals[i] = std::max(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized minimum(Vectorized b) const { + Vectorized retval; + for (const auto i : c10::irange(size())) { + retval.vals[i] = std::min(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized relu(Vectorized zero_point) const { + return maximum(zero_point); + } + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) { + Vectorized retval; + for (const auto i : c10::irange(size())) { + retval.vals[i] = std::min( + std::max(vals[i], zero_point.vals[i]), q_six.vals[i]); + } + return retval; + } + + int_vec_return_type widening_subtract(Vectorized b) const { + int_vec_return_type retval; + constexpr int elem_per_int_vec = size() / int_num_vecs(); + for (const auto i : c10::irange(int_num_vecs())) { + for (const auto j : c10::irange(elem_per_int_vec)) { + retval[i].vals[j] = + static_cast(vals[i * elem_per_int_vec + j]) - + static_cast(b.vals[i * elem_per_int_vec + j]); + } + } + return retval; + } + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + constexpr int elem_per_int_vec = size() / int_num_vecs(); + constexpr auto min_val = std::numeric_limits::min(); + constexpr auto max_val = std::numeric_limits::max(); + Vectorized retval; + for (const auto i : c10::irange(int_num_vecs())) { + for (const auto j : c10::irange(elem_per_int_vec)) { + int32_t rounded = + std::nearbyint(static_cast(inp[i].vals[j]) * multiplier) + + zero_point; + retval.vals[i * elem_per_int_vec + j] = + std::min(std::max(rounded, min_val), max_val); + } + } + return retval; + } +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +struct Vectorized : public VectorizedQuantizedConverter< + c10::quint8, + std::array, 4>, + std::array, 4>, + 4 * Vectorized::size()> { + using VectorizedQuantizedConverter::VectorizedQuantizedConverter; + + static Vectorized loadu(const void* ptr) { + return Vectorized(ptr); + } + + static Vectorized loadu(const void* ptr, int64_t count) { + __at_align__ value_type tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(value_type)); + return Vectorized(tmp_values); + } + + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float /*inverse_scale*/) { + std::array qvals; + std::array::size()> float_vals; + + for (const auto i : c10::irange(float_num_vecs())) { + rhs[i].store(&float_vals[i * Vectorized::size()]); + } + + at::native::quantize_vec( + scale, + zero_point, + float_vals.data(), + (c10::quint8*)qvals.data(), + float_vals.size()); + + return Vectorized::loadu(qvals.data()); + } + + Vectorized maximum(Vectorized b) const { + Vectorized retval; + for (const auto i : c10::irange(size())) { + retval.vals[i] = std::max(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized minimum(Vectorized b) const { + Vectorized retval; + for (const auto i : c10::irange(size())) { + retval.vals[i] = std::min(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized relu(Vectorized zero_point) const { + return maximum(zero_point); + } + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) { + Vectorized retval; + for (const auto i : c10::irange(size())) { + retval.vals[i] = std::min( + std::max(vals[i], zero_point.vals[i]), q_six.vals[i]); + } + return retval; + } + + int_vec_return_type widening_subtract(Vectorized b) const { + int_vec_return_type retval; + constexpr int elem_per_int_vec = size() / int_num_vecs(); + for (const auto i : c10::irange(int_num_vecs())) { + for (const auto j : c10::irange(elem_per_int_vec)) { + retval[i].vals[j] = + static_cast(vals[i * elem_per_int_vec + j]) - + static_cast(b.vals[i * elem_per_int_vec + j]); + } + } + return retval; + } + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + constexpr int elem_per_int_vec = size() / int_num_vecs(); + constexpr auto min_val = std::numeric_limits::min(); + constexpr auto max_val = std::numeric_limits::max(); + Vectorized retval; + for (const auto i : c10::irange(int_num_vecs())) { + for (const auto j : c10::irange(elem_per_int_vec)) { + int32_t rounded = + std::nearbyint(static_cast(inp[i].vals[j]) * multiplier) + + zero_point; + retval.vals[i * elem_per_int_vec + j] = + std::min(std::max(rounded, min_val), max_val); + } + } + return retval; + } +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +#endif // if defined(CPU_CAPABILITY_AVX2) + +#if (defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256)) +std::pair, Vectorized> inline convert_int8_to_float( + at::vec::Vectorized src) { + auto s8x8 = vget_low_s8(src); + auto s16x8 = vmovl_s8(s8x8); + + auto s32x4_hi = vmovl_s16(vget_high_s16(s16x8)); + auto s32x4_lo = vmovl_s16(vget_low_s16(s16x8)); + + return std::make_pair( + Vectorized(vcvtq_f32_s32(s32x4_lo)), + Vectorized(vcvtq_f32_s32(s32x4_hi))); +} + +std::pair, Vectorized> inline convert_int8_to_float( + at::vec::Vectorized src) { + auto u8x8 = vget_low_u8(src); + auto u16x8 = vmovl_u8(u8x8); + auto u32x4_hi = vmovl_u16(vget_high_u16(u16x8)); + auto u32x4_lo = vmovl_u16(vget_low_u16(u16x8)); + + return std::make_pair( + Vectorized(vcvtq_f32_u32(u32x4_lo)), + Vectorized(vcvtq_f32_u32(u32x4_hi))); +} + +Vectorized inline convert_int8_half_register_to_float( + at::vec::Vectorized src) { + auto s8x8 = vget_low_s8(src); + auto s16x8 = vmovl_s8(s8x8); + + auto s32x4_lo = vmovl_s16(vget_low_s16(s16x8)); + + return Vectorized(vcvtq_f32_s32(s32x4_lo)); +} + +Vectorized inline convert_int8_half_register_to_float( + at::vec::Vectorized src) { + auto u8x8 = vget_low_u8(src); + auto u16x8 = vmovl_u8(u8x8); + auto u32x4_lo = vmovl_u16(vget_low_u16(u16x8)); + + return Vectorized(vcvtq_f32_u32(u32x4_lo)); +} + +#endif +} // namespace CPU_CAPABILITY +} // namespace at::vec + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_bfloat16_vsx.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_bfloat16_vsx.h new file mode 100644 index 0000000000000000000000000000000000000000..a2cba8d412f2b1f8c5ba60d77d9a42c1ed0639b0 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_bfloat16_vsx.h @@ -0,0 +1,80 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include + +namespace at { +namespace vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +inline std::tuple, Vectorized> convert_bfloat16_float( + const Vectorized& a) { + constexpr int64_t K = Vectorized::size(); + __at_align__ float arr[K]; + __at_align__ BFloat16 arr2[K]; + a.store(arr2); + convert(arr2, arr, K); + return std::make_tuple( + Vectorized::loadu(arr), + Vectorized::loadu(arr + Vectorized::size())); +} + +inline Vectorized convert_float_bfloat16( + const Vectorized& a, + const Vectorized& b) { + constexpr int64_t K = Vectorized::size(); + __at_align__ float arr[K]; + __at_align__ BFloat16 arr2[K]; + a.store(arr); + b.store(arr + Vectorized::size()); + convert(arr, arr2, K); + return Vectorized::loadu(arr2); +} + +inline void load_fp32_from_bf16( + const c10::BFloat16* data, + Vectorized& out) { + __at_align__ float values[Vectorized::size()]; + for (const auto k : c10::irange(Vectorized::size())) { + values[k] = data[k]; + } + out = Vectorized::loadu(values); +} + +inline void load_fp32_from_bf16( + const c10::BFloat16* data, + Vectorized& out1, + Vectorized& out2) { + load_fp32_from_bf16(data, out1); + data += Vectorized::size(); + load_fp32_from_bf16(data, out2); +} + +inline void load_fp32_from_fp16(const c10::Half* data, Vectorized& out) { + __at_align__ float values[Vectorized::size()]; + for (const auto k : c10::irange(Vectorized::size())) { + values[k] = data[k]; + } + out = Vectorized::loadu(values); +} + +inline void load_fp32_from_fp16( + const c10::Half* data, + Vectorized& out1, + Vectorized& out2) { + load_fp32_from_fp16(data, out1); + data += Vectorized::size(); + load_fp32_from_fp16(data, out2); +} + +} // namespace CPU_CAPABILITY +} // namespace vec +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_common_vsx.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_common_vsx.h new file mode 100644 index 0000000000000000000000000000000000000000..849f75c2854a361c936288792495f3b6ae0af801 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_common_vsx.h @@ -0,0 +1,255 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include + +// Note: header order is important here +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include + +namespace at { +namespace vec { + +inline namespace CPU_CAPABILITY { + +DEFINE_CLAMP_FUNCS(c10::quint8) +DEFINE_CLAMP_FUNCS(c10::qint8) +DEFINE_CLAMP_FUNCS(c10::qint32) +DEFINE_CLAMP_FUNCS(int16_t) +DEFINE_CLAMP_FUNCS(int32_t) +DEFINE_CLAMP_FUNCS(int64_t) +DEFINE_CLAMP_FUNCS(float) +DEFINE_CLAMP_FUNCS(double) + +template <> +Vectorized C10_ALWAYS_INLINE fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return Vectorized{ + vec_madd(a.vec0(), b.vec0(), c.vec0()), + vec_madd(a.vec1(), b.vec1(), c.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return Vectorized{ + a.vec0() * b.vec0() + c.vec0(), a.vec1() * b.vec1() + c.vec1()}; +} +template <> +Vectorized C10_ALWAYS_INLINE fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return Vectorized{ + a.vec0() * b.vec0() + c.vec0(), a.vec1() * b.vec1() + c.vec1()}; +} +template <> +Vectorized C10_ALWAYS_INLINE fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return Vectorized{ + a.vec0() * b.vec0() + c.vec0(), a.vec1() * b.vec1() + c.vec1()}; +} + +DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(float) +DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(double) +DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(int64_t) +DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(int32_t) +DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(int16_t) + +template <> +Vectorized C10_ALWAYS_INLINE +convert_to_int_of_same_size(const Vectorized& src) { + return Vectorized{vec_signed(src.vec0()), vec_signed(src.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +convert_to_int_of_same_size(const Vectorized& src) { + return Vectorized{vec_signed(src.vec0()), vec_signed(src.vec1())}; +} + +template <> +inline void convert(const int32_t* src, float* dst, int64_t n) { + // int32_t and float have same size + int64_t i; + for (i = 0; i <= (n - Vectorized::size()); + i += Vectorized::size()) { + const int32_t* src_a = src + i; + float* dst_a = dst + i; + vint32 input_vec0 = + vec_vsx_ld(offset0, reinterpret_cast(src_a)); + vint32 input_vec1 = + vec_vsx_ld(offset16, reinterpret_cast(src_a)); + vfloat32 c0 = vec_float(input_vec0); + vfloat32 c1 = vec_float(input_vec1); + vec_vsx_st(c0, offset0, dst_a); + vec_vsx_st(c1, offset16, dst_a); + } + + for (; i < n; i++) { + dst[i] = static_cast(src[i]); + } +} + +template <> +inline void convert(const int64_t* src, double* dst, int64_t n) { + int64_t i; + for (i = 0; i <= (n - Vectorized::size()); + i += Vectorized::size()) { + const int64_t* src_a = src + i; + double* dst_a = dst + i; + vint64 input_vec0 = + vec_vsx_ld(offset0, reinterpret_cast(src_a)); + vint64 input_vec1 = + vec_vsx_ld(offset16, reinterpret_cast(src_a)); + vfloat64 c0 = vec_double(input_vec0); + vfloat64 c1 = vec_double(input_vec1); + vec_vsx_st(c0, offset0, reinterpret_cast(dst_a)); + vec_vsx_st(c1, offset16, reinterpret_cast(dst_a)); + } + for (; i < n; i++) { + dst[i] = static_cast(src[i]); + } +} +// Generic implementation to fix compiler error +// TO-DO : Add optimized version for ppc64 +inline std::tuple, Vectorized> convert_half_float( + const Vectorized& a) { + constexpr int64_t K = Vectorized::size(); + __at_align__ float arr[K]; + __at_align__ Half arr2[K]; + a.store(arr2); + convert(arr2, arr, K); + return std::make_tuple( + Vectorized::loadu(arr), + Vectorized::loadu(arr + Vectorized::size())); +} + +inline Vectorized convert_float_half( + const Vectorized& a, + const Vectorized& b) { + constexpr int64_t K = Vectorized::size(); + __at_align__ float arr[K]; + __at_align__ Half arr2[K]; + a.store(arr); + b.store(arr + Vectorized::size()); + convert(arr, arr2, K); + return Vectorized::loadu(arr2); +}; + +template <> +std::pair, Vectorized> inline interleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, a1, a2, a3} + // b = {b0, b1, b2, b3} + + vfloat64 ab00 = vec_xxpermdi(a.vec0(), b.vec0(), 0); + vfloat64 ab11 = vec_xxpermdi(a.vec0(), b.vec0(), 3); + vfloat64 ab2_00 = vec_xxpermdi(a.vec1(), b.vec1(), 0); + vfloat64 ab2_11 = vec_xxpermdi(a.vec1(), b.vec1(), 3); + // return {a0, b0, a1, b1} + // {a2, b2, a3, b3} + return std::make_pair( + Vectorized{ab00, ab11}, Vectorized{ab2_00, ab2_11}); +} + +template <> +std::pair, Vectorized> inline deinterleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, b0, a1, b1} + // b = {a2, b2, a3, b3} + vfloat64 aa01 = vec_xxpermdi(a.vec0(), a.vec1(), 0); + vfloat64 aa23 = vec_xxpermdi(b.vec0(), b.vec1(), 0); + + vfloat64 bb_01 = vec_xxpermdi(a.vec0(), a.vec1(), 3); + vfloat64 bb_23 = vec_xxpermdi(b.vec0(), b.vec1(), 3); + + // swap lanes: + // return {a0, a1, a2, a3} + // {b0, b1, b2, b3} + return std::make_pair( + Vectorized{aa01, aa23}, Vectorized{bb_01, bb_23}); +} + +template <> +std::pair, Vectorized> inline interleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, a1, a2, a3,, a4, a5, a6, a7} + // b = {b0, b1, b2, b3,, b4, b5, b6, b7} + + vfloat32 ab0011 = vec_mergeh(a.vec0(), b.vec0()); + vfloat32 ab2233 = vec_mergel(a.vec0(), b.vec0()); + + vfloat32 ab2_0011 = vec_mergeh(a.vec1(), b.vec1()); + vfloat32 ab2_2233 = vec_mergel(a.vec1(), b.vec1()); + // group cols crossing lanes: + // return {a0, b0, a1, b1,, a2, b2, a3, b3} + // {a4, b4, a5, b5,, a6, b6, a7, b7} + + return std::make_pair( + Vectorized{ab0011, ab2233}, Vectorized{ab2_0011, ab2_2233}); +} + +template <> +std::pair, Vectorized> inline deinterleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, b0, a1, b1,, a2, b2, a3, b3} + // b = {a4, b4, a5, b5,, a6, b6, a7, b7} + + // {a0,a2,b0,b2} {a1,a3,b1,b3} + vfloat32 a0a2b0b2 = vec_mergeh(a.vec0(), a.vec1()); + vfloat32 a1a3b1b3 = vec_mergel(a.vec0(), a.vec1()); + + vfloat32 aa0123 = vec_mergeh(a0a2b0b2, a1a3b1b3); + vfloat32 bb0123 = vec_mergel(a0a2b0b2, a1a3b1b3); + + vfloat32 a0a2b0b2_2 = vec_mergeh(b.vec0(), b.vec1()); + vfloat32 a1a3b1b3_2 = vec_mergel(b.vec0(), b.vec1()); + + vfloat32 aa0123_2 = vec_mergeh(a0a2b0b2_2, a1a3b1b3_2); + vfloat32 bb0123_2 = vec_mergel(a0a2b0b2_2, a1a3b1b3_2); + + // it could be done with vec_perm ,too + // swap lanes: + // return {a0, a1, a2, a3,, a4, a5, a6, a7} + // {b0, b1, b2, b3,, b4, b5, b6, b7} + + return std::make_pair( + Vectorized{aa0123, aa0123_2}, Vectorized{bb0123, bb0123_2}); +} + +} // namespace CPU_CAPABILITY +} // namespace vec +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_complex_double_vsx.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_complex_double_vsx.h new file mode 100644 index 0000000000000000000000000000000000000000..6cc03ca753ae4817b50565c03c732ba3b763a973 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_complex_double_vsx.h @@ -0,0 +1,684 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include +#include +#include +#include +#include + +namespace at { +namespace vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { +using ComplexDbl = c10::complex; + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized { + union { + struct { + vfloat64 _vec0; + vfloat64 _vec1; + }; + struct { + vbool64 _vecb0; + vbool64 _vecb1; + }; + + } __attribute__((__may_alias__)); + + public: + using value_type = ComplexDbl; + using vec_internal_type = vfloat64; + using vec_internal_mask_type = vbool64; + using size_type = int; + static constexpr size_type size() { + return 2; + } + Vectorized() {} + C10_ALWAYS_INLINE Vectorized(vfloat64 v) : _vec0{v}, _vec1{v} {} + C10_ALWAYS_INLINE Vectorized(vbool64 vmask) : _vecb0{vmask}, _vecb1{vmask} {} + C10_ALWAYS_INLINE Vectorized(vfloat64 v1, vfloat64 v2) + : _vec0{v1}, _vec1{v2} {} + C10_ALWAYS_INLINE Vectorized(vbool64 v1, vbool64 v2) + : _vecb0{v1}, _vecb1{v2} {} + + Vectorized(ComplexDbl val) { + double real_value = val.real(); + double imag_value = val.imag(); + _vec0 = vfloat64{real_value, imag_value}; + _vec1 = vfloat64{real_value, imag_value}; + } + Vectorized(ComplexDbl val1, ComplexDbl val2) { + _vec0 = vfloat64{val1.real(), val1.imag()}; + _vec1 = vfloat64{val2.real(), val2.imag()}; + } + + C10_ALWAYS_INLINE const vec_internal_type& vec0() const { + return _vec0; + } + C10_ALWAYS_INLINE const vec_internal_type& vec1() const { + return _vec1; + } + + template + static std:: + enable_if_t> + C10_ALWAYS_INLINE blend( + const Vectorized& a, + const Vectorized& b) { + return a; + } + + template + static std:: + enable_if_t> + C10_ALWAYS_INLINE blend( + const Vectorized& a, + const Vectorized& b) { + return b; + } + + template + static std:: + enable_if_t> + C10_ALWAYS_INLINE blend( + const Vectorized& a, + const Vectorized& b) { + return {b._vec0, a._vec1}; + } + + template + static std:: + enable_if_t> + C10_ALWAYS_INLINE blend( + const Vectorized& a, + const Vectorized& b) { + return {a._vec0, b._vec1}; + } + + template + static Vectorized C10_ALWAYS_INLINE + el_blend(const Vectorized& a, const Vectorized& b) { + const vbool64 mask_1st = VsxDblMask1(mask); + const vbool64 mask_2nd = VsxDblMask2(mask); + return { + (vfloat64)vec_sel(a._vec0, b._vec0, mask_1st), + (vfloat64)vec_sel(a._vec1, b._vec1, mask_2nd)}; + } + + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + // convert std::complex index mask to V index mask: xy -> xxyy + auto mask_complex = Vectorized( + vec_splat(mask._vec0, 0), vec_splat(mask._vec1, 0)); + return { + vec_sel(a._vec0, b._vec0, mask_complex._vecb0), + vec_sel(a._vec1, b._vec1, mask_complex._vecb1)}; + } + + static Vectorized C10_ALWAYS_INLINE elwise_blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + return { + vec_sel(a._vec0, b._vec0, mask._vecb0), + vec_sel(a._vec1, b._vec1, mask._vecb1)}; + } + template + static Vectorized arange( + ComplexDbl base = 0., + step_t step = static_cast(1)) { + return Vectorized(base, base + step); + } + static Vectorized set( + const Vectorized& a, + const Vectorized& b, + int64_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + } + return b; + } + + static Vectorized C10_ALWAYS_INLINE + loadu(const void* ptr, int count = size()) { + if (count == size()) { + return { + vec_vsx_ld(offset0, reinterpret_cast(ptr)), + vec_vsx_ld(offset16, reinterpret_cast(ptr))}; + } + + __at_align__ value_type tmp_values[size()] = {}; + std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); + + return { + vec_vsx_ld(offset0, reinterpret_cast(tmp_values)), + vec_vsx_ld(offset16, reinterpret_cast(tmp_values))}; + } + void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const { + if (count == size()) { + vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr)); + vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr)); + } else if (count > 0) { + __at_align__ value_type tmp_values[size()]; + vec_vsx_st(_vec0, offset0, reinterpret_cast(tmp_values)); + vec_vsx_st(_vec1, offset16, reinterpret_cast(tmp_values)); + std::memcpy( + ptr, tmp_values, std::min(count, size()) * sizeof(value_type)); + } + } + + const ComplexDbl& operator[](int idx) const = delete; + ComplexDbl& operator[](int idx) = delete; + + Vectorized map(ComplexDbl (*const f)(ComplexDbl)) const { + __at_align__ ComplexDbl tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + + Vectorized map(ComplexDbl (*const f)(const ComplexDbl&)) const { + __at_align__ ComplexDbl tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + + Vectorized el_swapped() const { + vfloat64 v0 = vec_xxpermdi(_vec0, _vec0, 2); + vfloat64 v1 = vec_xxpermdi(_vec1, _vec1, 2); + return {v0, v1}; + } + + Vectorized el_madd( + const Vectorized& multiplier, + const Vectorized& val) const { + return { + vec_madd(_vec0, multiplier._vec0, val._vec0), + vec_madd(_vec1, multiplier._vec1, val._vec1)}; + } + + Vectorized el_mergeo() const { + vfloat64 v0 = vec_splat(_vec0, 1); + vfloat64 v1 = vec_splat(_vec1, 1); + return {v0, v1}; + } + + Vectorized el_mergee() const { + vfloat64 v0 = vec_splat(_vec0, 0); + vfloat64 v1 = vec_splat(_vec1, 0); + return {v0, v1}; + } + + static Vectorized el_mergee( + const Vectorized& first, + const Vectorized& second) { + return { + vec_mergeh(first._vec0, second._vec0), + vec_mergeh(first._vec1, second._vec1)}; + } + + static Vectorized el_mergeo( + const Vectorized& first, + const Vectorized& second) { + return { + vec_mergel(first._vec0, second._vec0), + vec_mergel(first._vec1, second._vec1)}; + } + + Vectorized abs_2_() const { + auto a = (*this).elwise_mult(*this); + auto permuted = a.el_swapped(); + a = a + permuted; + return a; + } + + Vectorized abs_() const { + auto vi = el_mergeo(); + auto vr = el_mergee(); + return { + Sleef_hypotd2_u05vsx(vr._vec0, vi._vec0), + Sleef_hypotd2_u05vsx(vr._vec1, vi._vec1)}; + } + + Vectorized abs() const { + return abs_() & vd_real_mask; + } + + Vectorized angle_() const { + // angle = atan2(b/a) + // auto b_a = _mm256_permute_pd(values, 0x05); // b a + // return Sleef_atan2d4_u10(values, b_a); // 90-angle angle + Vectorized ret; + ret._vec0[0] = std::atan2(_vec0[1], _vec0[0]); + ret._vec1[0] = std::atan2(_vec1[1], _vec1[0]); + return ret; + } + + Vectorized angle() const { + return angle_() & vd_real_mask; + } + + Vectorized real_() const { + return *this & vd_real_mask; + } + Vectorized real() const { + return *this & vd_real_mask; + } + Vectorized imag_() const { + return *this & vd_imag_mask; + } + Vectorized imag() const { + return imag_().el_swapped(); + } + + Vectorized conj_() const { + return *this ^ vd_isign_mask; + } + Vectorized conj() const { + return *this ^ vd_isign_mask; + } + + Vectorized log() const { + // Most trigonomic ops use the log() op to improve complex number + // performance. + return map(std::log); + } + + Vectorized log2() const { + // log2eB_inv + auto ret = log(); + return ret.elwise_mult(vd_log2e_inv); + } + Vectorized log10() const { + auto ret = log(); + return ret.elwise_mult(vd_log10e_inv); + } + + Vectorized log1p() const { + return map(std::log1p); + } + + Vectorized asin() const { + // asin(x) + // = -i*ln(iz + sqrt(1 -z^2)) + // = -i*ln((ai - b) + sqrt(1 - (a + bi)*(a + bi))) + // = -i*ln((-b + ai) + sqrt(1 - (a**2 - b**2) - 2*abi)) + auto conj = conj_(); + auto b_a = conj.el_swapped(); + auto ab = conj.elwise_mult(b_a); + auto im = ab + ab; + auto val_2 = (*this).elwise_mult(*this); + auto val_2_swapped = val_2.el_swapped(); + auto re = horizontal_sub(val_2, val_2_swapped); + re = Vectorized(vd_one) - re; + auto root = el_blend<0x0A>(re, im).sqrt(); + auto ln = (b_a + root).log(); + return ln.el_swapped().conj(); + } + + Vectorized acos() const { + // acos(x) = pi/2 - asin(x) + return Vectorized(vd_pi_2) - asin(); + } + + Vectorized atan() const { + // atan(x) = i/2 * ln((i + z)/(i - z)) + auto ione = Vectorized(vd_imag_one); + auto sum = ione + *this; + auto sub = ione - *this; + auto ln = (sum / sub).log(); // ln((i + z)/(i - z)) + return ln * vd_imag_half; // i/2*ln() + } + Vectorized atanh() const { + return map(std::atanh); + } + + Vectorized sin() const { + return map(std::sin); + } + Vectorized sinh() const { + return map(std::sinh); + } + Vectorized cos() const { + return map(std::cos); + } + Vectorized cosh() const { + return map(std::cosh); + } + + Vectorized tan() const { + return map(std::tan); + } + Vectorized tanh() const { + return map(std::tanh); + } + Vectorized ceil() const { + return {vec_ceil(_vec0), vec_ceil(_vec1)}; + } + Vectorized floor() const { + return {vec_floor(_vec0), vec_floor(_vec1)}; + } + Vectorized neg() const { + auto z = Vectorized(vd_zero); + return z - *this; + } + Vectorized round() const { + return {vec_rint(_vec0), vec_rint(_vec1)}; + } + + Vectorized trunc() const { + return {vec_trunc(_vec0), vec_trunc(_vec1)}; + } + + Vectorized elwise_sqrt() const { + return {vec_sqrt(_vec0), vec_sqrt(_vec1)}; + } + + Vectorized sqrt() const { + return map(std::sqrt); + } + + Vectorized reciprocal() const { + // re + im*i = (a + bi) / (c + di) + // re = (ac + bd)/abs_2() = c/abs_2() + // im = (bc - ad)/abs_2() = d/abs_2() + auto c_d = *this ^ vd_isign_mask; // c -d + auto abs = abs_2_(); + return c_d.elwise_div(abs); + } + + Vectorized rsqrt() const { + return sqrt().reciprocal(); + } + + static Vectorized horizontal_add( + Vectorized& first, + Vectorized& second) { + // Operates on individual floats, see _mm_hadd_ps + // {f0+f1, s0+s1, f2+f3, s2+s3, ...} + // i.e. it sums the re and im of each value and interleaves first and + // second: {f_re0 + f_im0, s_re0 + s_im0, f_re1 + f_im1, s_re1 + s_im1, ...} + return el_mergee(first, second) + el_mergeo(first, second); + } + + static Vectorized horizontal_sub( + Vectorized& first, + Vectorized& second) { + // we will simulate it differently with 6 instructions total + // lets permute second so that we can add it getting horizontal sums + auto first_perm = first.el_swapped(); // 2perm + auto second_perm = second.el_swapped(); // 2perm + // summ + auto first_ret = first - first_perm; // 2sub + auto second_ret = second - second_perm; // 2 sub + // now lets choose evens + return el_mergee(first_ret, second_ret); // 2 mergee's + } + + Vectorized inline operator*( + const Vectorized& b) const { + //(a + bi) * (c + di) = (ac - bd) + (ad + bc)i +#if 1 + // this is more vsx friendly than simulating horizontal from x86 + auto vi = b.el_mergeo(); + auto vr = b.el_mergee(); + vi = vi ^ vd_rsign_mask; + auto ret = elwise_mult(vr); + auto vx_swapped = el_swapped(); + ret = vx_swapped.elwise_mult(vi) + ret; +#else + auto ac_bd = elwise_mult(b); + auto d_c = b.el_swapped(); + d_c = d_c ^ vd_isign_mask; + auto ad_bc = elwise_mult(d_c); + auto ret = horizontal_sub(ac_bd, ad_bc); +#endif + return ret; + } + + Vectorized inline operator/( + const Vectorized& b) const { + // re + im*i = (a + bi) / (c + di) + // re = (ac + bd)/abs_2() + // im = (bc - ad)/abs_2() + // auto fabs_cd = Vectorized{ + // vec_andc(b._vec0, vd_sign_mask), + // vec_andc(b._vec1, vd_sign_mask)}; // |c| |d| + // auto fabs_dc = fabs_cd.el_swapped(); // |d| |c| + // auto scale = fabs_cd.elwise_max(fabs_dc); // sc = max(|c|, |d|) + // auto a2 = elwise_div(scale); // a/sc b/sc + // auto b2 = b.elwise_div(scale); // c/sc d/sc + // auto acbd2 = a2.elwise_mult(b2); // ac/sc^2 bd/sc^2 + // auto dc2 = b2.el_swapped(); // d/sc c/sc + // dc2 = dc2 ^ vd_rsign_mask; // -d/sc c/sc + // auto adbc2 = a2.elwise_mult(dc2); // -ad/sc^2 bc/sc^2 + // auto ret = horizontal_add(acbd2, adbc2); // (ac+bd)/sc^2 (bc-ad)/sc^2 + // auto denom2 = b2.abs_2_(); // (c^2+d^2)/sc^2 + // (c^2+d^2)/sc^2 ret = ret.elwise_div(denom2); return ret; + + __at_align__ c10::complex + tmp1[Vectorized>::size()]; + __at_align__ c10::complex + tmp2[Vectorized>::size()]; + __at_align__ c10::complex + out[Vectorized>::size()]; + this->store(tmp1); + b.store(tmp2); + + for (const auto i : c10::irange(Vectorized>::size())) { + out[i] = tmp1[i] / tmp2[i]; + } + return loadu(out); + } + + Vectorized exp() const { + return map(std::exp); + } + Vectorized exp2() const { + return map(exp2_impl); + } + Vectorized expm1() const { + return map(std::expm1); + } + + Vectorized pow(const Vectorized& exp) const { + __at_align__ ComplexDbl x_tmp[size()]; + __at_align__ ComplexDbl y_tmp[size()]; + store(x_tmp); + exp.store(y_tmp); + for (const auto i : c10::irange(size())) { + x_tmp[i] = std::pow(x_tmp[i], y_tmp[i]); + } + return loadu(x_tmp); + } + + Vectorized sgn() const { + return map(at::native::sgn_impl); + } + + Vectorized operator<(const Vectorized& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized operator<=(const Vectorized& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized operator>(const Vectorized& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized operator>=(const Vectorized& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vectorized eq(const Vectorized& other) const { + auto eq = (*this == other); // compares real and imag individually + // If both real numbers and imag numbers are equal, then the complex numbers + // are equal + return (eq.real() & eq.imag()) & vd_one; + } + Vectorized ne(const Vectorized& other) const { + auto ne = (*this != other); // compares real and imag individually + // If either real numbers or imag numbers are not equal, then the complex + // numbers are not equal + return (ne.real() | ne.imag()) & vd_one; + } + + DEFINE_MEMBER_OP(operator==, ComplexDbl, vec_cmpeq) + DEFINE_MEMBER_OP(operator!=, ComplexDbl, vec_cmpne) + + DEFINE_MEMBER_OP(operator+, ComplexDbl, vec_add) + DEFINE_MEMBER_OP(operator-, ComplexDbl, vec_sub) + DEFINE_MEMBER_OP(operator&, ComplexDbl, vec_and) + DEFINE_MEMBER_OP(operator|, ComplexDbl, vec_or) + DEFINE_MEMBER_OP(operator^, ComplexDbl, vec_xor) + // elementwise helpers + DEFINE_MEMBER_OP(elwise_mult, ComplexDbl, vec_mul) + DEFINE_MEMBER_OP(elwise_div, ComplexDbl, vec_div) + DEFINE_MEMBER_OP(elwise_gt, ComplexDbl, vec_cmpgt) + DEFINE_MEMBER_OP(elwise_ge, ComplexDbl, vec_cmpge) + DEFINE_MEMBER_OP(elwise_lt, ComplexDbl, vec_cmplt) + DEFINE_MEMBER_OP(elwise_le, ComplexDbl, vec_cmple) + DEFINE_MEMBER_OP(elwise_max, ComplexDbl, vec_max) +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + auto abs_a = a.abs_2_(); + auto abs_b = b.abs_2_(); + // auto mask = _mm256_cmp_ps(abs_a, abs_b, _CMP_LT_OQ); + // auto max = _mm256_blendv_ps(a, b, mask); + auto mask = abs_a.elwise_lt(abs_b); + auto max = Vectorized::elwise_blendv(a, b, mask); + + return max; + // Exploit the fact that all-ones is a NaN. + // auto isnan = _mm256_cmp_ps(abs_a, abs_b, _CMP_UNORD_Q); + // return _mm256_or_ps(max, isnan); +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + auto abs_a = a.abs_2_(); + auto abs_b = b.abs_2_(); + // auto mask = _mm256_cmp_ps(abs_a, abs_b, _CMP_GT_OQ); + // auto min = _mm256_blendv_ps(a, b, mask); + auto mask = abs_a.elwise_gt(abs_b); + auto min = Vectorized::elwise_blendv(a, b, mask); + return min; + // Exploit the fact that all-ones is a NaN. + // auto isnan = _mm256_cmp_ps(abs_a, abs_b, _CMP_UNORD_Q); + // return _mm256_or_ps(min, isnan); +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator+(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_add(a.vec0(), b.vec0()), vec_add(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator-(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_sub(a.vec0(), b.vec0()), vec_sub(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator&(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_and(a.vec0(), b.vec0()), vec_and(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator|(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_or(a.vec0(), b.vec0()), vec_or(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator^(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_xor(a.vec0(), b.vec0()), vec_xor(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator*(const Vectorized& a, const Vectorized& b) { + // (a + ib) * (c + id) = (ac - bd) + i(ad + bc) + // Split into real and imaginary parts + auto a_real = a.el_mergee(); // real part of a + auto a_imag = a.el_mergeo(); // imag part of a + auto b_real = b.el_mergee(); // real part of b + auto b_imag = b.el_mergeo(); // imag part of b + + // Compute components + auto ac = a_real.elwise_mult(b_real); // real*real + auto bd = a_imag.elwise_mult(b_imag); // imag*imag + + // Real part: ac - bd + auto real = ac - bd; + + auto ad = a_real.elwise_mult(b_imag); // real*imag + auto bc = a_imag.elwise_mult(b_real); // imag*real + + // Imag = ad + bc + auto imag = ad + bc; + + // Merge real and imaginary parts into vectors + __vector double v0 = vec_mergeh(real.vec0(), imag.vec0()); // [r0, i0] + __vector double v1 = vec_mergeh(real.vec1(), imag.vec1()); // [r1, i1] + + // Create the final result + auto result = Vectorized{v0, v1}; + return result; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator/(const Vectorized& a, const Vectorized& b) { + // re + im*i = (a + bi) / (c + di) + // re = (ac + bd)/abs_2() + // im = (bc - ad)/abs_2() + // Take absolute values of real and imaginary parts of b + __at_align__ c10::complex + tmp1[Vectorized>::size()]; + __at_align__ c10::complex + tmp2[Vectorized>::size()]; + __at_align__ c10::complex + out[Vectorized>::size()]; + a.store(tmp1); + b.store(tmp2); + for (const auto i : c10::irange(Vectorized>::size())) { + out[i] = tmp1[i] / tmp2[i]; + } + return Vectorized::loadu(out); +} + +} // namespace CPU_CAPABILITY +} // namespace vec +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_complex_float_vsx.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_complex_float_vsx.h new file mode 100644 index 0000000000000000000000000000000000000000..ebeab3693c288277f434948d6e9a805e5b188cf0 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_complex_float_vsx.h @@ -0,0 +1,776 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) + +#pragma once +#include +#include +#include +#include +#include + +namespace at { +namespace vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { +using ComplexFlt = c10::complex; + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized { + private: + union { + struct { + vfloat32 _vec0; + vfloat32 _vec1; + }; + struct { + vbool32 _vecb0; + vbool32 _vecb1; + }; + + } __attribute__((__may_alias__)); + + public: + using value_type = ComplexFlt; + using vec_internal_type = vfloat32; + using vec_internal_mask_type = vbool32; + using size_type = int; + + static constexpr size_type size() { + return 4; + } + Vectorized() {} + + C10_ALWAYS_INLINE Vectorized(vfloat32 v) : _vec0{v}, _vec1{v} {} + C10_ALWAYS_INLINE Vectorized(vbool32 vmask) : _vecb0{vmask}, _vecb1{vmask} {} + C10_ALWAYS_INLINE Vectorized(vfloat32 v1, vfloat32 v2) + : _vec0{v1}, _vec1{v2} {} + C10_ALWAYS_INLINE Vectorized(vbool32 v1, vbool32 v2) + : _vecb0{v1}, _vecb1{v2} {} + + Vectorized(ComplexFlt val) { + float real_value = val.real(); + float imag_value = val.imag(); + _vec0 = vfloat32{real_value, imag_value, real_value, imag_value}; + _vec1 = vfloat32{real_value, imag_value, real_value, imag_value}; + } + + Vectorized( + ComplexFlt val1, + ComplexFlt val2, + ComplexFlt val3, + ComplexFlt val4) { + _vec0 = vfloat32{val1.real(), val1.imag(), val2.real(), val2.imag()}; + _vec1 = vfloat32{val3.real(), val3.imag(), val4.real(), val4.imag()}; + } + + C10_ALWAYS_INLINE const vec_internal_type& vec0() const { + return _vec0; + } + C10_ALWAYS_INLINE const vec_internal_type& vec1() const { + return _vec1; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + return a; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + return b; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + return {b._vec0, a._vec1}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + return {a._vec0, b._vec1}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + const vbool32 mask_1st = VsxComplexMask1(mask); + return {(vfloat32)vec_sel(a._vec0, b._vec0, mask_1st), a._vec1}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + const vbool32 mask_1st = VsxComplexMask1(mask); + return {(vfloat32)vec_sel(a._vec0, b._vec0, mask_1st), b._vec1}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + const vbool32 mask_2nd = VsxComplexMask2(mask); + // generated masks + return {a._vec0, (vfloat32)vec_sel(a._vec1, b._vec1, mask_2nd)}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + const vbool32 mask_2nd = VsxComplexMask2(mask); + // generated masks + return {b._vec0, (vfloat32)vec_sel(a._vec1, b._vec1, mask_2nd)}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + const vbool32 mask_1st = VsxComplexMask1(mask); + const vbool32 mask_2nd = VsxComplexMask2(mask); + return { + (vfloat32)vec_sel(a._vec0, b._vec0, mask_1st), + (vfloat32)vec_sel(a._vec1, b._vec1, mask_2nd)}; + } + + template + static Vectorized C10_ALWAYS_INLINE + el_blend(const Vectorized& a, const Vectorized& b) { + const vbool32 mask_1st = VsxMask1(mask); + const vbool32 mask_2nd = VsxMask2(mask); + return { + (vfloat32)vec_sel(a._vec0, b._vec0, mask_1st), + (vfloat32)vec_sel(a._vec1, b._vec1, mask_2nd)}; + } + + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + // convert std::complex index mask to V index mask: xy -> xxyy + auto mask_complex = Vectorized( + vec_mergeh(mask._vec0, mask._vec0), vec_mergeh(mask._vec1, mask._vec1)); + return { + vec_sel( + a._vec0, b._vec0, reinterpret_cast(mask_complex._vec0)), + vec_sel( + a._vec1, b._vec1, reinterpret_cast(mask_complex._vec1)), + }; + } + + static Vectorized elwise_blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + return { + vec_sel(a._vec0, b._vec0, reinterpret_cast(mask._vec0)), + vec_sel(a._vec1, b._vec1, reinterpret_cast(mask._vec1)), + }; + } + + template + static Vectorized arange( + ComplexFlt base = 0., + step_t step = static_cast(1)) { + return Vectorized( + base, + base + step, + base + ComplexFlt(2) * step, + base + ComplexFlt(3) * step); + } + static Vectorized set( + const Vectorized& a, + const Vectorized& b, + int64_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + } + return b; + } + + static Vectorized C10_ALWAYS_INLINE + loadu(const void* ptr, int count = size()) { + if (count == size()) { + return { + vec_vsx_ld(offset0, reinterpret_cast(ptr)), + vec_vsx_ld(offset16, reinterpret_cast(ptr))}; + } + + __at_align__ value_type tmp_values[size()] = {}; + std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); + + return { + vec_vsx_ld(offset0, reinterpret_cast(tmp_values)), + vec_vsx_ld(offset16, reinterpret_cast(tmp_values))}; + } + + void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const { + if (count == size()) { + vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr)); + vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr)); + } else if (count > 0) { + __at_align__ value_type tmp_values[size()]; + vec_vsx_st(_vec0, offset0, reinterpret_cast(tmp_values)); + vec_vsx_st(_vec1, offset16, reinterpret_cast(tmp_values)); + std::memcpy( + ptr, tmp_values, std::min(count, size()) * sizeof(value_type)); + } + } + + const ComplexFlt& operator[](int idx) const = delete; + ComplexFlt& operator[](int idx) = delete; + + Vectorized map(ComplexFlt (*const f)(ComplexFlt)) const { + __at_align__ ComplexFlt tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + + Vectorized map(ComplexFlt (*const f)(const ComplexFlt&)) const { + __at_align__ ComplexFlt tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + + static Vectorized horizontal_add( + Vectorized& first, + Vectorized& second) { + // Operates on individual floats, see _mm_hadd_ps + // {f0+f1, s0+s1, f2+f3, s2+s3, ...} + // i.e. it sums the re and im of each value and interleaves first and + // second: {f_re0 + f_im0, s_re0 + s_im0, f_re1 + f_im1, s_re1 + s_im1, ...} + return el_mergee(first, second) + el_mergeo(first, second); + } + + static Vectorized horizontal_sub_permD8( + Vectorized& first, + Vectorized& second) { + // we will simulate it differently with 6 instructions total + // lets permute second so that we can add it getting horizontal sums + auto first_perm = first.el_swapped(); // 2perm + auto second_perm = second.el_swapped(); // 2perm + // sum + auto first_ret = first - first_perm; // 2sub + auto second_ret = second - second_perm; // 2 sub + // now lets choose evens + return el_mergee(first_ret, second_ret); // 2 mergee's + } + + Vectorized abs_2_() const { + auto a = (*this).elwise_mult(*this); + auto permuted = a.el_swapped(); + a = a + permuted; + return a.el_mergee(); + } + + Vectorized abs_() const { + auto vi = el_mergeo(); + auto vr = el_mergee(); + return { + Sleef_hypotf4_u05vsx(vr._vec0, vi._vec0), + Sleef_hypotf4_u05vsx(vr._vec1, vi._vec1)}; + } + + Vectorized abs() const { + return abs_() & real_mask; + } + + Vectorized real_() const { + return *this & real_mask; + } + Vectorized real() const { + return *this & real_mask; + } + Vectorized imag_() const { + return *this & imag_mask; + } + Vectorized imag() const { + // we can use swap_mask or sldwi + auto ret = imag_(); + return { + vec_sldw(ret._vec0, ret._vec0, 3), vec_sldw(ret._vec1, ret._vec1, 3)}; + } + + Vectorized conj_() const { + return *this ^ isign_mask; + } + Vectorized conj() const { + return *this ^ isign_mask; + } + + Vectorized log() const { + // Most trigonomic ops use the log() op to improve complex number + // performance. + return map(std::log); + } + + Vectorized log2() const { + // log2eB_inv + auto ret = log(); + return ret.elwise_mult(log2e_inv); + } + Vectorized log10() const { + auto ret = log(); + return ret.elwise_mult(log10e_inv); + } + + Vectorized log1p() const { + return map(std::log1p); + } + + Vectorized el_swapped() const { + vfloat32 v0 = vec_perm(_vec0, _vec0, swap_mask); + vfloat32 v1 = vec_perm(_vec1, _vec1, swap_mask); + return {v0, v1}; + } + + Vectorized el_mergee() const { + // as mergee phased in , we can use vec_perm with mask + return {vec_mergee(_vecb0, _vecb0), vec_mergee(_vecb1, _vecb1)}; + } + + Vectorized el_mergeo() const { + // as mergeo phased in , we can use vec_perm with mask + return {vec_mergeo(_vecb0, _vecb0), vec_mergeo(_vecb1, _vecb1)}; + } + + Vectorized el_madd( + const Vectorized& multiplier, + const Vectorized& val) const { + return { + vec_madd(_vec0, multiplier._vec0, val._vec0), + vec_madd(_vec1, multiplier._vec1, val._vec1)}; + } + + static Vectorized el_mergee( + const Vectorized& first, + const Vectorized& second) { + return { + vec_mergee(first._vecb0, second._vecb0), + vec_mergee(first._vecb1, second._vecb1)}; + } + + static Vectorized el_mergeo( + const Vectorized& first, + const Vectorized& second) { + return { + vec_mergeo(first._vecb0, second._vecb0), + vec_mergeo(first._vecb1, second._vecb1)}; + } + + Vectorized angle_() const { + // angle = atan2(b/a) + // auto b_a = _mm256_permute_ps(values, 0xB1); // b a + // return Sleef_atan2f8_u10(values, b_a); // 90-angle angle + Vectorized ret; + for (int i = 0; i < 4; i += 2) { + ret._vec0[i] = std::atan2(_vec0[i + 1], _vec0[i]); + ret._vec1[i] = std::atan2(_vec1[i + 1], _vec1[i]); + } + return ret; + } + + Vectorized angle() const { + return angle_() & real_mask; + } + + Vectorized sin() const { + return map(std::sin); + } + Vectorized sinh() const { + return map(std::sinh); + } + Vectorized cos() const { + return map(std::cos); + } + Vectorized cosh() const { + return map(std::cosh); + } + Vectorized ceil() const { + return {vec_ceil(_vec0), vec_ceil(_vec1)}; + } + Vectorized floor() const { + return {vec_floor(_vec0), vec_floor(_vec1)}; + } + Vectorized neg() const { + auto z = Vectorized(zero); + return z - *this; + } + Vectorized round() const { + return {vec_round(_vec0), vec_round(_vec1)}; + } + Vectorized tan() const { + return map(std::tan); + } + Vectorized tanh() const { + return map(std::tanh); + } + Vectorized trunc() const { + return {vec_trunc(_vec0), vec_trunc(_vec1)}; + } + + Vectorized elwise_sqrt() const { + return {vec_sqrt(_vec0), vec_sqrt(_vec1)}; + } + + Vectorized sqrt() const { + return map(std::sqrt); + } + + Vectorized reciprocal() const { + // re + im*i = (a + bi) / (c + di) + // re = (ac + bd)/abs_2() = c/abs_2() + // im = (bc - ad)/abs_2() = d/abs_2() + auto c_d = *this ^ isign_mask; // c -d + auto abs = abs_2_(); + return c_d.elwise_div(abs); + } + + Vectorized rsqrt() const { + return sqrt().reciprocal(); + } + + Vectorized pow(const Vectorized& exp) const { + __at_align__ ComplexFlt x_tmp[size()]; + __at_align__ ComplexFlt y_tmp[size()]; + store(x_tmp); + exp.store(y_tmp); + for (const auto i : c10::irange(size())) { + x_tmp[i] = std::pow(x_tmp[i], y_tmp[i]); + } + return loadu(x_tmp); + } + + Vectorized atan() const { + // atan(x) = i/2 * ln((i + z)/(i - z)) + auto ione = Vectorized(imag_one); + auto sum = ione + *this; + auto sub = ione - *this; + auto ln = (sum / sub).log(); // ln((i + z)/(i - z)) + return ln * imag_half; // i/2*ln() + } + Vectorized atanh() const { + return map(std::atanh); + } + + Vectorized acos() const { + // acos(x) = pi/2 - asin(x) + return Vectorized(pi_2) - asin(); + } + + Vectorized inline operator*( + const Vectorized& b) const { + //(a + bi) * (c + di) = (ac - bd) + (ad + bc)i + +#if 1 + // this is more vsx friendly than simulating horizontal from x86 + + auto vi = b.el_mergeo(); + auto vr = b.el_mergee(); + vi = vi ^ rsign_mask; + auto ret = elwise_mult(vr); + auto vx_swapped = el_swapped(); + ret = vx_swapped.elwise_mult(vi) + ret; + return ret; + +#else + + auto ac_bd = elwise_mult(b); + auto d_c = b.el_swapped(); + d_c = d_c ^ isign_mask; + auto ad_bc = elwise_mult(d_c); + auto ret = horizontal_sub_permD8(ac_bd, ad_bc); + return ret; +#endif + } + + Vectorized inline operator/( + const Vectorized& b) const { +#if 1 + __at_align__ c10::complex + tmp1[Vectorized>::size()]; + __at_align__ c10::complex + tmp2[Vectorized>::size()]; + __at_align__ c10::complex + out[Vectorized>::size()]; + this->store(tmp1); + b.store(tmp2); + + for (const auto i : c10::irange(Vectorized>::size())) { + out[i] = tmp1[i] / tmp2[i]; + } + return loadu(out); +#else + auto fabs_cd = Vectorized{ + vec_andc(b._vec0, sign_mask), vec_andc(b._vec1, sign_mask)}; // |c| |d| + auto fabs_dc = fabs_cd.el_swapped(); // |d| |c| + auto scale = fabs_cd.elwise_max(fabs_dc); // sc = max(|c|, |d|) + auto a2 = elwise_div(scale); // a/sc b/sc + auto b2 = b.elwise_div(scale); // c/sc d/sc + auto acbd2 = a2.elwise_mult(b2); // ac/sc^2 bd/s + auto dc2 = b2.el_swapped(); // d/sc c/sc + dc2 = dc2 ^ rsign_mask; // -d/sc c/sc + auto adbc2 = a2.elwise_mult(dc2); // -ad/sc^2 bc/sc^2 + auto ret = horizontal_add(acbd2, adbc2); // (ac+bd)/sc^2 (bc-ad)/sc^2 + auto denom2 = b2.abs_2_(); // (c^2+d^2)/sc^2 (c^2+d^2)/sc^2 + ret = ret.elwise_div(denom2); + return ret; +#endif + } + + Vectorized asin() const { + // asin(x) + // = -i*ln(iz + sqrt(1 -z^2)) + // = -i*ln((ai - b) + sqrt(1 - (a + bi)*(a + bi))) + // = -i*ln((-b + ai) + sqrt(1 - (a**2 - b**2) - 2*abi)) + +#if 1 + auto conj = conj_(); + auto b_a = conj.el_swapped(); + auto ab = conj.elwise_mult(b_a); + auto im = ab + ab; + auto val_2 = (*this).elwise_mult(*this); + auto val_2_swapped = val_2.el_swapped(); + auto re = horizontal_sub_permD8(val_2, val_2_swapped); + re = Vectorized(one) - re; + auto root = el_blend<0xAA>(re, im).sqrt(); + auto ln = (b_a + root).log(); + return ln.el_swapped().conj(); +#else + return map(std::asin); +#endif + } + + Vectorized exp() const { + return map(std::exp); + } + Vectorized exp2() const { + return map(exp2_impl); + } + Vectorized expm1() const { + return map(std::expm1); + } + + Vectorized eq(const Vectorized& other) const { + auto eq = (*this == other); // compares real and imag individually + // If both real numbers and imag numbers are equal, then the complex numbers + // are equal + return (eq.real() & eq.imag()) & one; + } + Vectorized ne(const Vectorized& other) const { + auto ne = (*this != other); // compares real and imag individually + // If either real numbers or imag numbers are not equal, then the complex + // numbers are not equal + return (ne.real() | ne.imag()) & one; + } + + Vectorized sgn() const { + return map(at::native::sgn_impl); + } + + Vectorized operator<(const Vectorized& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vectorized operator<=(const Vectorized& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vectorized operator>(const Vectorized& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vectorized operator>=(const Vectorized& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + DEFINE_MEMBER_OP(operator==, ComplexFlt, vec_cmpeq) + DEFINE_MEMBER_OP(operator!=, ComplexFlt, vec_cmpne) + + DEFINE_MEMBER_OP(operator+, ComplexFlt, vec_add) + DEFINE_MEMBER_OP(operator-, ComplexFlt, vec_sub) + DEFINE_MEMBER_OP(operator&, ComplexFlt, vec_and) + DEFINE_MEMBER_OP(operator|, ComplexFlt, vec_or) + DEFINE_MEMBER_OP(operator^, ComplexFlt, vec_xor) + // elementwise helpers + DEFINE_MEMBER_OP(elwise_mult, ComplexFlt, vec_mul) + DEFINE_MEMBER_OP(elwise_div, ComplexFlt, vec_div) + DEFINE_MEMBER_OP(elwise_gt, ComplexFlt, vec_cmpgt) + DEFINE_MEMBER_OP(elwise_ge, ComplexFlt, vec_cmpge) + DEFINE_MEMBER_OP(elwise_lt, ComplexFlt, vec_cmplt) + DEFINE_MEMBER_OP(elwise_le, ComplexFlt, vec_cmple) + DEFINE_MEMBER_OP(elwise_max, ComplexFlt, vec_max) +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + auto abs_a = a.abs_2_(); + auto abs_b = b.abs_2_(); + // auto mask = _mm256_cmp_ps(abs_a, abs_b, _CMP_LT_OQ); + // auto max = _mm256_blendv_ps(a, b, mask); + auto mask = abs_a.elwise_lt(abs_b); + auto max = Vectorized::elwise_blendv(a, b, mask); + + return max; + // Exploit the fact that all-ones is a NaN. + // auto isnan = _mm256_cmp_ps(abs_a, abs_b, _CMP_UNORD_Q); + // return _mm256_or_ps(max, isnan); +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + auto abs_a = a.abs_2_(); + auto abs_b = b.abs_2_(); + // auto mask = _mm256_cmp_ps(abs_a, abs_b, _CMP_GT_OQ); + // auto min = _mm256_blendv_ps(a, b, mask); + auto mask = abs_a.elwise_gt(abs_b); + auto min = Vectorized::elwise_blendv(a, b, mask); + return min; + // Exploit the fact that all-ones is a NaN. + // auto isnan = _mm256_cmp_ps(abs_a, abs_b, _CMP_UNORD_Q); + // return _mm256_or_ps(min, isnan); +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator+(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_add(a.vec0(), b.vec0()), vec_add(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator-(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_sub(a.vec0(), b.vec0()), vec_sub(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator&(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_and(a.vec0(), b.vec0()), vec_and(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator|(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_or(a.vec0(), b.vec0()), vec_or(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator^(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_xor(a.vec0(), b.vec0()), vec_xor(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator*(const Vectorized& a, const Vectorized& b) { + // (a + ib) * (c + id) = (ac - bd) + i(ad + bc) + // Split into real and imaginary parts + auto a_real = a.el_mergee(); // real part of a + auto a_imag = a.el_mergeo(); // imag part of a + auto b_real = b.el_mergee(); // real part of b + auto b_imag = b.el_mergeo(); // imag part of b + + auto b_imag_neg = b_imag ^ rsign_mask; + // Compute components + auto ac = a_real.elwise_mult(b_real); // real * real + auto bd = a_imag.elwise_mult(b_imag_neg); // imag * imag + auto ad = a_real.elwise_mult(b_imag); // real * imag + auto bc = a_imag.elwise_mult(b_real); // imag * real + + // Real = ac - bd (fix the negative bd part) + auto real = ac + bd; // Real part calculation + auto imag = ad + bc; // Imaginary part calculation + + // Step 1: Extract from real and imag + __vector float r0 = real.vec0(); // {r0, r1, r2, r3} + __vector float i0 = imag.vec0(); // {i0, i1, i2, i3} + + __vector float r1 = real.vec1(); // imag[0..3] + __vector float i1 = imag.vec1(); // imag[4..7] + + __vector unsigned char perm_lo = { + 0, + 1, + 2, + 3, // r0 + 16, + 17, + 18, + 19, // + 8, + 9, + 10, + 11, // r1 + 24, + 25, + 26, + 27}; + __vector float v0 = + vec_perm(r0, i0, perm_lo); // Interleave r0 and i0, r1 and i1 + __vector float v1 = vec_perm(r1, i1, perm_lo); + Vectorized result(v0, v1); + return result; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator/(const Vectorized& a, const Vectorized& b) { + // Take absolute values of real and imaginary parts of b + __at_align__ c10::complex + tmp1[Vectorized>::size()]; + __at_align__ c10::complex + tmp2[Vectorized>::size()]; + __at_align__ c10::complex out[Vectorized>::size()]; + a.store(tmp1); + b.store(tmp2); + for (const auto i : + c10::irange(Vectorized>:: + size())) { //{Vectorized>::size())) + //{ + out[i] = tmp1[i] / tmp2[i]; + } + return Vectorized::loadu(out); +} + +} // namespace CPU_CAPABILITY +} // namespace vec +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_double_vsx.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_double_vsx.h new file mode 100644 index 0000000000000000000000000000000000000000..63a9e5e2f1ad1328a85db5e0228b81dfd41ab215 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_double_vsx.h @@ -0,0 +1,520 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include + +#include + +namespace at { +namespace vec { + +inline namespace CPU_CAPABILITY { + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized { + private: + union { + struct { + vfloat64 _vec0; + vfloat64 _vec1; + }; + struct { + vbool64 _vecb0; + vbool64 _vecb1; + }; + + } __attribute__((__may_alias__)); + + public: + using value_type = double; + using vec_internal_type = vfloat64; + using vec_internal_mask_type = vbool64; + using size_type = int; + static constexpr size_type size() { + return 4; + } + Vectorized() {} + C10_ALWAYS_INLINE Vectorized(vfloat64 v) : _vec0{v}, _vec1{v} {} + C10_ALWAYS_INLINE Vectorized(vbool64 vmask) : _vecb0{vmask}, _vecb1{vmask} {} + C10_ALWAYS_INLINE Vectorized(vfloat64 v1, vfloat64 v2) + : _vec0{v1}, _vec1{v2} {} + C10_ALWAYS_INLINE Vectorized(vbool64 v1, vbool64 v2) + : _vecb0{v1}, _vecb1{v2} {} + C10_ALWAYS_INLINE Vectorized(double scalar) + : _vec0{vec_splats(scalar)}, _vec1{vec_splats(scalar)} {} + C10_ALWAYS_INLINE Vectorized( + double scalar1, + double scalar2, + double scalar3, + double scalar4) + : _vec0{vfloat64{scalar1, scalar2}}, _vec1{vfloat64{scalar3, scalar4}} {} + C10_ALWAYS_INLINE const vec_internal_type& vec0() const { + return _vec0; + } + C10_ALWAYS_INLINE const vec_internal_type& vec1() const { + return _vec1; + } + + int zero_mask() const { + auto cmp = (*this == vd_zero); + return (cmp._vecb0[0] & 1) | (cmp._vecb0[1] & 2) | (cmp._vecb1[0] & 4) | + (cmp._vecb1[1] & 8); + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + return a; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + return b; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + return {b._vec0, a._vec1}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + return {a._vec0, b._vec1}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + const vbool64 mask_1st = VsxDblMask1(mask); + return {(vfloat64)vec_sel(a._vec0, b._vec0, mask_1st), a._vec1}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + const vbool64 mask_1st = VsxDblMask1(mask); + return {(vfloat64)vec_sel(a._vec0, b._vec0, mask_1st), b._vec1}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + const vbool64 mask_2nd = VsxDblMask2(mask); + // generated masks + return {a._vec0, (vfloat64)vec_sel(a._vec1, b._vec1, mask_2nd)}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + const vbool64 mask_2nd = VsxDblMask2(mask); + // generated masks + return {b._vec0, (vfloat64)vec_sel(a._vec1, b._vec1, mask_2nd)}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + const vbool64 mask_1st = VsxDblMask1(mask); + const vbool64 mask_2nd = VsxDblMask2(mask); + return { + (vfloat64)vec_sel(a._vec0, b._vec0, mask_1st), + (vfloat64)vec_sel(a._vec1, b._vec1, mask_2nd)}; + } + + static Vectorized C10_ALWAYS_INLINE blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + // the mask used here returned by comparison of vec256 + + return { + vec_sel(a._vec0, b._vec0, mask._vecb0), + vec_sel(a._vec1, b._vec1, mask._vecb1)}; + } + template + static Vectorized arange( + double base = 0., + step_t step = static_cast(1)) { + return Vectorized( + base, base + step, base + 2 * step, base + 3 * step); + } + + static Vectorized C10_ALWAYS_INLINE + set(const Vectorized& a, + const Vectorized& b, + size_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + } + + return b; + } + static Vectorized C10_ALWAYS_INLINE + loadu(const void* ptr, int count = size()) { + if (count == size()) { + return { + vec_vsx_ld(offset0, reinterpret_cast(ptr)), + vec_vsx_ld(offset16, reinterpret_cast(ptr))}; + } + + __at_align__ value_type tmp_values[size()] = {}; + std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); + + return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)}; + } + void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const { + if (count == size()) { + vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr)); + vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr)); + } else if (count > 0) { + __at_align__ value_type tmp_values[size()]; + vec_vsx_st(_vec0, offset0, tmp_values); + vec_vsx_st(_vec1, offset16, tmp_values); + std::memcpy( + ptr, tmp_values, std::min(count, size()) * sizeof(value_type)); + } + } + const double& operator[](int idx) const = delete; + double& operator[](int idx) = delete; + Vectorized map(double (*const f)(double)) const { + Vectorized ret; + for (const auto i : c10::irange(size() / 2)) { + ret._vec0[i] = f(_vec0[i]); + } + for (const auto i : c10::irange(size() / 2)) { + ret._vec1[i] = f(_vec1[i]); + } + return ret; + } + + Vectorized mapbi( + double (*const f)(double, double), + const Vectorized& other) const { + Vectorized ret; + for (const auto i : c10::irange(size() / 2)) { + ret._vec0[i] = f(_vec0[i], other._vec0[i]); + } + for (const auto i : c10::irange(size() / 2)) { + ret._vec1[i] = f(_vec1[i], other._vec1[i]); + } + return ret; + } + Vectorized C10_ALWAYS_INLINE abs() const { + return {vec_abs(_vec0), vec_abs(_vec1)}; + } + + Vectorized C10_ALWAYS_INLINE acos() const { + return {Sleef_acosd2_u10(_vec0), Sleef_acosd2_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE acosh() const { + return {Sleef_acoshd2_u10(_vec0), Sleef_acoshd2_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE asin() const { + return {Sleef_asind2_u10(_vec0), Sleef_asind2_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE asinh() const { + return {Sleef_asinhd2_u10(_vec0), Sleef_asinhd2_u10(_vec1)}; + } + Vectorized atan() const { + return {Sleef_atand2_u10(_vec0), Sleef_atand2_u10(_vec1)}; + } + Vectorized atanh() const { + return {Sleef_atanhd2_u10(_vec0), Sleef_atanhd2_u10(_vec1)}; + } + Vectorized atan2(const Vectorized& b) const { + return { + Sleef_atan2d2_u10(_vec0, b._vec0), Sleef_atan2d2_u10(_vec1, b._vec1)}; + } + Vectorized copysign(const Vectorized& sign) const { + return { + Sleef_copysignd2(_vec0, sign._vec0), + Sleef_copysignd2(_vec1, sign._vec1)}; + } + Vectorized erf() const { + return {Sleef_erfd2_u10(_vec0), Sleef_erfd2_u10(_vec1)}; + } + Vectorized erfc() const { + return {Sleef_erfcd2_u15(_vec0), Sleef_erfcd2_u15(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE exp() const { + return {Sleef_expd2_u10(_vec0), Sleef_expd2_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE exp2() const { + return {Sleef_exp2d2_u10(_vec0), Sleef_exp2d2_u10(_vec1)}; + } + Vectorized expm1() const { + return {Sleef_expm1d2_u10(_vec0), Sleef_expm1d2_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE exp_u20() const { + return exp(); + } + Vectorized C10_ALWAYS_INLINE fexp_u20() const { + return exp(); + } + + Vectorized lgamma() const __ubsan_ignore_undefined__ { + return {Sleef_lgammad2_u10(_vec0), Sleef_lgammad2_u10(_vec1)}; + } + + Vectorized erfinv() const { + return map(calc_erfinv); + } + + Vectorized angle() const { + auto tmp = blendv( + Vectorized(0), + Vectorized(c10::pi), + *this < Vectorized(0)); + return blendv(tmp, *this, isnan()); + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return Vectorized{0}; + } + Vectorized conj() const { + return *this; + } + + Vectorized C10_ALWAYS_INLINE log() const { + return {Sleef_logd2_u10(_vec0), Sleef_logd2_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE log10() const { + return {Sleef_log10d2_u10(_vec0), Sleef_log10d2_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE log1p() const { + return {Sleef_log1pd2_u10(_vec0), Sleef_log1pd2_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE log2() const { + return {Sleef_log2d2_u10(_vec0), Sleef_log2d2_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE ceil() const { + return {vec_ceil(_vec0), vec_ceil(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE cos() const { + return {Sleef_cosd2_u10(_vec0), Sleef_cosd2_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE cosh() const { + return {Sleef_coshd2_u10(_vec0), Sleef_coshd2_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE floor() const { + return {vec_floor(_vec0), vec_floor(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE neg() const { + return {vec_neg(_vec0), vec_neg(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE round() const { + return {vec_rint(_vec0), vec_rint(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE sin() const { + return {Sleef_sind2_u10(_vec0), Sleef_sind2_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE sinh() const { + return {Sleef_sinhd2_u10(_vec0), Sleef_sinhd2_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE tan() const { + return {Sleef_tand2_u10(_vec0), Sleef_tand2_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE tanh() const { + return {Sleef_tanhd2_u10(_vec0), Sleef_tanhd2_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE trunc() const { + return {vec_trunc(_vec0), vec_trunc(_vec1)}; + } + + Vectorized C10_ALWAYS_INLINE frac() const { + return *this - trunc(); + } + + Vectorized C10_ALWAYS_INLINE sqrt() const { + return {vec_sqrt(_vec0), vec_sqrt(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE reciprocal() const { + return { + vec_div(vd_one, _vec0), // vec_re(_vec0) is estimated one. + vec_div(vd_one, _vec1)}; + } + Vectorized C10_ALWAYS_INLINE rsqrt() const { + return sqrt().reciprocal(); + } + + Vectorized C10_ALWAYS_INLINE pow(const Vectorized& b) const { + return {Sleef_powd2_u10(_vec0, b._vec0), Sleef_powd2_u10(_vec1, b._vec1)}; + } + Vectorized C10_ALWAYS_INLINE fmod(const Vectorized& b) const { + return {Sleef_fmodd2(_vec0, b._vec0), Sleef_fmodd2(_vec1, b._vec1)}; + } + + Vectorized hypot(const Vectorized& b) const { + return { + Sleef_hypotd2_u05(_vec0, b._vec0), Sleef_hypotd2_u05(_vec1, b._vec1)}; + } + + Vectorized nextafter(const Vectorized& b) const { + return { + Sleef_nextafterd2(_vec0, b._vec0), Sleef_nextafterd2(_vec1, b._vec1)}; + } + + Vectorized igamma(const Vectorized& x) const { + return mapbi(calc_igamma, x); + } + + Vectorized igammac(const Vectorized& x) const { + return mapbi(calc_igammac, x); + } + + Vectorized i0() const { + return map(calc_i0); + } + + Vectorized i0e() const { + return map(calc_i0e); + } + + Vectorized digamma() const { + return map(calc_digamma); + } + + Vectorized _nor() const { + return {vec_nor(_vec0, _vec0), vec_nor(_vec1, _vec1)}; + } + + Vectorized isnan() const { + auto x = *this; + auto ret = (x == x); + return ret._nor(); + } + bool has_inf_nan() const { + for (const auto i : c10::irange(size() / 2)) { + if (_isnan(_vec0[i]) || _isinf(_vec0[i])) { + return true; + } + } + for (const auto i : c10::irange(size() / 2)) { + if (_isnan(_vec1[i]) || _isinf(_vec1[i])) { + return true; + } + } + return false; + } + + DEFINE_MEMBER_OP(operator==, double, vec_cmpeq) + DEFINE_MEMBER_OP(operator!=, double, vec_cmpne) + DEFINE_MEMBER_OP(operator<, double, vec_cmplt) + DEFINE_MEMBER_OP(operator<=, double, vec_cmple) + DEFINE_MEMBER_OP(operator>, double, vec_cmpgt) + DEFINE_MEMBER_OP(operator>=, double, vec_cmpge) + DEFINE_MEMBER_OP_AND_ONE(eq, double, vec_cmpeq) + DEFINE_MEMBER_OP_AND_ONE(ne, double, vec_cmpne) + DEFINE_MEMBER_OP_AND_ONE(lt, double, vec_cmplt) + DEFINE_MEMBER_OP_AND_ONE(le, double, vec_cmple) + DEFINE_MEMBER_OP_AND_ONE(gt, double, vec_cmpgt) + DEFINE_MEMBER_OP_AND_ONE(ge, double, vec_cmpge) + DEFINE_MEMBER_OP(operator+, double, vec_add) + DEFINE_MEMBER_OP(operator-, double, vec_sub) + DEFINE_MEMBER_OP(operator*, double, vec_mul) + DEFINE_MEMBER_OP(operator/, double, vec_div) + DEFINE_MEMBER_OP(maximum, double, vec_max_nan2) + DEFINE_MEMBER_OP(minimum, double, vec_min_nan2) + DEFINE_MEMBER_OP(operator&, double, vec_and) + DEFINE_MEMBER_OP(operator|, double, vec_or) + DEFINE_MEMBER_OP(operator^, double, vec_xor) + DEFINE_MEMBER_TERNARY_OP(madd, double, vec_madd) +}; +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return a.minimum(b); +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator+(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_add(a.vec0(), b.vec0()), vec_add(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator-(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_sub(a.vec0(), b.vec0()), vec_sub(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator*(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_mul(a.vec0(), b.vec0()), vec_mul(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator/(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_div(a.vec0(), b.vec0()), vec_div(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator&(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_and(a.vec0(), b.vec0()), vec_and(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator|(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_or(a.vec0(), b.vec0()), vec_or(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator^(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_xor(a.vec0(), b.vec0()), vec_xor(a.vec1(), b.vec1())}; +} + +} // namespace CPU_CAPABILITY +} // namespace vec +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_float_vsx.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_float_vsx.h new file mode 100644 index 0000000000000000000000000000000000000000..f26ea32fe0b1e8d2ab91149b28b002ceadfa1f3a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_float_vsx.h @@ -0,0 +1,553 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +namespace at { +namespace vec { +// See Note [CPU_CAPABILITY namespace] + +inline namespace CPU_CAPABILITY { + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized { + private: + union { + struct { + vfloat32 _vec0; + vfloat32 _vec1; + }; + struct { + vbool32 _vecb0; + vbool32 _vecb1; + }; + + } __attribute__((__may_alias__)); + + public: + using value_type = float; + using vec_internal_type = vfloat32; + using vec_internal_mask_type = vbool32; + using size_type = int; + + static constexpr size_type size() { + return 8; + } + Vectorized() {} + + C10_ALWAYS_INLINE Vectorized(vfloat32 v) : _vec0{v}, _vec1{v} {} + C10_ALWAYS_INLINE Vectorized(vbool32 vmask) : _vecb0{vmask}, _vecb1{vmask} {} + C10_ALWAYS_INLINE Vectorized(vfloat32 v1, vfloat32 v2) + : _vec0{v1}, _vec1{v2} {} + C10_ALWAYS_INLINE Vectorized(vbool32 v1, vbool32 v2) + : _vecb0{v1}, _vecb1{v2} {} + C10_ALWAYS_INLINE Vectorized(float scalar) + : _vec0{vec_splats(scalar)}, _vec1{vec_splats(scalar)} {} + C10_ALWAYS_INLINE Vectorized( + float scalar1, + float scalar2, + float scalar3, + float scalar4, + float scalar5, + float scalar6, + float scalar7, + float scalar8) + : _vec0{vfloat32{scalar1, scalar2, scalar3, scalar4}}, + _vec1{vfloat32{scalar5, scalar6, scalar7, scalar8}} {} + C10_ALWAYS_INLINE const vec_internal_type& vec0() const { + return _vec0; + } + C10_ALWAYS_INLINE const vec_internal_type& vec1() const { + return _vec1; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + return a; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + return b; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + return {b._vec0, a._vec1}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + return {a._vec0, b._vec1}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + const vbool32 mask_1st = VsxMask1(mask); + return {(vfloat32)vec_sel(a._vec0, b._vec0, mask_1st), a._vec1}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + const vbool32 mask_1st = VsxMask1(mask); + return {(vfloat32)vec_sel(a._vec0, b._vec0, mask_1st), b._vec1}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + const vbool32 mask_2nd = VsxMask2(mask); + // generated masks + return {a._vec0, (vfloat32)vec_sel(a._vec1, b._vec1, mask_2nd)}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + const vbool32 mask_2nd = VsxMask2(mask); + // generated masks + return {b._vec0, (vfloat32)vec_sel(a._vec1, b._vec1, mask_2nd)}; + } + + template + static std::enable_if_t> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + const vbool32 mask_1st = VsxMask1(mask); + const vbool32 mask_2nd = VsxMask2(mask); + return { + (vfloat32)vec_sel(a._vec0, b._vec0, mask_1st), + (vfloat32)vec_sel(a._vec1, b._vec1, mask_2nd)}; + } + + static Vectorized C10_ALWAYS_INLINE blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + // the mask used here returned by comparison of vec256 + // assuming this we can use the same mask directly with vec_sel + return { + vec_sel(a._vec0, b._vec0, mask._vecb0), + vec_sel(a._vec1, b._vec1, mask._vecb1)}; + } + + template + static Vectorized arange( + float base = 0.f, + step_t step = static_cast(1)) { + return Vectorized( + base, + base + step, + base + 2 * step, + base + 3 * step, + base + 4 * step, + base + 5 * step, + base + 6 * step, + base + 7 * step); + } + static Vectorized set( + const Vectorized& a, + const Vectorized& b, + size_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + case 4: + return blend<15>(a, b); + case 5: + return blend<31>(a, b); + case 6: + return blend<63>(a, b); + case 7: + return blend<127>(a, b); + } + + return b; + } + static Vectorized C10_ALWAYS_INLINE + loadu(const void* ptr, int count = size()) { + if (count == size()) { + return { + vec_vsx_ld(offset0, reinterpret_cast(ptr)), + vec_vsx_ld(offset16, reinterpret_cast(ptr))}; + } + + __at_align__ value_type tmp_values[size()] = {}; + std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); + + return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)}; + } + void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const { + if (count == size()) { + vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr)); + vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr)); + } else if (count > 0) { + __at_align__ value_type tmp_values[size()]; + vec_vsx_st(_vec0, offset0, tmp_values); + vec_vsx_st(_vec1, offset16, tmp_values); + std::memcpy( + ptr, tmp_values, std::min(count, size()) * sizeof(value_type)); + } + } + + const float& operator[](int idx) const = delete; + float& operator[](int idx) = delete; + + Vectorized map(float (*const f)(float)) const { + Vectorized ret; + for (int i = 0; i < size() / 2; i++) { + ret._vec0[i] = f(_vec0[i]); + } + for (int i = 0; i < size() / 2; i++) { + ret._vec1[i] = f(_vec1[i]); + } + return ret; + } + + Vectorized mapbi( + float (*const f)(float, float), + const Vectorized& other) const { + Vectorized ret; + for (int i = 0; i < size() / 2; i++) { + ret._vec0[i] = f(_vec0[i], other._vec0[i]); + } + for (int i = 0; i < size() / 2; i++) { + ret._vec1[i] = f(_vec1[i], other._vec1[i]); + } + return ret; + } + + Vectorized _nor() const { + return {vec_nor(_vec0, _vec0), vec_nor(_vec1, _vec1)}; + } + + Vectorized isnan() const { + auto x = *this; + auto ret = (x == x); + return ret._nor(); + } + + bool has_inf_nan() const { + for (const auto i : c10::irange(size() / 2)) { + if (_isnan(_vec0[i]) || _isinf(_vec0[i])) { + return true; + } + } + for (const auto i : c10::irange(size() / 2)) { + if (_isnan(_vec1[i]) || _isinf(_vec1[i])) { + return true; + } + } + return false; + } + + int zero_mask() const { + // returns an integer mask where all zero elements are translated to 1-bit + // and others are translated to 0-bit + //__m256 cmp = _mm256_cmp_ps(values, _mm256_set1_ps(0.0f), _CMP_EQ_OQ); + auto cmp = (*this == zero); + // return _mm256_movemask_ps(cmp); + // possible simulation //mask= lvsl ( 0 ) vbpermq( vec, mask <<5) + vuint64 result0 = vec_vbpermq((vuint8)cmp._vecb0, mask_zero_bits); + vuint64 result1 = vec_vbpermq((vuint8)cmp._vecb1, mask_zero_bits); + return (result0[1] >> 12 | (result1[1] >> 8)); + } + + Vectorized C10_ALWAYS_INLINE abs() const { + return {vec_abs(_vec0), vec_abs(_vec1)}; + } + + Vectorized C10_ALWAYS_INLINE acos() const { + return {Sleef_acosf4_u10(_vec0), Sleef_acosf4_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE acosh() const { + return {Sleef_acoshf4_u10(_vec0), Sleef_acoshf4_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE asin() const { + return {Sleef_asinf4_u10(_vec0), Sleef_asinf4_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE asinh() const { + return {Sleef_asinhf4_u10(_vec0), Sleef_asinhf4_u10(_vec1)}; + } + Vectorized atan() const { + return {Sleef_atanf4_u10(_vec0), Sleef_atanf4_u10(_vec1)}; + } + Vectorized atanh() const { + return {Sleef_atanhf4_u10(_vec0), Sleef_atanhf4_u10(_vec1)}; + } + Vectorized atan2(const Vectorized& b) const { + return { + Sleef_atan2f4_u10(_vec0, b._vec0), Sleef_atan2f4_u10(_vec1, b._vec1)}; + } + Vectorized copysign(const Vectorized& sign) const { + return { + Sleef_copysignf4(_vec0, sign._vec0), + Sleef_copysignf4(_vec1, sign._vec1)}; + } + Vectorized lgamma() const { + return {Sleef_lgammaf4_u10(_vec0), Sleef_lgammaf4_u10(_vec1)}; + } + Vectorized erf() const { + return {Sleef_erff4_u10(_vec0), Sleef_erff4_u10(_vec1)}; + } + + Vectorized erfc() const { + return {Sleef_erfcf4_u15(_vec0), Sleef_erfcf4_u15(_vec1)}; + } + + Vectorized erfinv() const { + return map(calc_erfinv); + } + + Vectorized angle() const { + auto tmp = blendv( + Vectorized(0), + Vectorized(c10::pi), + *this < Vectorized(0)); + return blendv(tmp, *this, isnan()); + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return Vectorized{0}; + } + Vectorized conj() const { + return *this; + } + + Vectorized C10_ALWAYS_INLINE exp() const { + return {Sleef_expf4_u10(_vec0), Sleef_expf4_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE exp2() const { + return {Sleef_exp2f4_u10(_vec0), Sleef_exp2f4_u10(_vec1)}; + } + Vectorized expm1() const { + return {Sleef_expm1f4_u10(_vec0), Sleef_expm1f4_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE exp_u20() const { + return exp(); + } + Vectorized C10_ALWAYS_INLINE fexp_u20() const { + return exp(); + } + + Vectorized C10_ALWAYS_INLINE log() const { + return {Sleef_logf4_u10(_vec0), Sleef_logf4_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE log10() const { + return {Sleef_log10f4_u10(_vec0), Sleef_log10f4_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE log1p() const { + return {Sleef_log1pf4_u10(_vec0), Sleef_log1pf4_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE log2() const { + return {Sleef_log2f4_u10(_vec0), Sleef_log2f4_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE ceil() const { + return {vec_ceil(_vec0), vec_ceil(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE cos() const { + return {Sleef_cosf4_u10(_vec0), Sleef_cosf4_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE cosh() const { + return {Sleef_coshf4_u10(_vec0), Sleef_coshf4_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE floor() const { + return {vec_floor(_vec0), vec_floor(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE neg() const { + return {vec_neg(_vec0), vec_neg(_vec1)}; + } + + Vectorized C10_ALWAYS_INLINE round() const { + return {vec_round(_vec0), vec_round(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE sin() const { + return {Sleef_sinf4_u10(_vec0), Sleef_sinf4_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE sinh() const { + return {Sleef_sinhf4_u10(_vec0), Sleef_sinhf4_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE tan() const { + return {Sleef_tanf4_u10(_vec0), Sleef_tanf4_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE tanh() const { + return {Sleef_tanhf4_u10(_vec0), Sleef_tanhf4_u10(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE trunc() const { + return {vec_trunc(_vec0), vec_trunc(_vec1)}; + } + + Vectorized C10_ALWAYS_INLINE frac() const { + return *this - trunc(); + } + + Vectorized C10_ALWAYS_INLINE sqrt() const { + return {vec_sqrt(_vec0), vec_sqrt(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE reciprocal() const { + return Vectorized(one) / (*this); + } + Vectorized C10_ALWAYS_INLINE rsqrt() const { + return sqrt().reciprocal(); + } + + Vectorized C10_ALWAYS_INLINE pow(const Vectorized& exp) const { + return { + Sleef_powf4_u10(_vec0, exp._vec0), Sleef_powf4_u10(_vec1, exp._vec1)}; + } + + Vectorized fmod(const Vectorized& b) const { + return {Sleef_fmodf4(_vec0, b._vec0), Sleef_fmodf4(_vec1, b._vec1)}; + } + + Vectorized hypot(const Vectorized& b) const { + return { + Sleef_hypotf4_u05(_vec0, b._vec0), Sleef_hypotf4_u05(_vec1, b._vec1)}; + } + + Vectorized nextafter(const Vectorized& b) const { + return { + Sleef_nextafterf4(_vec0, b._vec0), Sleef_nextafterf4(_vec1, b._vec1)}; + } + + Vectorized igamma(const Vectorized& x) const { + return mapbi(calc_igamma, x); + } + + Vectorized igammac(const Vectorized& x) const { + return mapbi(calc_igammac, x); + } + + Vectorized i0() const { + return map(calc_i0); + } + + Vectorized i0e() const { + return map(calc_i0e); + } + + Vectorized digamma() const { + return map(calc_digamma); + } + + DEFINE_MEMBER_OP(operator==, float, vec_cmpeq) + DEFINE_MEMBER_OP(operator!=, float, vec_cmpne) + DEFINE_MEMBER_OP(operator<, float, vec_cmplt) + DEFINE_MEMBER_OP(operator<=, float, vec_cmple) + DEFINE_MEMBER_OP(operator>, float, vec_cmpgt) + DEFINE_MEMBER_OP(operator>=, float, vec_cmpge) + DEFINE_MEMBER_OP_AND_ONE(eq, float, vec_cmpeq) + DEFINE_MEMBER_OP_AND_ONE(ne, float, vec_cmpne) + DEFINE_MEMBER_OP_AND_ONE(lt, float, vec_cmplt) + DEFINE_MEMBER_OP_AND_ONE(le, float, vec_cmple) + DEFINE_MEMBER_OP_AND_ONE(gt, float, vec_cmpgt) + DEFINE_MEMBER_OP_AND_ONE(ge, float, vec_cmpge) + DEFINE_MEMBER_OP(operator+, float, vec_add) + DEFINE_MEMBER_OP(operator-, float, vec_sub) + DEFINE_MEMBER_OP(operator*, float, vec_mul) + DEFINE_MEMBER_OP(operator/, float, vec_div) + DEFINE_MEMBER_OP(maximum, float, vec_max_nan2) + DEFINE_MEMBER_OP(minimum, float, vec_min_nan2) + DEFINE_MEMBER_OP(operator&, float, vec_and) + DEFINE_MEMBER_OP(operator|, float, vec_or) + DEFINE_MEMBER_OP(operator^, float, vec_xor) + DEFINE_MEMBER_TERNARY_OP(madd, float, vec_madd) +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return a.minimum(b); +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator+(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_add(a.vec0(), b.vec0()), vec_add(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator-(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_sub(a.vec0(), b.vec0()), vec_sub(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator*(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_mul(a.vec0(), b.vec0()), vec_mul(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator/(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_div(a.vec0(), b.vec0()), vec_div(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator&(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_and(a.vec0(), b.vec0()), vec_and(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator|(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_or(a.vec0(), b.vec0()), vec_or(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator^(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_xor(a.vec0(), b.vec0()), vec_xor(a.vec1(), b.vec1())}; +} + +} // namespace CPU_CAPABILITY +} // namespace vec +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_int16_vsx.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_int16_vsx.h new file mode 100644 index 0000000000000000000000000000000000000000..5150ccf3a2cd6df9c05e1f2b1184912ebd9ad7fd --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_int16_vsx.h @@ -0,0 +1,422 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +namespace at { +namespace vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized { + private: + union { + struct { + vint16 _vec0; + vint16 _vec1; + }; + struct { + vbool16 _vecb0; + vbool16 _vecb1; + }; + + } __attribute__((__may_alias__)); + + public: + using value_type = int16_t; + using vec_internal_type = vint16; + using vec_internal_mask_type = vbool16; + using size_type = int; + static constexpr size_type size() { + return 16; + } + Vectorized() {} + C10_ALWAYS_INLINE Vectorized(vint16 v) : _vec0{v}, _vec1{v} {} + C10_ALWAYS_INLINE Vectorized(vbool16 vmask) : _vecb0{vmask}, _vecb1{vmask} {} + C10_ALWAYS_INLINE Vectorized(vint16 v1, vint16 v2) : _vec0{v1}, _vec1{v2} {} + C10_ALWAYS_INLINE Vectorized(vbool16 v1, vbool16 v2) + : _vecb0{v1}, _vecb1{v2} {} + C10_ALWAYS_INLINE Vectorized(int16_t scalar) + : _vec0{vec_splats(scalar)}, _vec1{vec_splats(scalar)} {} + + C10_ALWAYS_INLINE Vectorized( + int16_t scalar1, + int16_t scalar2, + int16_t scalar3, + int16_t scalar4, + int16_t scalar5, + int16_t scalar6, + int16_t scalar7, + int16_t scalar8, + int16_t scalar9, + int16_t scalar10, + int16_t scalar11, + int16_t scalar12, + int16_t scalar13, + int16_t scalar14, + int16_t scalar15, + int16_t scalar16) + : _vec0{vint16{ + scalar1, + scalar2, + scalar3, + scalar4, + scalar5, + scalar6, + scalar7, + scalar8}}, + _vec1{vint16{ + scalar9, + scalar10, + scalar11, + scalar12, + scalar13, + scalar14, + scalar15, + scalar16}} {} + C10_ALWAYS_INLINE const vec_internal_type& vec0() const { + return _vec0; + } + C10_ALWAYS_INLINE const vec_internal_type& vec1() const { + return _vec1; + } + + template + static std::enable_if_t> C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + return a; + } + + template + static std::enable_if_t<(mask & 65535) == 65535, Vectorized> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + return b; + } + + template + static std::enable_if_t> C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + return {b._vec0, a._vec1}; + } + + template + static std::enable_if_t<(mask > 0 && mask < 255), Vectorized> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + constexpr int16_t g0 = (mask & 1) * 0xffff; + constexpr int16_t g1 = ((mask & 2) >> 1) * 0xffff; + constexpr int16_t g2 = ((mask & 4) >> 2) * 0xffff; + constexpr int16_t g3 = ((mask & 8) >> 3) * 0xffff; + constexpr int16_t g4 = ((mask & 16) >> 4) * 0xffff; + constexpr int16_t g5 = ((mask & 32) >> 5) * 0xffff; + constexpr int16_t g6 = ((mask & 64) >> 6) * 0xffff; + constexpr int16_t g7 = ((mask & 128) >> 7) * 0xffff; + const vint16 mask_1st = vint16{g0, g1, g2, g3, g4, g5, g6, g7}; + + return {(vint16)vec_sel(a._vec0, b._vec0, (vbool16)mask_1st), a._vec1}; + } + + template + static std::enable_if_t< + (mask > 255 && (mask & 65535) != 65535 && ((mask & 255) == 255)), + Vectorized> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + constexpr int16_t g0_2 = (mask & 1) * 0xffff; + constexpr int16_t g1_2 = ((mask & 2) >> 1) * 0xffff; + constexpr int16_t g2_2 = ((mask & 4) >> 2) * 0xffff; + constexpr int16_t g3_2 = ((mask & 8) >> 3) * 0xffff; + constexpr int16_t g4_2 = ((mask & 16) >> 4) * 0xffff; + constexpr int16_t g5_2 = ((mask & 32) >> 5) * 0xffff; + constexpr int16_t g6_2 = ((mask & 64) >> 6) * 0xffff; + constexpr int16_t g7_2 = ((mask & 128) >> 7) * 0xffff; + + const vint16 mask_2nd = + vint16{g0_2, g1_2, g2_2, g3_2, g4_2, g5_2, g6_2, g7_2}; + // generated masks + return {b._vec0, (vint16)vec_sel(a._vec1, b._vec1, (vbool16)mask_2nd)}; + } + + template + static std::enable_if_t< + (mask > 255 && ((mask & 65535) != 65535) && ((mask & 255) == 0)), + Vectorized> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + constexpr int16_t mask2 = (mask & 65535) >> 16; + constexpr int16_t g0_2 = (mask & 1) * 0xffff; + constexpr int16_t g1_2 = ((mask & 2) >> 1) * 0xffff; + constexpr int16_t g2_2 = ((mask & 4) >> 2) * 0xffff; + constexpr int16_t g3_2 = ((mask & 8) >> 3) * 0xffff; + constexpr int16_t g4_2 = ((mask & 16) >> 4) * 0xffff; + constexpr int16_t g5_2 = ((mask & 32) >> 5) * 0xffff; + constexpr int16_t g6_2 = ((mask & 64) >> 6) * 0xffff; + constexpr int16_t g7_2 = ((mask & 128) >> 7) * 0xffff; + + const vint16 mask_2nd = + vint16{g0_2, g1_2, g2_2, g3_2, g4_2, g5_2, g6_2, g7_2}; + // generated masks + return {a, (vint16)vec_sel(a._vec1, b._vec1, (vbool16)mask_2nd)}; + } + + template + static std::enable_if_t< + (mask > 255 && ((mask & 65535) != 65535) && ((mask & 255) != 0) && + ((mask & 255) != 255)), + Vectorized> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + constexpr int16_t g0 = (mask & 1) * 0xffff; + constexpr int16_t g1 = ((mask & 2) >> 1) * 0xffff; + constexpr int16_t g2 = ((mask & 4) >> 2) * 0xffff; + constexpr int16_t g3 = ((mask & 8) >> 3) * 0xffff; + constexpr int16_t g4 = ((mask & 16) >> 4) * 0xffff; + constexpr int16_t g5 = ((mask & 32) >> 5) * 0xffff; + constexpr int16_t g6 = ((mask & 64) >> 6) * 0xffff; + constexpr int16_t g7 = ((mask & 128) >> 7) * 0xffff; + constexpr int16_t mask2 = (mask & 65535) >> 16; + constexpr int16_t g0_2 = (mask & 1) * 0xffff; + constexpr int16_t g1_2 = ((mask & 2) >> 1) * 0xffff; + constexpr int16_t g2_2 = ((mask & 4) >> 2) * 0xffff; + constexpr int16_t g3_2 = ((mask & 8) >> 3) * 0xffff; + constexpr int16_t g4_2 = ((mask & 16) >> 4) * 0xffff; + constexpr int16_t g5_2 = ((mask & 32) >> 5) * 0xffff; + constexpr int16_t g6_2 = ((mask & 64) >> 6) * 0xffff; + constexpr int16_t g7_2 = ((mask & 128) >> 7) * 0xffff; + + const vint16 mask_1st = vint16{g0, g1, g2, g3, g4, g5, g6, g7}; + const vint16 mask_2nd = + vint16{g0_2, g1_2, g2_2, g3_2, g4_2, g5_2, g6_2, g7_2}; + // generated masks + return { + (vint16)vec_sel(a._vec0, b._vec0, (vbool16)mask_1st), + (vint16)vec_sel(a._vec1, b._vec1, (vbool16)mask_2nd)}; + } + + static Vectorized C10_ALWAYS_INLINE blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + // the mask used here returned by comparison of vec256 + // assuming this we can use the same mask directly with vec_sel + // warning intel style mask will not work properly + return { + vec_sel(a._vec0, b._vec0, mask._vecb0), + vec_sel(a._vec1, b._vec1, mask._vecb1)}; + } + + template + static Vectorized arange( + int16_t base = 0, + step_t step = static_cast(1)) { + return Vectorized( + base, + base + step, + base + 2 * step, + base + 3 * step, + base + 4 * step, + base + 5 * step, + base + 6 * step, + base + 7 * step, + base + 8 * step, + base + 9 * step, + base + 10 * step, + base + 11 * step, + base + 12 * step, + base + 13 * step, + base + 14 * step, + base + 15 * step); + } + static Vectorized set( + const Vectorized& a, + const Vectorized& b, + size_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + case 4: + return blend<15>(a, b); + case 5: + return blend<31>(a, b); + case 6: + return blend<63>(a, b); + case 7: + return blend<127>(a, b); + case 8: + return blend<255>(a, b); + case 9: + return blend<511>(a, b); + case 10: + return blend<1023>(a, b); + case 11: + return blend<2047>(a, b); + case 12: + return blend<4095>(a, b); + case 13: + return blend<8191>(a, b); + case 14: + return blend<16383>(a, b); + case 15: + return blend<32767>(a, b); + } + return b; + } + static Vectorized C10_ALWAYS_INLINE + loadu(const void* ptr, int count = size()) { + if (count == size()) { + return { + vec_vsx_ld(offset0, reinterpret_cast(ptr)), + vec_vsx_ld(offset16, reinterpret_cast(ptr))}; + } + + __at_align__ value_type tmp_values[size()] = {}; + std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); + + return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)}; + } + void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const { + if (count == size()) { + vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr)); + vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr)); + } else if (count > 0) { + __at_align__ value_type tmp_values[size()]; + vec_vsx_st(_vec0, offset0, tmp_values); + vec_vsx_st(_vec1, offset16, tmp_values); + std::memcpy( + ptr, tmp_values, std::min(count, size()) * sizeof(value_type)); + } + } + const int16_t& operator[](int idx) const = delete; + int16_t& operator[](int idx) = delete; + + Vectorized angle() const { + return blendv( + Vectorized(0), + Vectorized(c10::pi), + *this < Vectorized(0)); + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return Vectorized{0}; + } + Vectorized conj() const { + return *this; + } + + Vectorized C10_ALWAYS_INLINE abs() const { + return {vec_abs(_vec0), vec_abs(_vec1)}; + } + + Vectorized C10_ALWAYS_INLINE neg() const { + return {vec_neg(_vec0), vec_neg(_vec1)}; + } + + DEFINE_MEMBER_UNARY_OP(operator~, int16_t, vec_not) + DEFINE_MEMBER_OP(operator==, int16_t, vec_cmpeq) + DEFINE_MEMBER_OP(operator!=, int16_t, vec_cmpne) + DEFINE_MEMBER_OP(operator<, int16_t, vec_cmplt) + DEFINE_MEMBER_OP(operator<=, int16_t, vec_cmple) + DEFINE_MEMBER_OP(operator>, int16_t, vec_cmpgt) + DEFINE_MEMBER_OP(operator>=, int16_t, vec_cmpge) + DEFINE_MEMBER_OP_AND_ONE(eq, int16_t, vec_cmpeq) + DEFINE_MEMBER_OP_AND_ONE(ne, int16_t, vec_cmpne) + DEFINE_MEMBER_OP_AND_ONE(lt, int16_t, vec_cmplt) + DEFINE_MEMBER_OP_AND_ONE(le, int16_t, vec_cmple) + DEFINE_MEMBER_OP_AND_ONE(gt, int16_t, vec_cmpgt) + DEFINE_MEMBER_OP_AND_ONE(ge, int16_t, vec_cmpge) + DEFINE_MEMBER_OP(operator+, int16_t, vec_add) + DEFINE_MEMBER_OP(operator-, int16_t, vec_sub) + DEFINE_MEMBER_OP(operator*, int16_t, vec_mul) + DEFINE_MEMBER_EMULATE_BINARY_OP(operator/, int16_t, /) + DEFINE_MEMBER_OP(maximum, int16_t, vec_max) + DEFINE_MEMBER_OP(minimum, int16_t, vec_min) + DEFINE_MEMBER_OP(operator&, int16_t, vec_and) + DEFINE_MEMBER_OP(operator|, int16_t, vec_or) + DEFINE_MEMBER_OP(operator^, int16_t, vec_xor) +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return a.minimum(b); +} + +DEFINE_SHIFT_FUNCS(int16_t) + +template <> +Vectorized C10_ALWAYS_INLINE +operator+(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_add(a.vec0(), b.vec0()), vec_add(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator-(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_sub(a.vec0(), b.vec0()), vec_sub(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator*(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_mul(a.vec0(), b.vec0()), vec_mul(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator/(const Vectorized& a, const Vectorized& b) { + return Vectorized{a.vec0() / b.vec0(), a.vec1() / b.vec1()}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator&(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_and(a.vec0(), b.vec0()), vec_and(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator|(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_or(a.vec0(), b.vec0()), vec_or(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator^(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_xor(a.vec0(), b.vec0()), vec_xor(a.vec1(), b.vec1())}; +} + +} // namespace CPU_CAPABILITY +} // namespace vec +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_int32_vsx.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_int32_vsx.h new file mode 100644 index 0000000000000000000000000000000000000000..baa0a95a9bd194a8a4f7cc3a1518a77d12bd8e58 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_int32_vsx.h @@ -0,0 +1,352 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +namespace at { +namespace vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized { + private: + union { + struct { + vint32 _vec0; + vint32 _vec1; + }; + struct { + vbool32 _vecb0; + vbool32 _vecb1; + }; + + } __attribute__((__may_alias__)); + + public: + using value_type = int32_t; + using vec_internal_type = vint32; + using vec_internal_mask_type = vbool32; + using size_type = int; + static constexpr size_type size() { + return 8; + } + Vectorized() {} + C10_ALWAYS_INLINE Vectorized(vint32 v) : _vec0{v}, _vec1{v} {} + C10_ALWAYS_INLINE Vectorized(vbool32 vmask) : _vecb0{vmask}, _vecb1{vmask} {} + C10_ALWAYS_INLINE Vectorized(vint32 v1, vint32 v2) : _vec0{v1}, _vec1{v2} {} + C10_ALWAYS_INLINE Vectorized(vbool32 v1, vbool32 v2) + : _vecb0{v1}, _vecb1{v2} {} + C10_ALWAYS_INLINE Vectorized(int32_t scalar) + : _vec0{vec_splats(scalar)}, _vec1{vec_splats(scalar)} {} + C10_ALWAYS_INLINE Vectorized( + int32_t scalar1, + int32_t scalar2, + int32_t scalar3, + int32_t scalar4, + int32_t scalar5, + int32_t scalar6, + int32_t scalar7, + int32_t scalar8) + : _vec0{vint32{scalar1, scalar2, scalar3, scalar4}}, + _vec1{vint32{scalar5, scalar6, scalar7, scalar8}} {} + C10_ALWAYS_INLINE const vec_internal_type& vec0() const { + return _vec0; + } + C10_ALWAYS_INLINE const vec_internal_type& vec1() const { + return _vec1; + } + + template + static std::enable_if_t> C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + return a; + } + + template + static std::enable_if_t<(mask & 255) == 255, Vectorized> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + return b; + } + + template + static std::enable_if_t> C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + return {b._vec0, a._vec1}; + } + + template + static std::enable_if_t<(mask > 0 && mask < 15), Vectorized> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + constexpr uint32_t g0 = (mask & 1) * 0xffffffff; + constexpr uint32_t g1 = ((mask & 2) >> 1) * 0xffffffff; + constexpr uint32_t g2 = ((mask & 4) >> 2) * 0xffffffff; + constexpr uint32_t g3 = ((mask & 8) >> 3) * 0xffffffff; + const vbool32 mask_1st = (vbool32){g0, g1, g2, g3}; + + return {(vint32)vec_sel(a._vec0, b._vec0, (vbool32)mask_1st), a._vec1}; + } + + template + static std::enable_if_t< + (mask > 15 && (mask & 255) != 255 && ((mask & 15) == 15)), + Vectorized> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + constexpr uint32_t mask2 = (mask & 255) >> 4; + constexpr uint32_t g0_2 = (mask2 & 1) * 0xffffffff; + constexpr uint32_t g1_2 = ((mask2 & 2) >> 1) * 0xffffffff; + constexpr uint32_t g2_2 = ((mask2 & 4) >> 2) * 0xffffffff; + constexpr uint32_t g3_2 = ((mask2 & 8) >> 3) * 0xffffffff; + + const vbool32 mask_2nd = (vbool32){g0_2, g1_2, g2_2, g3_2}; + // generated masks + return {b._vec0, (vint32)vec_sel(a._vec1, b._vec1, (vbool32)mask_2nd)}; + } + + template + static std::enable_if_t< + (mask > 15 && ((mask & 255) != 255) && ((mask & 15) == 0)), + Vectorized> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + constexpr uint32_t mask2 = (mask & 255) >> 4; + constexpr uint32_t g0_2 = (mask2 & 1) * 0xffffffff; + constexpr uint32_t g1_2 = ((mask2 & 2) >> 1) * 0xffffffff; + constexpr uint32_t g2_2 = ((mask2 & 4) >> 2) * 0xffffffff; + constexpr uint32_t g3_2 = ((mask2 & 8) >> 3) * 0xffffffff; + + const vbool32 mask_2nd = (vbool32){g0_2, g1_2, g2_2, g3_2}; + // generated masks + return {a, (vint32)vec_sel(a._vec1, b._vec1, (vbool32)mask_2nd)}; + } + + template + static std::enable_if_t< + (mask > 15 && ((mask & 255) != 255) && ((mask & 15) != 0) && + ((mask & 15) != 15)), + Vectorized> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + constexpr uint32_t g0 = (mask & 1) * 0xffffffff; + constexpr uint32_t g1 = ((mask & 2) >> 1) * 0xffffffff; + constexpr uint32_t g2 = ((mask & 4) >> 2) * 0xffffffff; + constexpr uint32_t g3 = ((mask & 8) >> 3) * 0xffffffff; + constexpr uint32_t mask2 = (mask & 255) >> 4; + constexpr uint32_t g0_2 = (mask2 & 1) * 0xffffffff; + constexpr uint32_t g1_2 = ((mask2 & 2) >> 1) * 0xffffffff; + constexpr uint32_t g2_2 = ((mask2 & 4) >> 2) * 0xffffffff; + constexpr uint32_t g3_2 = ((mask2 & 8) >> 3) * 0xffffffff; + + const vbool32 mask_1st = (vbool32){g0, g1, g2, g3}; + const vbool32 mask_2nd = (vbool32){g0_2, g1_2, g2_2, g3_2}; + // generated masks + return { + (vint32)vec_sel(a._vec0, b._vec0, (vbool32)mask_1st), + (vint32)vec_sel(a._vec1, b._vec1, (vbool32)mask_2nd)}; + } + + static Vectorized C10_ALWAYS_INLINE blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + // the mask used here returned by comparison of vec256 + // assuming this we can use the same mask directly with vec_sel + // warning intel style mask will not work properly + return { + vec_sel(a._vec0, b._vec0, mask._vecb0), + vec_sel(a._vec1, b._vec1, mask._vecb1)}; + } + + template + static Vectorized arange( + int32_t base = 0.f, + step_t step = static_cast(1)) { + return Vectorized( + base, + base + step, + base + 2 * step, + base + 3 * step, + base + 4 * step, + base + 5 * step, + base + 6 * step, + base + 7 * step); + } + static Vectorized set( + const Vectorized& a, + const Vectorized& b, + size_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + case 4: + return blend<15>(a, b); + case 5: + return blend<31>(a, b); + case 6: + return blend<63>(a, b); + case 7: + return blend<127>(a, b); + } + + return b; + } + static Vectorized C10_ALWAYS_INLINE + loadu(const void* ptr, int count = size()) { + if (count == size()) { + return { + vec_vsx_ld(offset0, reinterpret_cast(ptr)), + vec_vsx_ld(offset16, reinterpret_cast(ptr))}; + } + + __at_align__ value_type tmp_values[size()] = {}; + std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); + + return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)}; + } + void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const { + if (count == size()) { + vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr)); + vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr)); + } else if (count > 0) { + __at_align__ value_type tmp_values[size()]; + vec_vsx_st(_vec0, offset0, tmp_values); + vec_vsx_st(_vec1, offset16, tmp_values); + std::memcpy( + ptr, tmp_values, std::min(count, size()) * sizeof(value_type)); + } + } + const int32_t& operator[](int idx) const = delete; + int32_t& operator[](int idx) = delete; + + Vectorized angle() const { + return blendv( + Vectorized(0), + Vectorized(c10::pi), + *this < Vectorized(0)); + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return Vectorized{0}; + } + Vectorized conj() const { + return *this; + } + + Vectorized C10_ALWAYS_INLINE abs() const { + return {vec_abs(_vec0), vec_abs(_vec1)}; + } + + Vectorized C10_ALWAYS_INLINE neg() const { + return {vec_neg(_vec0), vec_neg(_vec1)}; + } + + DEFINE_MEMBER_UNARY_OP(operator~, int32_t, vec_not) + DEFINE_MEMBER_OP(operator==, int32_t, vec_cmpeq) + DEFINE_MEMBER_OP(operator!=, int32_t, vec_cmpne) + DEFINE_MEMBER_OP(operator<, int32_t, vec_cmplt) + DEFINE_MEMBER_OP(operator<=, int32_t, vec_cmple) + DEFINE_MEMBER_OP(operator>, int32_t, vec_cmpgt) + DEFINE_MEMBER_OP(operator>=, int32_t, vec_cmpge) + DEFINE_MEMBER_OP_AND_ONE(eq, int32_t, vec_cmpeq) + DEFINE_MEMBER_OP_AND_ONE(ne, int32_t, vec_cmpne) + DEFINE_MEMBER_OP_AND_ONE(lt, int32_t, vec_cmplt) + DEFINE_MEMBER_OP_AND_ONE(le, int32_t, vec_cmple) + DEFINE_MEMBER_OP_AND_ONE(gt, int32_t, vec_cmpgt) + DEFINE_MEMBER_OP_AND_ONE(ge, int32_t, vec_cmpge) + DEFINE_MEMBER_OP(operator+, int32_t, vec_add) + DEFINE_MEMBER_OP(operator-, int32_t, vec_sub) + DEFINE_MEMBER_OP(operator*, int32_t, vec_mul) + DEFINE_MEMBER_EMULATE_BINARY_OP(operator/, int32_t, /) + DEFINE_MEMBER_OP(maximum, int32_t, vec_max) + DEFINE_MEMBER_OP(minimum, int32_t, vec_min) + DEFINE_MEMBER_OP(operator&, int32_t, vec_and) + DEFINE_MEMBER_OP(operator|, int32_t, vec_or) + DEFINE_MEMBER_OP(operator^, int32_t, vec_xor) +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return a.minimum(b); +} + +DEFINE_SHIFT_FUNCS(int32_t) + +template <> +Vectorized C10_ALWAYS_INLINE +operator+(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_add(a.vec0(), b.vec0()), vec_add(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator-(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_sub(a.vec0(), b.vec0()), vec_sub(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator*(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_mul(a.vec0(), b.vec0()), vec_mul(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator/(const Vectorized& a, const Vectorized& b) { + return Vectorized{a.vec0() / b.vec0(), a.vec1() / b.vec1()}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator&(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_and(a.vec0(), b.vec0()), vec_and(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator|(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_or(a.vec0(), b.vec0()), vec_or(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator^(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_xor(a.vec0(), b.vec0()), vec_xor(a.vec1(), b.vec1())}; +} + +} // namespace CPU_CAPABILITY +} // namespace vec +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_int64_vsx.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_int64_vsx.h new file mode 100644 index 0000000000000000000000000000000000000000..c3012293b3c7b0c10855f86f6c747b50e4ee1a17 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_int64_vsx.h @@ -0,0 +1,306 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +namespace at { +namespace vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized { + private: + union { + struct { + vint64 _vec0; + vint64 _vec1; + }; + struct { + vbool64 _vecb0; + vbool64 _vecb1; + }; + + } __attribute__((__may_alias__)); + + public: + using value_type = int64_t; + using vec_internal_type = vint64; + using vec_internal_mask_type = vbool64; + using size_type = int; + using ElementType = signed long long; + static constexpr size_type size() { + return 4; + } + Vectorized() {} + C10_ALWAYS_INLINE Vectorized(vint64 v) : _vec0{v}, _vec1{v} {} + C10_ALWAYS_INLINE Vectorized(vbool64 vmask) : _vecb0{vmask}, _vecb1{vmask} {} + C10_ALWAYS_INLINE Vectorized(vint64 v1, vint64 v2) : _vec0{v1}, _vec1{v2} {} + C10_ALWAYS_INLINE Vectorized(vbool64 v1, vbool64 v2) + : _vecb0{v1}, _vecb1{v2} {} + C10_ALWAYS_INLINE Vectorized(int64_t scalar) + : _vec0{vec_splats(scalar)}, _vec1{vec_splats(scalar)} {} + C10_ALWAYS_INLINE Vectorized( + int64_t scalar1, + int64_t scalar2, + int64_t scalar3, + int64_t scalar4) + : _vec0{vint64{scalar1, scalar2}}, _vec1{vint64{scalar3, scalar4}} {} + + C10_ALWAYS_INLINE const vec_internal_type& vec0() const { + return _vec0; + } + C10_ALWAYS_INLINE const vec_internal_type& vec1() const { + return _vec1; + } + + template + static std::enable_if_t> C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + return a; + } + + template + static std::enable_if_t> C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + return {b._vec0, a._vec1}; + } + + template + static std::enable_if_t<(mask & 15) == 15, Vectorized> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + return b; + } + + template + static std::enable_if_t<(mask > 0 && mask < 3), Vectorized> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + constexpr uint64_t g0 = (mask & 1) * 0xffffffffffffffff; + constexpr uint64_t g1 = ((mask & 2) >> 1) * 0xffffffffffffffff; + const vbool64 mask_1st = (vbool64){g0, g1}; + return {(vint64)vec_sel(a._vec0, b._vec0, (vbool64)mask_1st), a._vec1}; + } + + template + static std::enable_if_t<(mask > 3) && (mask & 3) == 0, Vectorized> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + constexpr uint64_t g0_2 = ((mask & 4) >> 2) * 0xffffffffffffffff; + constexpr uint64_t g1_2 = ((mask & 8) >> 3) * 0xffffffffffffffff; + + const vbool64 mask_2nd = (vbool64){g0_2, g1_2}; + return {a._vec0, (vint64)vec_sel(a._vec1, b._vec1, (vbool64)mask_2nd)}; + } + + template + static std::enable_if_t< + (mask > 3) && (mask & 3) != 0 && (mask & 15) != 15, + Vectorized> + C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + constexpr uint64_t g0 = (mask & 1) * 0xffffffffffffffff; + constexpr uint64_t g1 = ((mask & 2) >> 1) * 0xffffffffffffffff; + constexpr uint64_t g0_2 = ((mask & 4) >> 2) * 0xffffffffffffffff; + constexpr uint64_t g1_2 = ((mask & 8) >> 3) * 0xffffffffffffffff; + + const vbool64 mask_1st = (vbool64){g0, g1}; + const vbool64 mask_2nd = (vbool64){g0_2, g1_2}; + return { + (vint64)vec_sel(a._vec0, b._vec0, (vbool64)mask_1st), + (vint64)vec_sel(a._vec1, b._vec1, (vbool64)mask_2nd)}; + } + + static Vectorized C10_ALWAYS_INLINE blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + // the mask used here returned by comparison of vec256 + + return { + vec_sel(a._vec0, b._vec0, mask._vecb0), + vec_sel(a._vec1, b._vec1, mask._vecb1)}; + } + template + static Vectorized arange( + int64_t base = 0., + step_t step = static_cast(1)) { + return Vectorized( + base, base + step, base + 2 * step, base + 3 * step); + } + + static Vectorized C10_ALWAYS_INLINE + set(const Vectorized& a, + const Vectorized& b, + size_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + } + + return b; + } + static Vectorized C10_ALWAYS_INLINE + loadu(const void* ptr, int count = size()) { + if (count == size()) { + static_assert(sizeof(double) == sizeof(value_type)); + const double* dptr = reinterpret_cast(ptr); + return {// treat it as double load + (vint64)vec_vsx_ld(offset0, dptr), + (vint64)vec_vsx_ld(offset16, dptr)}; + } + + __at_align__ double tmp_values[size()] = {}; + std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); + + return { + (vint64)vec_vsx_ld(offset0, tmp_values), + (vint64)vec_vsx_ld(offset16, tmp_values)}; + } + void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const { + if (count == size()) { + double* dptr = reinterpret_cast(ptr); + vec_vsx_st((vfloat64)_vec0, offset0, dptr); + vec_vsx_st((vfloat64)_vec1, offset16, dptr); + } else if (count > 0) { + __at_align__ double tmp_values[size()]; + vec_vsx_st((vfloat64)_vec0, offset0, tmp_values); + vec_vsx_st((vfloat64)_vec1, offset16, tmp_values); + std::memcpy( + ptr, tmp_values, std::min(count, size()) * sizeof(value_type)); + } + } + const int64_t& operator[](int idx) const = delete; + int64_t& operator[](int idx) = delete; + + Vectorized angle() const { + return blendv( + Vectorized(0), + Vectorized(c10::pi), + *this < Vectorized(0)); + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return Vectorized{0}; + } + Vectorized conj() const { + return *this; + } + + Vectorized C10_ALWAYS_INLINE abs() const { + return {vec_abs(_vec0), vec_abs(_vec1)}; + } + + Vectorized C10_ALWAYS_INLINE neg() const { + return {vec_neg(_vec0), vec_neg(_vec1)}; + } + + DEFINE_MEMBER_UNARY_OP(operator~, int64_t, vec_not) + DEFINE_MEMBER_OP(operator==, int64_t, vec_cmpeq) + DEFINE_MEMBER_OP(operator!=, int64_t, vec_cmpne) + DEFINE_MEMBER_OP(operator<, int64_t, vec_cmplt) + DEFINE_MEMBER_OP(operator<=, int64_t, vec_cmple) + DEFINE_MEMBER_OP(operator>, int64_t, vec_cmpgt) + DEFINE_MEMBER_OP(operator>=, int64_t, vec_cmpge) + DEFINE_MEMBER_OP_AND_ONE(eq, int64_t, vec_cmpeq) + DEFINE_MEMBER_OP_AND_ONE(ne, int64_t, vec_cmpne) + DEFINE_MEMBER_OP_AND_ONE(lt, int64_t, vec_cmplt) + DEFINE_MEMBER_OP_AND_ONE(le, int64_t, vec_cmple) + DEFINE_MEMBER_OP_AND_ONE(gt, int64_t, vec_cmpgt) + DEFINE_MEMBER_OP_AND_ONE(ge, int64_t, vec_cmpge) + DEFINE_MEMBER_OP(operator+, int64_t, vec_add) + DEFINE_MEMBER_OP(operator-, int64_t, vec_sub) + DEFINE_MEMBER_OP(operator*, int64_t, vec_mul) + DEFINE_MEMBER_OP(operator/, int64_t, vec_div) + DEFINE_MEMBER_OP(maximum, int64_t, vec_max) + DEFINE_MEMBER_OP(minimum, int64_t, vec_min) + DEFINE_MEMBER_OP(operator&, int64_t, vec_and) + DEFINE_MEMBER_OP(operator|, int64_t, vec_or) + DEFINE_MEMBER_OP(operator^, int64_t, vec_xor) +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return a.minimum(b); +} + +DEFINE_SHIFT_FUNCS(int64_t) + +template <> +Vectorized C10_ALWAYS_INLINE +operator+(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_add(a.vec0(), b.vec0()), vec_add(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator-(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_sub(a.vec0(), b.vec0()), vec_sub(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator*(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_mul(a.vec0(), b.vec0()), vec_mul(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator/(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_div(a.vec0(), b.vec0()), vec_div(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator&(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_and(a.vec0(), b.vec0()), vec_and(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator|(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_or(a.vec0(), b.vec0()), vec_or(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator^(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_xor(a.vec0(), b.vec0()), vec_xor(a.vec1(), b.vec1())}; +} + +} // namespace CPU_CAPABILITY +} // namespace vec +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_mask_vsx.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_mask_vsx.h new file mode 100644 index 0000000000000000000000000000000000000000..f02be95efa692b75a8ba7349492d58177b66a978 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_mask_vsx.h @@ -0,0 +1,74 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include + +namespace at::vec { +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_VSX) + +template +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + VectorizedN result; + + for (int i = 0; i < N; ++i) { + auto tmp = vec_mask[i]; + result[i] = reinterpret_cast&>(tmp); + } + return VecMask(result); + } +}; + +template +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + VectorizedN result; + + for (int i = 0; i < N; ++i) { + auto tmp = vec_mask[i]; + result[i] = reinterpret_cast&>(tmp); + } + return VecMask(result); + } +}; + +template +struct VecMaskCast< + int64_t, + dst_n, + mask_t, + mask_n, + typename std::enable_if_t< + (dst_n == 2 * mask_n) && + (std::is_same_v || std::is_same_v)>> { + static inline VecMask apply( + const VecMask& vec_mask) { + VectorizedN result; + + auto int_mask = vec_mask.template cast(); + + for (int i = 0; i < mask_n; ++i) { + VectorizedN in_int_n; + in_int_n[0] = int_mask[i]; + + auto int64_vecs = convert(in_int_n); + + result[2 * i] = int64_vecs[0]; + result[2 * i + 1] = int64_vecs[1]; + } + return VecMask(result); + } +}; + +#endif + +} // namespace CPU_CAPABILITY +} // namespace at::vec + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_qint32_vsx.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_qint32_vsx.h new file mode 100644 index 0000000000000000000000000000000000000000..692607d4d5254353f74d43ce88404cb96d9d770b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_qint32_vsx.h @@ -0,0 +1,306 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include + +// This file defines Vectorized<> for the quantized types. +// +// +// Currently, we simply use these classes as efficient converters between +// the quantized types and Vectorized, usually in bandwidth-bound cases +// where doing the arithmetic in full-precision is acceptable (e.g. +// elementwise operators). +// +// +// Conversions are as follows: +// Vectorized -> 1x Vectorized +// +// The size of the returned float vector is specified by the special +// constexpr function float_num_vecs. The type of the value returned +// from dequantize (and expected as an argument to quantize) is +// specified by float_vec_return_type. +// +// When writing kernels with these vectors, it is expected that floating- +// point operations will be carried out in a loop over +// Vectorized::float_num_vecs iterations. + +namespace at { +namespace vec { +inline namespace CPU_CAPABILITY { + +template <> +struct is_vec_specialized_for : std::bool_constant {}; +template <> +struct Vectorized { + private: + union { + struct { + vint32 _vec0; + vint32 _vec1; + }; + struct { + vbool32 _vecb0; + vbool32 _vecb1; + }; + + } __attribute__((__may_alias__)); + + public: + Vectorized() {} + + using size_type = int; + static constexpr size_type size() { + return 8; + } + + static constexpr size_t float_num_vecs() { + return 1; + } + static constexpr int int_num_vecs() { + return 1; + } + using float_vec_return_type = std::array, 1>; + using int_vec_return_type = std::array, 1>; + using value_type = c10::qint32::underlying; + using vec_internal_type = vint32; + using vec_internal_mask_type = vbool32; + C10_ALWAYS_INLINE Vectorized(vint32 v) : _vec0{v}, _vec1{v} {} + C10_ALWAYS_INLINE Vectorized(vbool32 vmask) : _vecb0{vmask}, _vecb1{vmask} {} + C10_ALWAYS_INLINE Vectorized(vint32 v1, vint32 v2) : _vec0{v1}, _vec1{v2} {} + C10_ALWAYS_INLINE Vectorized(vbool32 v1, vbool32 v2) + : _vecb0{v1}, _vecb1{v2} {} + + Vectorized(const c10::qint32& val) + : _vec0(vec_splats(val.val_)), _vec1(vec_splats(val.val_)) {} + + static Vectorized C10_ALWAYS_INLINE + loadu(const void* ptr, int count = size()) { + if (count == size()) { + return { + vec_vsx_ld(offset0, reinterpret_cast(ptr)), + vec_vsx_ld(offset16, reinterpret_cast(ptr))}; + } + + __at_align__ value_type tmp_values[size()] = {}; + std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); + + return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)}; + } + void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const { + if (count == size()) { + vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr)); + vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr)); + } else if (count > 0) { + __at_align__ value_type tmp_values[size()]; + vec_vsx_st(_vec0, offset0, tmp_values); + vec_vsx_st(_vec1, offset16, tmp_values); + std::memcpy( + ptr, tmp_values, std::min(count, size()) * sizeof(value_type)); + } + } + + C10_ALWAYS_INLINE const vec_internal_type& vec0() const { + return _vec0; + } + C10_ALWAYS_INLINE const vec_internal_type& vec1() const { + return _vec1; + } + + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point, + Vectorized scale_zp_premul) const { + vfloat32 float_vals0 = vec_float(_vec0); + vfloat32 float_vals1 = vec_float(_vec1); + vfloat32 scale_vec0 = scale.vec0(); + vfloat32 scale_vec1 = scale.vec1(); + vfloat32 zero_point_vec0 = zero_point.vec0(); + vfloat32 zero_point_vec1 = zero_point.vec1(); + + vfloat32 vec_sub_zero_point_0 = vec_sub(float_vals0, zero_point_vec0); + vfloat32 vec_sub_zero_point_1 = vec_sub(float_vals1, zero_point_vec1); + Vectorized vf0 = { + vec_mul(scale_vec0, vec_sub_zero_point_0), + vec_mul(scale_vec1, vec_sub_zero_point_1)}; + return {vf0}; + } + + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point) const { + vfloat32 float_vals0 = vec_float(_vec0); + vfloat32 float_vals1 = vec_float(_vec1); + vfloat32 scale_vec0 = scale.vec0(); + vfloat32 scale_vec1 = scale.vec1(); + vfloat32 zero_point0 = zero_point.vec0(); + vfloat32 zero_point1 = zero_point.vec1(); + return {Vectorized{ + (float_vals0 - zero_point0) * scale_vec0, + (float_vals1 - zero_point1) * scale_vec1}}; + } + + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale) { + Vectorized retval; + + const vint32 vmin = vec_splats(std::numeric_limits::min()); + const vint32 vmax = vec_splats(std::numeric_limits::max()); + vfloat32 inverse_scale_v = vec_splats(inverse_scale); + vfloat32 vec_zero_point = vec_splats((float)(zero_point)); + Vectorized vf0 = rhs[0]; + + vfloat32 vecf0 = vf0.vec0(); + vfloat32 vecf1 = vf0.vec1(); + vecf0 = vec_mul(vecf0, inverse_scale_v); + vecf1 = vec_mul(vecf1, inverse_scale_v); + vecf0 = vec_add(vec_rint(vecf0), vec_zero_point); + vecf1 = vec_add(vec_rint(vecf1), vec_zero_point); + vint32 veci0 = vec_signed(vecf0); + vint32 veci1 = vec_signed(vecf1); + + veci0 = vec_max(veci0, vmin); + veci1 = vec_max(veci1, vmin); + veci0 = vec_min(veci0, vmax); + veci1 = vec_min(veci1, vmax); + + return {veci0, veci1}; + } + + Vectorized relu(Vectorized zero_point) const { + return {vec_max(_vec0, zero_point._vec0), vec_max(_vec1, zero_point._vec1)}; + } + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) const { + vint32 max0 = vec_max(_vec0, zero_point._vec0); + vint32 max1 = vec_max(_vec1, zero_point._vec1); + return {vec_min(max0, q_six._vec0), vec_min(max1, q_six._vec1)}; + } + + int_vec_return_type widening_subtract(Vectorized b) const { + return {*this - b}; + } + + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + const vint32 vmin = vec_splats(std::numeric_limits::min()); + const vint32 vmax = vec_splats(std::numeric_limits::max()); + vfloat32 vec_mult = vec_splats(multiplier); + vint32 vec_zero_point = vec_splats(zero_point); + Vectorized vi = inp[0]; + vfloat32 vecf0 = vec_float(vi.vec0()); + vfloat32 vecf1 = vec_float(vi.vec1()); + + vecf0 = vec_mul(vecf0, vec_mult); + vecf1 = vec_mul(vecf1, vec_mult); + + vecf0 = vec_rint(vecf0); + vecf1 = vec_rint(vecf1); + + vint32 veci0 = vec_add(vec_signed(vecf0), vec_zero_point); + vint32 veci1 = vec_add(vec_signed(vecf1), vec_zero_point); + + veci0 = vec_max(veci0, vmin); + veci1 = vec_max(veci1, vmin); + veci0 = vec_min(veci0, vmax); + veci1 = vec_min(veci1, vmax); + + return {veci0, veci1}; + } + + DEFINE_MEMBER_OP(operator==, c10::qint32, vec_cmpeq) + DEFINE_MEMBER_OP(operator!=, c10::qint32, vec_cmpne) + DEFINE_MEMBER_OP(operator<, c10::qint32, vec_cmplt) + DEFINE_MEMBER_OP(operator<=, c10::qint32, vec_cmple) + DEFINE_MEMBER_OP(operator>, c10::qint32, vec_cmpgt) + DEFINE_MEMBER_OP(operator>=, c10::qint32, vec_cmpge) + DEFINE_MEMBER_OP(operator+, c10::qint32, vec_add) + DEFINE_MEMBER_OP(operator-, c10::qint32, vec_sub) + DEFINE_MEMBER_OP(operator*, c10::qint32, vec_mul) + DEFINE_MEMBER_EMULATE_BINARY_OP(operator/, c10::qint32, /) + DEFINE_MEMBER_OP(maximum, c10::qint32, vec_max) + DEFINE_MEMBER_OP(minimum, c10::qint32, vec_min) + DEFINE_MEMBER_OP(operator&, c10::qint32, vec_and) + DEFINE_MEMBER_OP(operator|, c10::qint32, vec_or) + DEFINE_MEMBER_OP(operator^, c10::qint32, vec_xor) +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return a.minimum(b); +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator+(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_add(a.vec0(), b.vec0()), vec_add(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator-(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_sub(a.vec0(), b.vec0()), vec_sub(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator*(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_mul(a.vec0(), b.vec0()), vec_mul(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator/(const Vectorized& a, const Vectorized& b) { + return Vectorized{a.vec0() / b.vec0(), a.vec1() / b.vec1()}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator&(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_and(a.vec0(), b.vec0()), vec_and(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator|(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_or(a.vec0(), b.vec0()), vec_or(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator^(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_xor(a.vec0(), b.vec0()), vec_xor(a.vec1(), b.vec1())}; +} + +} // namespace CPU_CAPABILITY +} // namespace vec +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_qint8_vsx.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_qint8_vsx.h new file mode 100644 index 0000000000000000000000000000000000000000..3fb5b62c5c0d898bd0fba05898123b7fa53bed5e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_qint8_vsx.h @@ -0,0 +1,517 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include + +// This file defines Vectorized<> for the quantized types. +// +// +// Currently, we simply use these classes as efficient converters between +// the quantized types and Vectorized, usually in bandwidth-bound cases +// where doing the arithmetic in full-precision is acceptable (e.g. +// elementwise operators). +// +// +// Conversions are as follows: +// Vectorized -> 4x Vectorized +// +// The size of the returned float vector is specified by the special +// constexpr function float_num_vecs. The type of the value returned +// from dequantize (and expected as an argument to quantize) is +// specified by float_vec_return_type. +// +// When writing kernels with these vectors, it is expected that floating- +// point operations will be carried out in a loop over +// Vectorized::float_num_vecs iterations. + +namespace at { +namespace vec { +inline namespace CPU_CAPABILITY { + +template <> +struct is_vec_specialized_for : std::bool_constant {}; +template <> +struct Vectorized { + private: + union { + struct { + vint8 _vec0; + vint8 _vec1; + }; + struct { + vbool8 _vecb0; + vbool8 _vecb1; + }; + + } __attribute__((__may_alias__)); + + public: + Vectorized() {} + using size_type = int; + static constexpr size_type size() { + return 32; + } + + static constexpr size_t float_num_vecs() { + return 4; + } + static constexpr int int_num_vecs() { + return 4; + } + using float_vec_return_type = std::array, 4>; + using int_vec_return_type = std::array, 4>; + using value_type = typename c10::qint8::underlying; + using vec_internal_type = vint8; + using vec_internal_mask_type = vbool8; + // Broadcast constructor + C10_ALWAYS_INLINE Vectorized(const c10::qint8& val) + : _vec0{vec_splats(val.val_)}, _vec1{vec_splats(val.val_)} {} + + C10_ALWAYS_INLINE Vectorized(const Vectorized& other) + : _vec0{other._vec0}, _vec1(other._vec1) {} + + C10_ALWAYS_INLINE Vectorized(vint8 v) : _vec0{v}, _vec1{v} {} + C10_ALWAYS_INLINE Vectorized(vbool8 vmask) : _vecb0{vmask}, _vecb1{vmask} {} + C10_ALWAYS_INLINE Vectorized(vint8 v1, vint8 v2) : _vec0{v1}, _vec1{v2} {} + C10_ALWAYS_INLINE Vectorized(vbool8 v1, vbool8 v2) : _vecb0{v1}, _vecb1{v2} {} + + C10_ALWAYS_INLINE const vec_internal_type& vec0() const { + return _vec0; + } + C10_ALWAYS_INLINE const vec_internal_type& vec1() const { + return _vec1; + } + + static C10_ALWAYS_INLINE Vectorized loadu( + const void* ptr, + int count = size()) { + if (count == size()) { + return { + vec_vsx_ld(offset0, reinterpret_cast(ptr)), + vec_vsx_ld(offset16, reinterpret_cast(ptr))}; + } + __at_align__ value_type tmp_values[size()] = {}; + std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); + return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)}; + } + void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const { + if (count == size()) { + vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr)); + vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr)); + } else if (count > 0) { + __at_align__ value_type tmp_values[size()]; + vec_vsx_st(_vec0, offset0, tmp_values); + vec_vsx_st(_vec1, offset16, tmp_values); + std::memcpy( + ptr, tmp_values, std::min(count, size()) * sizeof(value_type)); + } + } + + public: + float_vec_return_type C10_ALWAYS_INLINE dequantize( + Vectorized scale, + Vectorized zero_point, + Vectorized scale_zp_premul) const { + vint16 vecshi0 = vec_unpackh(_vec0); + vint16 vecshi1 = vec_unpackl(_vec0); + + vint16 vecshi2 = vec_unpackh(_vec1); + vint16 vecshi3 = vec_unpackl(_vec1); + + vint32 veci0 = vec_unpackh(vecshi0); + vint32 veci1 = vec_unpackl(vecshi0); + + vint32 veci2 = vec_unpackh(vecshi1); + vint32 veci3 = vec_unpackl(vecshi1); + + vint32 veci4 = vec_unpackh(vecshi2); + vint32 veci5 = vec_unpackl(vecshi2); + + vint32 veci6 = vec_unpackh(vecshi3); + vint32 veci7 = vec_unpackl(vecshi3); + + vfloat32 vecf0_0 = vec_float(veci0); + vfloat32 vecf1_0 = vec_float(veci1); + + vfloat32 vecf0_1 = vec_float(veci2); + vfloat32 vecf1_1 = vec_float(veci3); + + vfloat32 vecf0_2 = vec_float(veci4); + vfloat32 vecf1_2 = vec_float(veci5); + + vfloat32 vecf0_3 = vec_float(veci6); + vfloat32 vecf1_3 = vec_float(veci7); + vfloat32 scale_vec0 = scale.vec0(); + vfloat32 scale_vec1 = scale.vec1(); + + vfloat32 zero_point_vec0 = zero_point.vec0(); + vfloat32 zero_point_vec1 = zero_point.vec1(); + + vfloat32 vec_substract_src_zp0_0 = vec_sub(vecf0_0, zero_point_vec0); + vfloat32 vec_substract_src_zp1_0 = vec_sub(vecf1_0, zero_point_vec1); + Vectorized vf0_zp = { + vec_mul(scale_vec0, vec_substract_src_zp0_0), + vec_mul(scale_vec1, vec_substract_src_zp1_0)}; + + vfloat32 vec_substract_src_zp0_1 = vec_sub(vecf0_1, zero_point_vec0); + vfloat32 vec_substract_src_zp1_1 = vec_sub(vecf1_1, zero_point_vec1); + Vectorized vf1_zp = { + vec_mul(scale_vec0, vec_substract_src_zp0_1), + vec_mul(scale_vec1, vec_substract_src_zp1_1)}; + + vfloat32 vec_substract_src_zp0_2 = vec_sub(vecf0_2, zero_point_vec0); + vfloat32 vec_substract_src_zp1_2 = vec_sub(vecf1_2, zero_point_vec1); + Vectorized vf2_zp = { + vec_mul(scale_vec0, vec_substract_src_zp0_2), + vec_mul(scale_vec1, vec_substract_src_zp1_2)}; + + vfloat32 vec_substract_src_zp0_3 = vec_sub(vecf0_3, zero_point_vec0); + vfloat32 vec_substract_src_zp1_3 = vec_sub(vecf1_3, zero_point_vec1); + Vectorized vf3_zp = { + vec_mul(scale_vec0, vec_substract_src_zp0_3), + vec_mul(scale_vec1, vec_substract_src_zp1_3)}; + + return {vf0_zp, vf1_zp, vf2_zp, vf3_zp}; + } + + float_vec_return_type C10_ALWAYS_INLINE + dequantize(Vectorized scale, Vectorized zero_point) const { + vint16 vecshi0 = vec_unpackh(_vec0); + vint16 vecshi1 = vec_unpackl(_vec0); + + vint16 vecshi2 = vec_unpackh(_vec1); + vint16 vecshi3 = vec_unpackl(_vec1); + + vint32 veci0 = vec_unpackh(vecshi0); + vint32 veci1 = vec_unpackl(vecshi0); + + vint32 veci2 = vec_unpackh(vecshi1); + vint32 veci3 = vec_unpackl(vecshi1); + + vint32 veci4 = vec_unpackh(vecshi2); + vint32 veci5 = vec_unpackl(vecshi2); + + vint32 veci6 = vec_unpackh(vecshi3); + vint32 veci7 = vec_unpackl(vecshi3); + + vfloat32 vecf0_0 = vec_float(veci0); + vfloat32 vecf1_0 = vec_float(veci1); + + vfloat32 vecf0_1 = vec_float(veci2); + vfloat32 vecf1_1 = vec_float(veci3); + + vfloat32 vecf0_2 = vec_float(veci4); + vfloat32 vecf1_2 = vec_float(veci5); + + vfloat32 vecf0_3 = vec_float(veci6); + vfloat32 vecf1_3 = vec_float(veci7); + vfloat32 scale_vec0 = scale.vec0(); + vfloat32 scale_vec1 = scale.vec1(); + vfloat32 zero_point0 = zero_point.vec0(); + vfloat32 zero_point1 = zero_point.vec1(); + return { + Vectorized{ + (vecf0_0 - zero_point0) * scale_vec0, + (vecf1_0 - zero_point1) * scale_vec1}, + Vectorized{ + (vecf0_1 - zero_point0) * scale_vec0, + (vecf1_1 - zero_point1) * scale_vec1}, + Vectorized{ + (vecf0_2 - zero_point0) * scale_vec0, + (vecf1_2 - zero_point1) * scale_vec1}, + Vectorized{ + (vecf0_3 - zero_point0) * scale_vec0, + (vecf1_3 - zero_point1) * scale_vec1}}; + } + + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale) { + // constexpr int32_t min_val = std::numeric_limits::min(); + // constexpr int32_t max_val = std::numeric_limits::max(); + + vfloat32 inverse_scale_v = vec_splats(inverse_scale); + vfloat32 vec_zero_point = vec_splats((float)zero_point); + // vint32 vmin = vec_splats(min_val); + // vint32 vmax = vec_splats(max_val); + + Vectorized vf0 = rhs[0]; + Vectorized vf1 = rhs[1]; + Vectorized vf2 = rhs[2]; + Vectorized vf3 = rhs[3]; + vfloat32 vecf0 = vf0.vec0(); + vfloat32 vecf1 = vf0.vec1(); + vfloat32 vecf2 = vf1.vec0(); + vfloat32 vecf3 = vf1.vec1(); + + vfloat32 vecf4 = vf2.vec0(); + vfloat32 vecf5 = vf2.vec1(); + vfloat32 vecf6 = vf3.vec0(); + vfloat32 vecf7 = vf3.vec1(); + + vecf0 = vec_mul(vecf0, inverse_scale_v); + vecf1 = vec_mul(vecf1, inverse_scale_v); + vecf2 = vec_mul(vecf2, inverse_scale_v); + vecf3 = vec_mul(vecf3, inverse_scale_v); + + vecf4 = vec_mul(vecf4, inverse_scale_v); + vecf5 = vec_mul(vecf5, inverse_scale_v); + vecf6 = vec_mul(vecf6, inverse_scale_v); + vecf7 = vec_mul(vecf7, inverse_scale_v); + + vecf0 = vec_add(vec_rint(vecf0), vec_zero_point); + vecf1 = vec_add(vec_rint(vecf1), vec_zero_point); + vecf2 = vec_add(vec_rint(vecf2), vec_zero_point); + vecf3 = vec_add(vec_rint(vecf3), vec_zero_point); + + vecf4 = vec_add(vec_rint(vecf4), vec_zero_point); + vecf5 = vec_add(vec_rint(vecf5), vec_zero_point); + vecf6 = vec_add(vec_rint(vecf6), vec_zero_point); + vecf7 = vec_add(vec_rint(vecf7), vec_zero_point); + + vint32 veci0 = vec_signed(vecf0); + vint32 veci1 = vec_signed(vecf1); + vint32 veci2 = vec_signed(vecf2); + vint32 veci3 = vec_signed(vecf3); + + vint32 veci4 = vec_signed(vecf4); + vint32 veci5 = vec_signed(vecf5); + vint32 veci6 = vec_signed(vecf6); + vint32 veci7 = vec_signed(vecf7); + + // veci0 = vec_min(vmax, vec_max( vmin, vecf0)) ; + // veci1 = vec_min(vmax, vec_max( vmin, vecf1)) ; + // veci2 = vec_min(vmax, vec_max( vmin, vecf2)) ; + // veci3 = vec_min(vmax, vec_max( vmin, vecf3)) ; + + // veci4 = vec_min(vmax, vec_max( vmin, vecf4)) ; + // veci5 = vec_min(vmax, vec_max( vmin, vecf5)) ; + // veci6 = vec_min(vmax, vec_max( vmin, vecf6)) ; + // veci7 = vec_min(vmax, vec_max( vmin, vecf7)) ; + // vec_packs CLAMP already + vint16 vecshi0 = vec_packs(veci0, veci1); + vint16 vecshi1 = vec_packs(veci2, veci3); + vint16 vecshi2 = vec_packs(veci4, veci5); + vint16 vecshi3 = vec_packs(veci6, veci7); + + vint8 vec0 = vec_packs(vecshi0, vecshi1); + vint8 vec1 = vec_packs(vecshi2, vecshi3); + + return {vec0, vec1}; + } + + Vectorized C10_ALWAYS_INLINE + relu(Vectorized zero_point) const { + return {vec_max(_vec0, zero_point._vec0), vec_max(_vec1, zero_point._vec1)}; + } + + Vectorized C10_ALWAYS_INLINE + relu6(Vectorized zero_point, Vectorized q_six) const { + vint8 max0 = vec_max(_vec0, zero_point._vec0); + vint8 max1 = vec_max(_vec1, zero_point._vec1); + return {vec_min(max0, q_six._vec0), vec_min(max1, q_six._vec1)}; + } + + int_vec_return_type widening_subtract(Vectorized b) const { + vint16 vecshi0 = vec_unpackh(_vec0); + vint16 vecBshi0 = vec_unpackh(b._vec0); + vint16 vecshi1 = vec_unpackl(_vec0); + vint16 vecBshi1 = vec_unpackl(b._vec0); + + vint16 vecshi2 = vec_unpackh(_vec1); + vint16 vecBshi2 = vec_unpackh(b._vec1); + vint16 vecshi3 = vec_unpackl(_vec1); + vint16 vecBshi3 = vec_unpackl(b._vec1); + + vint32 veci0 = vec_unpackh(vecshi0); + vint32 vecBi0 = vec_unpackh(vecBshi0); + vint32 veci1 = vec_unpackl(vecshi0); + vint32 vecBi1 = vec_unpackl(vecBshi0); + + vint32 veci2 = vec_unpackh(vecshi1); + vint32 vecBi2 = vec_unpackh(vecBshi1); + vint32 veci3 = vec_unpackl(vecshi1); + vint32 vecBi3 = vec_unpackl(vecBshi1); + + vint32 veci4 = vec_unpackh(vecshi2); + vint32 vecBi4 = vec_unpackh(vecBshi2); + vint32 veci5 = vec_unpackl(vecshi2); + vint32 vecBi5 = vec_unpackl(vecBshi2); + + vint32 veci6 = vec_unpackh(vecshi3); + vint32 vecBi6 = vec_unpackh(vecBshi3); + vint32 veci7 = vec_unpackl(vecshi3); + vint32 vecBi7 = vec_unpackl(vecBshi3); + + return { + Vectorized(veci0 - vecBi0, veci1 - vecBi1), + Vectorized(veci2 - vecBi2, veci3 - vecBi3), + Vectorized(veci4 - vecBi4, veci5 - vecBi5), + Vectorized(veci6 - vecBi6, veci7 - vecBi7)}; + } + + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + vfloat32 vec_multiplier = vec_splats(multiplier); + vint32 vec_zero_point = vec_splats(zero_point); + + Vectorized vi0 = inp[0]; + Vectorized vi1 = inp[1]; + Vectorized vi2 = inp[2]; + Vectorized vi3 = inp[3]; + + vfloat32 vecf0 = vec_float(vi0.vec0()); + vfloat32 vecf1 = vec_float(vi0.vec1()); + vfloat32 vecf2 = vec_float(vi1.vec0()); + vfloat32 vecf3 = vec_float(vi1.vec1()); + + vfloat32 vecf4 = vec_float(vi2.vec0()); + vfloat32 vecf5 = vec_float(vi2.vec1()); + vfloat32 vecf6 = vec_float(vi3.vec0()); + vfloat32 vecf7 = vec_float(vi3.vec1()); + + vecf0 = vec_mul(vecf0, vec_multiplier); + vecf1 = vec_mul(vecf1, vec_multiplier); + vecf2 = vec_mul(vecf2, vec_multiplier); + vecf3 = vec_mul(vecf3, vec_multiplier); + + vecf4 = vec_mul(vecf4, vec_multiplier); + vecf5 = vec_mul(vecf5, vec_multiplier); + vecf6 = vec_mul(vecf6, vec_multiplier); + vecf7 = vec_mul(vecf7, vec_multiplier); + + vecf0 = vec_rint(vecf0); + vecf1 = vec_rint(vecf1); + vecf2 = vec_rint(vecf2); + vecf3 = vec_rint(vecf3); + + vecf4 = vec_rint(vecf4); + vecf5 = vec_rint(vecf5); + vecf6 = vec_rint(vecf6); + vecf7 = vec_rint(vecf7); + + vint32 veci0 = vec_signed(vecf0); + vint32 veci1 = vec_signed(vecf1); + vint32 veci2 = vec_signed(vecf2); + vint32 veci3 = vec_signed(vecf3); + + vint32 veci4 = vec_signed(vecf4); + vint32 veci5 = vec_signed(vecf5); + vint32 veci6 = vec_signed(vecf6); + vint32 veci7 = vec_signed(vecf7); + + veci0 = vec_add(veci0, vec_zero_point); + veci1 = vec_add(veci1, vec_zero_point); + veci2 = vec_add(veci2, vec_zero_point); + veci3 = vec_add(veci3, vec_zero_point); + + veci4 = vec_add(veci4, vec_zero_point); + veci5 = vec_add(veci5, vec_zero_point); + veci6 = vec_add(veci6, vec_zero_point); + veci7 = vec_add(veci7, vec_zero_point); + + vint16 vecshi0 = vec_packs(veci0, veci1); + vint16 vecshi1 = vec_packs(veci2, veci3); + vint16 vecshi2 = vec_packs(veci4, veci5); + vint16 vecshi3 = vec_packs(veci6, veci7); + + vint8 vec0 = vec_packs(vecshi0, vecshi1); + vint8 vec1 = vec_packs(vecshi2, vecshi3); + + return {vec0, vec1}; + } + + DEFINE_MEMBER_OP(operator==, c10::qint8, vec_cmpeq) + DEFINE_MEMBER_OP(operator!=, c10::qint8, vec_cmpne) + DEFINE_MEMBER_OP(operator<, c10::qint8, vec_cmplt) + DEFINE_MEMBER_OP(operator<=, c10::qint8, vec_cmple) + DEFINE_MEMBER_OP(operator>, c10::qint8, vec_cmpgt) + DEFINE_MEMBER_OP(operator>=, c10::qint8, vec_cmpge) + DEFINE_MEMBER_OP(operator+, c10::qint8, vec_add) + DEFINE_MEMBER_OP(operator-, c10::qint8, vec_sub) + DEFINE_MEMBER_OP(operator*, c10::qint8, vec_mul) + DEFINE_MEMBER_EMULATE_BINARY_OP(operator/, c10::qint8, /) + DEFINE_MEMBER_OP(maximum, c10::qint8, vec_max) + DEFINE_MEMBER_OP(minimum, c10::qint8, vec_min) + DEFINE_MEMBER_OP(operator&, c10::qint8, vec_and) + DEFINE_MEMBER_OP(operator|, c10::qint8, vec_or) + DEFINE_MEMBER_OP(operator^, c10::qint8, vec_xor) +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return a.minimum(b); +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator+(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_add(a.vec0(), b.vec0()), vec_add(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator-(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_sub(a.vec0(), b.vec0()), vec_sub(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator*(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_mul(a.vec0(), b.vec0()), vec_mul(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator/(const Vectorized& a, const Vectorized& b) { + return Vectorized{a.vec0() / b.vec0(), a.vec1() / b.vec1()}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator&(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_and(a.vec0(), b.vec0()), vec_and(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator|(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_or(a.vec0(), b.vec0()), vec_or(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator^(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_xor(a.vec0(), b.vec0()), vec_xor(a.vec1(), b.vec1())}; +} + +} // namespace CPU_CAPABILITY +} // namespace vec +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_quint8_vsx.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_quint8_vsx.h new file mode 100644 index 0000000000000000000000000000000000000000..9da6dec9db5e0314d3b70f8b4f0e5d919f02490d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_quint8_vsx.h @@ -0,0 +1,538 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include + +#include +#include +#include + +// This file defines Vectorized<> for the quantized types. +// +// +// Currently, we simply use these classes as efficient converters between +// the quantized types and Vectorized, usually in bandwidth-bound cases +// where doing the arithmetic in full-precision is acceptable (e.g. +// elementwise operators). +// +// +// Conversions are as follows: +// Vectorized -> 4x Vectorized +// +// The size of the returned float vector is specified by the special +// constexpr function float_num_vecs. The type of the value returned +// from dequantize (and expected as an argument to quantize) is +// specified by float_vec_return_type. +// +// When writing kernels with these vectors, it is expected that floating- +// point operations will be carried out in a loop over +// Vectorized::float_num_vecs iterations. + +namespace at { +namespace vec { +inline namespace CPU_CAPABILITY { + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +const vint16 mask_unsigned = vec_splats((short int)0xFF); +template <> +struct Vectorized { + private: + union { + struct { + vuint8 _vec0; + vuint8 _vec1; + }; + struct { + vbool8 _vecb0; + vbool8 _vecb1; + }; + + } __attribute__((__may_alias__)); + + public: + Vectorized() {} + using size_type = int; + static constexpr size_type size() { + return 32; + } + + static constexpr size_t float_num_vecs() { + return 4; + } + static constexpr int int_num_vecs() { + return 4; + } + using float_vec_return_type = std::array, 4>; + using int_vec_return_type = std::array, 4>; + using value_type = typename c10::quint8::underlying; + using vec_internal_type = vuint8; + using vec_internal_mask_type = vbool8; + // Broadcast constructor + C10_ALWAYS_INLINE Vectorized(const c10::quint8& val) + : _vec0(vec_splats(val.val_)), _vec1(vec_splats(val.val_)) {} + + C10_ALWAYS_INLINE Vectorized(const Vectorized& other) + : _vec0{other._vec0}, _vec1(other._vec1) {} + + C10_ALWAYS_INLINE Vectorized(vuint8 v) : _vec0{v}, _vec1{v} {} + C10_ALWAYS_INLINE Vectorized(vbool8 vmask) : _vecb0{vmask}, _vecb1{vmask} {} + C10_ALWAYS_INLINE Vectorized(vuint8 v1, vuint8 v2) : _vec0{v1}, _vec1{v2} {} + C10_ALWAYS_INLINE Vectorized(vbool8 v1, vbool8 v2) : _vecb0{v1}, _vecb1{v2} {} + + C10_ALWAYS_INLINE const vec_internal_type& vec0() const { + return _vec0; + } + C10_ALWAYS_INLINE const vec_internal_type& vec1() const { + return _vec1; + } + + static C10_ALWAYS_INLINE Vectorized loadu( + const void* ptr, + int count = size()) { + if (count == size()) { + return { + vec_vsx_ld(offset0, reinterpret_cast(ptr)), + vec_vsx_ld(offset16, reinterpret_cast(ptr))}; + } + __at_align__ value_type tmp_values[size()] = {}; + std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); + return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)}; + } + void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const { + if (count == size()) { + vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr)); + vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr)); + } else if (count > 0) { + __at_align__ value_type tmp_values[size()]; + vec_vsx_st(_vec0, offset0, tmp_values); + vec_vsx_st(_vec1, offset16, tmp_values); + std::memcpy( + ptr, tmp_values, std::min(count, size()) * sizeof(value_type)); + } + } + + public: + float_vec_return_type C10_ALWAYS_INLINE dequantize( + Vectorized scale, + Vectorized zero_point, + Vectorized scale_zp_premul) const { + // unpacking unsigned as signed + vint16 vecshi0 = vec_unpackh((vint8)_vec0); + vint16 vecshi1 = vec_unpackl((vint8)_vec0); + + vint16 vecshi2 = vec_unpackh((vint8)_vec1); + vint16 vecshi3 = vec_unpackl((vint8)_vec1); + + // signed -> unsigned + vecshi0 = vec_and(vecshi0, mask_unsigned); + vecshi1 = vec_and(vecshi1, mask_unsigned); + + vecshi2 = vec_and(vecshi2, mask_unsigned); + vecshi3 = vec_and(vecshi3, mask_unsigned); + + vint32 veci0 = vec_unpackh(vecshi0); + vint32 veci1 = vec_unpackl(vecshi0); + + vint32 veci2 = vec_unpackh(vecshi1); + vint32 veci3 = vec_unpackl(vecshi1); + + vint32 veci4 = vec_unpackh(vecshi2); + vint32 veci5 = vec_unpackl(vecshi2); + + vint32 veci6 = vec_unpackh(vecshi3); + vint32 veci7 = vec_unpackl(vecshi3); + + vfloat32 vecf0_0 = vec_float(veci0); + vfloat32 vecf1_0 = vec_float(veci1); + + vfloat32 vecf0_1 = vec_float(veci2); + vfloat32 vecf1_1 = vec_float(veci3); + + vfloat32 vecf0_2 = vec_float(veci4); + vfloat32 vecf1_2 = vec_float(veci5); + + vfloat32 vecf0_3 = vec_float(veci6); + vfloat32 vecf1_3 = vec_float(veci7); + vfloat32 scale_vec0 = scale.vec0(); + vfloat32 scale_vec1 = scale.vec1(); + + vfloat32 zero_point_vec0 = zero_point.vec0(); + vfloat32 zero_point_vec1 = zero_point.vec1(); + + vfloat32 vec_substract_src_zp0_0 = vec_sub(vecf0_0, zero_point_vec0); + vfloat32 vec_substract_src_zp1_0 = vec_sub(vecf1_0, zero_point_vec1); + Vectorized vf0_zp = { + vec_mul(scale_vec0, vec_substract_src_zp0_0), + vec_mul(scale_vec1, vec_substract_src_zp1_0)}; + + vfloat32 vec_substract_src_zp0_1 = vec_sub(vecf0_1, zero_point_vec0); + vfloat32 vec_substract_src_zp1_1 = vec_sub(vecf1_1, zero_point_vec1); + Vectorized vf1_zp = { + vec_mul(scale_vec0, vec_substract_src_zp0_1), + vec_mul(scale_vec1, vec_substract_src_zp1_1)}; + + vfloat32 vec_substract_src_zp0_2 = vec_sub(vecf0_2, zero_point_vec0); + vfloat32 vec_substract_src_zp1_2 = vec_sub(vecf1_2, zero_point_vec1); + Vectorized vf2_zp = { + vec_mul(scale_vec0, vec_substract_src_zp0_2), + vec_mul(scale_vec1, vec_substract_src_zp1_2)}; + + vfloat32 vec_substract_src_zp0_3 = vec_sub(vecf0_3, zero_point_vec0); + vfloat32 vec_substract_src_zp1_3 = vec_sub(vecf1_3, zero_point_vec1); + Vectorized vf3_zp = { + vec_mul(scale_vec0, vec_substract_src_zp0_3), + vec_mul(scale_vec1, vec_substract_src_zp1_3)}; + + return {vf0_zp, vf1_zp, vf2_zp, vf3_zp}; + } + + float_vec_return_type C10_ALWAYS_INLINE + dequantize(Vectorized scale, Vectorized zero_point) const { + // unpacking unsigned as signed + vint16 vecshi0 = vec_unpackh((vint8)_vec0); + vint16 vecshi1 = vec_unpackl((vint8)_vec0); + + vint16 vecshi2 = vec_unpackh((vint8)_vec1); + vint16 vecshi3 = vec_unpackl((vint8)_vec1); + + // signed -> unsigned + vecshi0 = vec_and(vecshi0, mask_unsigned); + vecshi1 = vec_and(vecshi1, mask_unsigned); + + vecshi2 = vec_and(vecshi2, mask_unsigned); + vecshi3 = vec_and(vecshi3, mask_unsigned); + + vint32 veci0 = vec_unpackh(vecshi0); + vint32 veci1 = vec_unpackl(vecshi0); + + vint32 veci2 = vec_unpackh(vecshi1); + vint32 veci3 = vec_unpackl(vecshi1); + + vint32 veci4 = vec_unpackh(vecshi2); + vint32 veci5 = vec_unpackl(vecshi2); + + vint32 veci6 = vec_unpackh(vecshi3); + vint32 veci7 = vec_unpackl(vecshi3); + + vfloat32 vecf0_0 = vec_float(veci0); + vfloat32 vecf1_0 = vec_float(veci1); + + vfloat32 vecf0_1 = vec_float(veci2); + vfloat32 vecf1_1 = vec_float(veci3); + + vfloat32 vecf0_2 = vec_float(veci4); + vfloat32 vecf1_2 = vec_float(veci5); + + vfloat32 vecf0_3 = vec_float(veci6); + vfloat32 vecf1_3 = vec_float(veci7); + vfloat32 scale_vec0 = scale.vec0(); + vfloat32 scale_vec1 = scale.vec1(); + + vfloat32 zero_point0 = zero_point.vec0(); + vfloat32 zero_point1 = zero_point.vec1(); + return { + Vectorized{ + (vecf0_0 - zero_point0) * scale_vec0, + (vecf1_0 - zero_point1) * scale_vec1}, + Vectorized{ + (vecf0_1 - zero_point0) * scale_vec0, + (vecf1_1 - zero_point1) * scale_vec1}, + Vectorized{ + (vecf0_2 - zero_point0) * scale_vec0, + (vecf1_2 - zero_point1) * scale_vec1}, + Vectorized{ + (vecf0_3 - zero_point0) * scale_vec0, + (vecf1_3 - zero_point1) * scale_vec1}}; + } + + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale) { + // constexpr int32_t min_val = std::numeric_limits::min(); + // constexpr int32_t max_val = std::numeric_limits::max(); + + vfloat32 vec_inverse = vec_splats(inverse_scale); + vfloat32 vec_zero_point = vec_splats((float)zero_point); + // vuint32 vmin = vec_splats(min_val); + // vuint32 vmax = vec_splats(max_val); + Vectorized vf0 = rhs[0]; + Vectorized vf1 = rhs[1]; + Vectorized vf2 = rhs[2]; + Vectorized vf3 = rhs[3]; + vfloat32 vecf0 = vf0.vec0(); + vfloat32 vecf1 = vf0.vec1(); + vfloat32 vecf2 = vf1.vec0(); + vfloat32 vecf3 = vf1.vec1(); + + vfloat32 vecf4 = vf2.vec0(); + vfloat32 vecf5 = vf2.vec1(); + vfloat32 vecf6 = vf3.vec0(); + vfloat32 vecf7 = vf3.vec1(); + + vecf0 = vec_mul(vecf0, vec_inverse); + vecf1 = vec_mul(vecf1, vec_inverse); + vecf2 = vec_mul(vecf2, vec_inverse); + vecf3 = vec_mul(vecf3, vec_inverse); + + vecf4 = vec_mul(vecf4, vec_inverse); + vecf5 = vec_mul(vecf5, vec_inverse); + vecf6 = vec_mul(vecf6, vec_inverse); + vecf7 = vec_mul(vecf7, vec_inverse); + + vecf0 = vec_add(vec_rint(vecf0), vec_zero_point); + vecf1 = vec_add(vec_rint(vecf1), vec_zero_point); + vecf2 = vec_add(vec_rint(vecf2), vec_zero_point); + vecf3 = vec_add(vec_rint(vecf3), vec_zero_point); + + vecf4 = vec_add(vec_rint(vecf4), vec_zero_point); + vecf5 = vec_add(vec_rint(vecf5), vec_zero_point); + vecf6 = vec_add(vec_rint(vecf6), vec_zero_point); + vecf7 = vec_add(vec_rint(vecf7), vec_zero_point); + + vint32 veci0 = vec_signed(vecf0); + vint32 veci1 = vec_signed(vecf1); + vint32 veci2 = vec_signed(vecf2); + vint32 veci3 = vec_signed(vecf3); + + vint32 veci4 = vec_signed(vecf4); + vint32 veci5 = vec_signed(vecf5); + vint32 veci6 = vec_signed(vecf6); + vint32 veci7 = vec_signed(vecf7); + + vint16 vecshi0 = vec_packs(veci0, veci1); + vint16 vecshi1 = vec_packs(veci2, veci3); + vint16 vecshi2 = vec_packs(veci4, veci5); + vint16 vecshi3 = vec_packs(veci6, veci7); + + vuint8 vec0 = vec_packsu(vecshi0, vecshi1); + vuint8 vec1 = vec_packsu(vecshi2, vecshi3); + + return {vec0, vec1}; + } + + Vectorized C10_ALWAYS_INLINE + relu(Vectorized zero_point) const { + return {vec_max(_vec0, zero_point._vec0), vec_max(_vec1, zero_point._vec1)}; + } + + Vectorized C10_ALWAYS_INLINE relu6( + Vectorized zero_point, + Vectorized q_six) const { + vuint8 max0 = vec_max(_vec0, zero_point._vec0); + vuint8 max1 = vec_max(_vec1, zero_point._vec1); + return {vec_min(max0, q_six._vec0), vec_min(max1, q_six._vec1)}; + } + + int_vec_return_type widening_subtract(Vectorized b) const { + vint16 vecshi0 = vec_unpackh((vint8)_vec0); + vint16 vecBshi0 = vec_unpackh((vint8)b._vec0); + vint16 vecshi1 = vec_unpackl((vint8)_vec0); + vint16 vecBshi1 = vec_unpackl((vint8)b._vec0); + + vint16 vecshi2 = vec_unpackh((vint8)_vec1); + vint16 vecBshi2 = vec_unpackh((vint8)b._vec1); + vint16 vecshi3 = vec_unpackl((vint8)_vec1); + vint16 vecBshi3 = vec_unpackl((vint8)b._vec1); + + vecshi0 = vec_and(vecshi0, mask_unsigned); + vecBshi0 = vec_and(vecBshi0, mask_unsigned); + vecshi1 = vec_and(vecshi1, mask_unsigned); + vecBshi1 = vec_and(vecBshi1, mask_unsigned); + + vecshi2 = vec_and(vecshi2, mask_unsigned); + vecBshi2 = vec_and(vecBshi2, mask_unsigned); + vecshi3 = vec_and(vecshi3, mask_unsigned); + vecBshi3 = vec_and(vecBshi3, mask_unsigned); + + vint32 veci0 = vec_unpackh(vecshi0); + vint32 vecBi0 = vec_unpackh(vecBshi0); + vint32 veci1 = vec_unpackl(vecshi0); + vint32 vecBi1 = vec_unpackl(vecBshi0); + + vint32 veci2 = vec_unpackh(vecshi1); + vint32 vecBi2 = vec_unpackh(vecBshi1); + vint32 veci3 = vec_unpackl(vecshi1); + vint32 vecBi3 = vec_unpackl(vecBshi1); + + vint32 veci4 = vec_unpackh(vecshi2); + vint32 vecBi4 = vec_unpackh(vecBshi2); + vint32 veci5 = vec_unpackl(vecshi2); + vint32 vecBi5 = vec_unpackl(vecBshi2); + + vint32 veci6 = vec_unpackh(vecshi3); + vint32 vecBi6 = vec_unpackh(vecBshi3); + vint32 veci7 = vec_unpackl(vecshi3); + vint32 vecBi7 = vec_unpackl(vecBshi3); + + return { + Vectorized(veci0 - vecBi0, veci1 - vecBi1), + Vectorized(veci2 - vecBi2, veci3 - vecBi3), + Vectorized(veci4 - vecBi4, veci5 - vecBi5), + Vectorized(veci6 - vecBi6, veci7 - vecBi7)}; + } + + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + vfloat32 vec_multiplier = vec_splats(multiplier); + vint32 vec_zero_point = vec_splats(zero_point); + + Vectorized vi0 = inp[0]; + Vectorized vi1 = inp[1]; + Vectorized vi2 = inp[2]; + Vectorized vi3 = inp[3]; + + vfloat32 vecf0 = vec_float(vi0.vec0()); + vfloat32 vecf1 = vec_float(vi0.vec1()); + vfloat32 vecf2 = vec_float(vi1.vec0()); + vfloat32 vecf3 = vec_float(vi1.vec1()); + + vfloat32 vecf4 = vec_float(vi2.vec0()); + vfloat32 vecf5 = vec_float(vi2.vec1()); + vfloat32 vecf6 = vec_float(vi3.vec0()); + vfloat32 vecf7 = vec_float(vi3.vec1()); + + vecf0 = vec_mul(vecf0, vec_multiplier); + vecf1 = vec_mul(vecf1, vec_multiplier); + vecf2 = vec_mul(vecf2, vec_multiplier); + vecf3 = vec_mul(vecf3, vec_multiplier); + + vecf4 = vec_mul(vecf4, vec_multiplier); + vecf5 = vec_mul(vecf5, vec_multiplier); + vecf6 = vec_mul(vecf6, vec_multiplier); + vecf7 = vec_mul(vecf7, vec_multiplier); + + vecf0 = vec_rint(vecf0); + vecf1 = vec_rint(vecf1); + vecf2 = vec_rint(vecf2); + vecf3 = vec_rint(vecf3); + + vecf4 = vec_rint(vecf4); + vecf5 = vec_rint(vecf5); + vecf6 = vec_rint(vecf6); + vecf7 = vec_rint(vecf7); + + vint32 veci0 = vec_signed(vecf0); + vint32 veci1 = vec_signed(vecf1); + vint32 veci2 = vec_signed(vecf2); + vint32 veci3 = vec_signed(vecf3); + + vint32 veci4 = vec_signed(vecf4); + vint32 veci5 = vec_signed(vecf5); + vint32 veci6 = vec_signed(vecf6); + vint32 veci7 = vec_signed(vecf7); + + veci0 = vec_add(veci0, vec_zero_point); + veci1 = vec_add(veci1, vec_zero_point); + veci2 = vec_add(veci2, vec_zero_point); + veci3 = vec_add(veci3, vec_zero_point); + + veci4 = vec_add(veci4, vec_zero_point); + veci5 = vec_add(veci5, vec_zero_point); + veci6 = vec_add(veci6, vec_zero_point); + veci7 = vec_add(veci7, vec_zero_point); + + vint16 vecshi0 = vec_packs(veci0, veci1); + vint16 vecshi1 = vec_packs(veci2, veci3); + vint16 vecshi2 = vec_packs(veci4, veci5); + vint16 vecshi3 = vec_packs(veci6, veci7); + + vuint8 vec0 = vec_packsu(vecshi0, vecshi1); + vuint8 vec1 = vec_packsu(vecshi2, vecshi3); + + return {vec0, vec1}; + } + + DEFINE_MEMBER_OP(operator==, c10::quint8, vec_cmpeq) + DEFINE_MEMBER_OP(operator!=, c10::quint8, vec_cmpne) + DEFINE_MEMBER_OP(operator<, c10::quint8, vec_cmplt) + DEFINE_MEMBER_OP(operator<=, c10::quint8, vec_cmple) + DEFINE_MEMBER_OP(operator>, c10::quint8, vec_cmpgt) + DEFINE_MEMBER_OP(operator>=, c10::quint8, vec_cmpge) + DEFINE_MEMBER_OP(operator+, c10::quint8, vec_add) + DEFINE_MEMBER_OP(operator-, c10::quint8, vec_sub) + DEFINE_MEMBER_OP(operator*, c10::quint8, vec_mul) + DEFINE_MEMBER_EMULATE_BINARY_OP(operator/, c10::quint8, /) + DEFINE_MEMBER_OP(maximum, c10::quint8, vec_max) + DEFINE_MEMBER_OP(minimum, c10::quint8, vec_min) + DEFINE_MEMBER_OP(operator&, c10::quint8, vec_and) + DEFINE_MEMBER_OP(operator|, c10::quint8, vec_or) + DEFINE_MEMBER_OP(operator^, c10::quint8, vec_xor) +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return a.minimum(b); +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator+(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_add(a.vec0(), b.vec0()), vec_add(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator-(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_sub(a.vec0(), b.vec0()), vec_sub(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator*(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_mul(a.vec0(), b.vec0()), vec_mul(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator/(const Vectorized& a, const Vectorized& b) { + return Vectorized{a.vec0() / b.vec0(), a.vec1() / b.vec1()}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator&(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_and(a.vec0(), b.vec0()), vec_and(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator|(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_or(a.vec0(), b.vec0()), vec_or(a.vec1(), b.vec1())}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +operator^(const Vectorized& a, const Vectorized& b) { + return Vectorized{ + vec_xor(a.vec0(), b.vec0()), vec_xor(a.vec1(), b.vec1())}; +} + +} // namespace CPU_CAPABILITY +} // namespace vec +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vsx_helpers.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vsx_helpers.h new file mode 100644 index 0000000000000000000000000000000000000000..a25216bd5db17b5a732f7bdb3ebd4047eef1e24f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vsx_helpers.h @@ -0,0 +1,581 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include +#include +#include +#include + +#if defined(__clang__) +typedef __vector __bool char vbool8; +typedef __vector __bool short vbool16; +typedef __vector __bool int vbool32; +typedef __vector __bool long long vbool64; +using vint8 = __attribute__((vector_size(16))) signed char; +using vint16 = __attribute__((vector_size(16))) signed short; +using vint32 = __attribute__((vector_size(16))) signed int; +using vint64 = __attribute__((vector_size(16))) signed long long; +using vuint8 = __attribute__((vector_size(16))) unsigned char; +using vuint16 = __attribute__((vector_size(16))) unsigned short; +using vuint32 = __attribute__((vector_size(16))) unsigned int; +using vuint64 = __attribute__((vector_size(16))) unsigned long long; +using vfloat32 = __attribute__((vector_size(16))) float; +using vfloat64 = __attribute__((vector_size(16))) double; +#else +using vbool8 = + __attribute__((altivec(vector__))) __attribute__((altivec(bool__))) char; +using vbool16 = + __attribute__((altivec(vector__))) __attribute__((altivec(bool__))) short; +using vbool32 = + __attribute__((altivec(vector__))) __attribute__((altivec(bool__))) int; +using vbool64 = __attribute__((altivec(vector__))) +__attribute__((altivec(bool__))) long long; +using vint8 = __attribute__((altivec(vector__))) signed char; +using vint16 = __attribute__((altivec(vector__))) signed short; +using vint32 = __attribute__((altivec(vector__))) signed int; +using vint64 = __attribute__((altivec(vector__))) signed long long; +using vuint8 = __attribute__((altivec(vector__))) unsigned char; +using vuint16 = __attribute__((altivec(vector__))) unsigned short; +using vuint32 = __attribute__((altivec(vector__))) unsigned int; +using vuint64 = __attribute__((altivec(vector__))) unsigned long long; +using vfloat32 = __attribute__((altivec(vector__))) float; +using vfloat64 = __attribute__((altivec(vector__))) double; +#endif + +inline auto make_vuint(vint8 v) { + return reinterpret_cast(v); +} +inline auto make_vuint(vint16 v) { + return reinterpret_cast(v); +} +inline auto make_vuint(vint32 v) { + return reinterpret_cast(v); +} +inline auto make_vuint(vint64 v) { + return reinterpret_cast(v); +} + +#if !defined(vec_float) +C10_ALWAYS_INLINE vfloat32 vec_float(const vint32& vec_in) { + vfloat32 vec_out; + __asm__("xvcvsxwsp %x0,%x1" : "=wf"(vec_out) : "wa"(vec_in)); + return vec_out; +} +#endif + +#if !defined(vec_signed) +C10_ALWAYS_INLINE vint32 vec_signed(const vfloat32& vec_in) { + vint32 vec_out; + __asm__("xvcvspsxws %x0,%x1" : "=wa"(vec_out) : "wf"(vec_in)); + return vec_out; +} + +C10_ALWAYS_INLINE vint64 vec_signed(const vfloat64& vec_in) { + vint64 vec_out; + __asm__("xvcvdpsxds %x0,%x1" : "=wa"(vec_out) : "wd"(vec_in)); + return vec_out; +} +#endif + +#if !defined(vec_neg) +C10_ALWAYS_INLINE vfloat32 vec_neg(const vfloat32& vec_in) { + vfloat32 vec_out; + __asm__("xvnegsp %x0,%x1" : "=wf"(vec_out) : "wf"(vec_in)); + return vec_out; +} + +C10_ALWAYS_INLINE vfloat64 vec_neg(const vfloat64& vec_in) { + vfloat64 vec_out; + __asm__("xvnegdp %x0,%x1" : "=wd"(vec_out) : "wd"(vec_in)); + return vec_out; +} + +C10_ALWAYS_INLINE vint16 vec_neg(const vint16& vec_in) { + vint16 vint0 = {0, 0, 0, 0, 0, 0, 0, 0}; + return vec_vsubuhm(vint0, vec_in); +} + +C10_ALWAYS_INLINE vint32 vec_neg(const vint32& vec_in) { + vint32 vint0 = {0, 0, 0, 0}; + return vec_vsubuwm(vint0, vec_in); +} + +C10_ALWAYS_INLINE vint64 vec_neg(const vint64& vec_in) { + return -vec_in; +} +#endif + +#if !defined(vec_sldw) +template +C10_ALWAYS_INLINE vfloat32 +vec_sldw_aux(const vfloat32& vec_in0, const vfloat32& vec_in1) { + vfloat32 vec_out; + __asm("xxsldwi %x0, %x1, %x2, %3 " + : "=wa"(vec_out) + : "wa"(vec_in0), "wa"(vec_in1), "I"(C)); + return vec_out; +} + +#define vec_sldw(a, b, c) vec_sldw_aux(a, b) +#endif + +#define vec_not(a) vec_nor(a, a) +#if defined(__clang__) && !defined(vec_splats) +C10_ALWAYS_INLINE vint64 vec_splats(const int64_t& a) { + return vec_splats(a); +} +#endif +// Vectorized min/max which return a if any operand is nan +template +C10_ALWAYS_INLINE T vec_min_nan(const T& a, const T& b) { + return vec_min(a, b); +} +template +C10_ALWAYS_INLINE T vec_max_nan(const T& a, const T& b) { + return vec_max(a, b); +} + +// Specializations for float/double taken from Eigen +template <> +C10_ALWAYS_INLINE vfloat32 +vec_min_nan(const vfloat32& a, const vfloat32& b) { + // NOTE: about 10% slower than vec_min, but consistent with std::min and SSE + // regarding NaN + vfloat32 ret; + __asm__("xvcmpgesp %x0,%x1,%x2\n\txxsel %x0,%x1,%x2,%x0" + : "=&wa"(ret) + : "wa"(a), "wa"(b)); + return ret; +} +// Specializations for float/double taken from Eigen +template <> +C10_ALWAYS_INLINE vfloat32 +vec_max_nan(const vfloat32& a, const vfloat32& b) { + // NOTE: about 10% slower than vec_max, but consistent with std::min and SSE + // regarding NaN + vfloat32 ret; + __asm__("xvcmpgtsp %x0,%x2,%x1\n\txxsel %x0,%x1,%x2,%x0" + : "=&wa"(ret) + : "wa"(a), "wa"(b)); + return ret; +} + +template <> +C10_ALWAYS_INLINE vfloat64 +vec_min_nan(const vfloat64& a, const vfloat64& b) { + // NOTE: about 10% slower than vec_min, but consistent with std::min and SSE + // regarding NaN + vfloat64 ret; + __asm__("xvcmpgedp %x0,%x1,%x2\n\txxsel %x0,%x1,%x2,%x0" + : "=&wa"(ret) + : "wa"(a), "wa"(b)); + return ret; +} +template <> +C10_ALWAYS_INLINE vfloat64 +vec_max_nan(const vfloat64& a, const vfloat64& b) { + // NOTE: about 10% slower than vec_max, but consistent with std::max and SSE + // regarding NaN + vfloat64 ret; + __asm__("xvcmpgtdp %x0,%x2,%x1\n\txxsel %x0,%x1,%x2,%x0" + : "=&wa"(ret) + : "wa"(a), "wa"(b)); + return ret; +} + +// Vectorizes min/max function which returns nan if any side is nan +#define C10_VSX_VEC_NAN_PROPAG(name, type, btype, func) \ + C10_ALWAYS_INLINE type name(const type& a, const type& b) { \ + type tmp = func(a, b); \ + btype nan_a = vec_cmpne(a, a); \ + btype nan_b = vec_cmpne(b, b); \ + tmp = vec_sel(tmp, a, nan_a); \ + return vec_sel(tmp, b, nan_b); \ + } + +C10_VSX_VEC_NAN_PROPAG(vec_min_nan2, vfloat32, vbool32, vec_min) +C10_VSX_VEC_NAN_PROPAG(vec_max_nan2, vfloat32, vbool32, vec_max) +C10_VSX_VEC_NAN_PROPAG(vec_min_nan2, vfloat64, vbool64, vec_min) +C10_VSX_VEC_NAN_PROPAG(vec_max_nan2, vfloat64, vbool64, vec_max) + +#undef C10_VSX_VEC_NAN_PROPAG + +#define DEFINE_MEMBER_UNARY_OP(op, op_type, func) \ + Vectorized C10_ALWAYS_INLINE op() const { \ + return Vectorized{func(_vec0), func(_vec1)}; \ + } + +#define DEFINE_MEMBER_OP(op, op_type, func) \ + Vectorized C10_ALWAYS_INLINE op(const Vectorized& other) \ + const { \ + return Vectorized{ \ + func(_vec0, other._vec0), func(_vec1, other._vec1)}; \ + } + +#define DEFINE_MEMBER_BITWISE_OP(op, op_type, func) \ + Vectorized C10_ALWAYS_INLINE op(const Vectorized& other) \ + const { \ + return Vectorized{ \ + func(_vecb0, other._vecb0), func(_vecb1, other._vecb1)}; \ + } + +#define DEFINE_MEMBER_TERNARY_OP(op, op_type, func) \ + Vectorized C10_ALWAYS_INLINE op( \ + const Vectorized& b, const Vectorized& c) const { \ + return Vectorized{ \ + func(_vec0, b._vec0, c._vec0), func(_vec1, b._vec1, c._vec1)}; \ + } + +#define DEFINE_MEMBER_EMULATE_BINARY_OP(op, op_type, binary_op) \ + Vectorized C10_ALWAYS_INLINE op(const Vectorized& b) \ + const { \ + Vectorized::vec_internal_type ret_0; \ + Vectorized::vec_internal_type ret_1; \ + for (int i = 0; i < Vectorized::size() / 2; i++) { \ + ret_0[i] = _vec0[i] binary_op b._vec0[i]; \ + ret_1[i] = _vec1[i] binary_op b._vec1[i]; \ + } \ + return Vectorized{ret_0, ret_1}; \ + } + +#define DEFINE_MEMBER_OP_AND_ONE(op, op_type, func) \ + Vectorized C10_ALWAYS_INLINE op(const Vectorized& other) \ + const { \ + using vvtype = Vectorized::vec_internal_type; \ + const vvtype v_one = vec_splats(static_cast(1.0)); \ + vvtype ret0 = (vvtype)func(_vec0, other._vec0); \ + vvtype ret1 = (vvtype)func(_vec1, other._vec1); \ + return Vectorized{vec_and(ret0, v_one), vec_and(ret1, v_one)}; \ + } + +#define DEFINE_CLAMP_FUNCS(operand_type) \ + template <> \ + Vectorized C10_ALWAYS_INLINE clamp( \ + const Vectorized& a, \ + const Vectorized& min, \ + const Vectorized& max) { \ + return Vectorized{ \ + vec_min_nan(vec_max_nan(a.vec0(), min.vec0()), max.vec0()), \ + vec_min_nan(vec_max_nan(a.vec1(), min.vec1()), max.vec1())}; \ + } \ + template <> \ + Vectorized C10_ALWAYS_INLINE clamp_min( \ + const Vectorized& a, \ + const Vectorized& min) { \ + return Vectorized{ \ + vec_max_nan(a.vec0(), min.vec0()), vec_max_nan(a.vec1(), min.vec1())}; \ + } \ + template <> \ + Vectorized C10_ALWAYS_INLINE clamp_max( \ + const Vectorized& a, \ + const Vectorized& max) { \ + return Vectorized{ \ + vec_min_nan(a.vec0(), max.vec0()), vec_min_nan(a.vec1(), max.vec1())}; \ + } + +#define DEFINE_REINTERPRET_CAST_FUNCS( \ + first_type, cast_type, cast_inner_vector_type) \ + template <> \ + C10_ALWAYS_INLINE Vectorized cast( \ + const Vectorized& src) { \ + return Vectorized{ \ + (cast_inner_vector_type)src.vec0(), \ + (cast_inner_vector_type)src.vec1()}; \ + } + +#define DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(first_type) \ + DEFINE_REINTERPRET_CAST_FUNCS(first_type, double, vfloat64) \ + DEFINE_REINTERPRET_CAST_FUNCS(first_type, float, vfloat32) \ + DEFINE_REINTERPRET_CAST_FUNCS(first_type, int64_t, vint64) \ + DEFINE_REINTERPRET_CAST_FUNCS(first_type, int32_t, vint32) \ + DEFINE_REINTERPRET_CAST_FUNCS(first_type, int16_t, vint16) + +// it can be used to emulate blend faster +constexpr int blendChoice( + uint32_t mask, + uint32_t half1 = 0xF, + uint32_t half2 = 0xF0) { + uint32_t none = 0; + uint32_t both = half1 | half2; + // clamp it between 0 and both + mask = mask & both; + // return (a._vec0, a._vec1) + if (mask == none) + return 0; + // return (b._vec0,b._vec1) + else if (mask == both) + return 1; + // return (b._vec0,a._vec1) + else if (mask == half1) + return 2; + // return (a._vec0,b._vec1) + else if (mask == half2) + return 3; + // return (*_vec0,a._vec1) + else if (mask > 0 && mask < half1) + return 4; + // return (*_vec0,b._vec1) + else if ((mask & half2) == half2) + return 5; + // return (a._vec0,*_vec1) + else if ((mask & half1) == 0 && mask > half1) + return 6; + // return (b._vec0,*_vec1) + else if ((mask & half1) == half1 && mask > half1) + return 7; + // return (*_vec0,*_vec1) + return 8; +} + +// it can be used to emulate blend faster +constexpr int blendChoiceDbl(uint32_t mask) { + // clamp it 0 and 0xF + return blendChoice(mask, 0x3, 0xC); +} + +constexpr vbool32 VsxMask1(uint32_t mask) { + uint32_t g0 = (mask & 1) * 0xffffffff; + uint32_t g1 = ((mask & 2) >> 1) * 0xffffffff; + uint32_t g2 = ((mask & 4) >> 2) * 0xffffffff; + uint32_t g3 = ((mask & 8) >> 3) * 0xffffffff; + return (vbool32){g0, g1, g2, g3}; +} + +constexpr vbool32 VsxMask2(uint32_t mask) { + uint32_t mask2 = (mask & 0xFF) >> 4; + return VsxMask1(mask2); +} + +constexpr vbool64 VsxDblMask1(uint32_t mask) { + uint64_t g0 = (mask & 1) * 0xffffffffffffffff; + uint64_t g1 = ((mask & 2) >> 1) * 0xffffffffffffffff; + return (vbool64){g0, g1}; +} + +constexpr vbool64 VsxDblMask2(uint32_t mask) { + uint32_t mask2 = (mask & 0xF) >> 2; + return VsxDblMask1(mask2); +} + +constexpr int maskForComplex(uint32_t mask) { + mask = mask & 0xF; + int complex_mask = 0; + if (mask & 1) + complex_mask |= 3; + if (mask & 2) + complex_mask |= (3 << 2); + if (mask & 4) + complex_mask |= (3 << 4); + if (mask & 8) + complex_mask |= (3 << 6); + return complex_mask; +} + +constexpr int maskForComplexDbl(uint32_t mask) { + mask = mask & 0x3; + int complex_mask = 0; + if (mask & 1) + complex_mask |= 3; + if (mask & 2) + complex_mask |= (3 << 2); + return complex_mask; +} + +constexpr int blendChoiceComplex(uint32_t mask) { + return blendChoice(maskForComplex(mask)); +} + +constexpr int blendChoiceComplexDbl(uint32_t mask) { + return blendChoiceDbl(maskForComplexDbl(mask)); +} + +constexpr vbool32 VsxComplexMask1(uint32_t mask) { + return VsxMask1(maskForComplex(mask)); +} + +constexpr vbool32 VsxComplexMask2(uint32_t mask) { + uint32_t mask2 = (mask & 0xF) >> 2; + return VsxMask1(maskForComplex(mask2)); +} + +constexpr vbool64 VsxComplexDblMask1(uint32_t mask) { + return VsxDblMask1(mask); +} + +constexpr vbool64 VsxComplexDblMask2(uint32_t mask) { + uint32_t mask2 = (mask & 0xF) >> 2; + return VsxDblMask1(mask2); +} + +// constants +namespace at { +namespace vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { +// +constexpr int offset0 = 0; +constexpr int offset16 = 16; + +// #Constants +const vuint8 mask_zero_bits = vuint8{ + 128, + 128, + 128, + 128, + 128, + 128, + 128, + 128, + 128, + 128, + 128, + 128, + 96, + 64, + 32, + 0}; + +const vuint8 swap_mask = + vuint8{4, 5, 6, 7, 0, 1, 2, 3, 12, 13, 14, 15, 8, 9, 10, 11}; + +const vint32 v0x7f = vec_splats(0x7f); +const vint32 vi_0 = vec_splats((int)(0)); +const vint32 vi_1 = vec_splats((int)1); +const vint32 vi_2 = vec_splats((int)2); +const vint32 vi_4 = vec_splats((int)4); +const vint32 vi_inv1 = vec_splats((int)~1); +const vuint32 vu_29 = vec_splats(29u); +const vuint32 vu_23 = vec_splats(23u); + +const vbool32 inv_mant_mask = (vbool32)vec_splats((unsigned int)~0xff800000); +const vbool32 sign_mask = (vbool32)vec_splats((int)0x80000000); +const vbool32 real_mask = vbool32{0xFFFFFFFF, 0x0, 0xFFFFFFFF, 0x0}; +const vbool32 imag_mask = vbool32{0x0, 0xFFFFFFFF, 0x0, 0xFFFFFFFF}; +const vbool32 isign_mask = vbool32{0x0, 0x80000000, 0x0, 0x80000000}; +const vbool32 rsign_mask = vbool32{0x80000000, 0x0, 0x80000000, 0x0}; + +const vbool64 vd_sign_mask = vbool64{0x8000000000000000, 0x8000000000000000}; +const vbool64 vd_imag_mask = vbool64{0x0, 0xFFFFFFFFFFFFFFFF}; +const vbool64 vd_real_mask = vbool64{0xFFFFFFFFFFFFFFFF, 0x0}; +const vbool64 vd_isign_mask = vbool64{0x0, 0x8000000000000000}; +const vbool64 vd_rsign_mask = vbool64{0x8000000000000000, 0x0}; + +const vfloat32 zero = vec_splats(0.f); +const vfloat32 half = vec_splats(0.5f); +const vfloat32 one = vec_splats(1.f); +const vfloat32 two = vec_splats(2.0f); +const vfloat32 _4div_pi = vec_splats(1.27323954473516f); +const vfloat32 v_inf = (vfloat32)vec_splats(0x7f800000u); +const vfloat32 v_minus_inf = + vfloat32{0xff800000u, 0xff800000u, 0xff800000u, 0xff800000u}; +const vfloat32 v_nan = (vfloat32)vec_splats(0x7fffffff); +const vfloat32 log10e_inv = vec_splats(0.43429448190325176f); +const vfloat32 log2e_inv = vec_splats(1.4426950408889634f); +const vfloat32 log2eB_inv = vec_splats(1.442695036924675f); +const vfloat32 cephes_SQRTHF = vec_splats(0.707106781186547524f); +const vfloat32 coscof_p0 = vec_splats(2.443315711809948E-005f); +const vfloat32 coscof_p1 = vec_splats(-1.388731625493765E-003f); +const vfloat32 coscof_p2 = vec_splats(4.166664568298827E-002f); +const vfloat32 exp_hi = vec_splats(104.f); +const vfloat32 exp_lo = vec_splats(-104.f); +const vfloat32 exp_p0 = vec_splats(0.000198527617612853646278381f); +const vfloat32 exp_p1 = vec_splats((0.00139304355252534151077271f)); +const vfloat32 exp_p2 = vec_splats(0.00833336077630519866943359f); +const vfloat32 exp_p3 = vec_splats(0.0416664853692054748535156f); +const vfloat32 exp_p4 = vec_splats(0.166666671633720397949219f); +const vfloat32 exp_p5 = vec_splats(0.5f); +const vfloat32 log_p0 = vec_splats(7.0376836292E-2f); +const vfloat32 log_p1 = vec_splats(-1.1514610310E-1f); +const vfloat32 log_p2 = vec_splats(1.1676998740E-1f); +const vfloat32 log_p3 = vec_splats(-1.2420140846E-1f); +const vfloat32 log_p4 = vec_splats(+1.4249322787E-1f); +const vfloat32 log_p5 = vec_splats(-1.6668057665E-1f); +const vfloat32 log_p6 = vec_splats(+2.0000714765E-1f); +const vfloat32 log_p7 = vec_splats(-2.4999993993E-1f); +const vfloat32 log_p8 = vec_splats(+3.3333331174E-1f); +const vfloat32 log_q1 = vec_splats(-2.12194440e-4f); +const vfloat32 log_q2 = vec_splats(0.693359375f); +const vfloat32 max_logf = vec_splats(88.02969187150841f); +const vfloat32 max_numf = + vec_splats(1.7014117331926442990585209174225846272e38f); +const vfloat32 min_inf = (vfloat32)vec_splats(0xff800000u); +const vfloat32 min_norm_pos = (vfloat32)vec_splats(0x0800000u); +const vfloat32 minus_cephes_dp1 = vec_splats(-0.78515625f); +const vfloat32 minus_cephes_dp2 = vec_splats(-2.4187564849853515625e-4f); +const vfloat32 minus_cephes_dp3 = vec_splats(-3.77489497744594108e-8f); +const vfloat32 negln2f_hi = vec_splats(-0.693145751953125f); +const vfloat32 negln2f_lo = vec_splats(-1.428606765330187045e-06f); +const vfloat32 p0 = vec_splats(2.03721912945E-4f); +const vfloat32 p1 = vec_splats(8.33028376239E-3f); +const vfloat32 p2 = vec_splats(1.66667160211E-1f); +const vfloat32 sincof_p0 = vec_splats(-1.9515295891E-4f); +const vfloat32 sincof_p1 = vec_splats(8.3321608736E-3f); +const vfloat32 sincof_p2 = vec_splats(-1.6666654611E-1f); +const vfloat32 tanh_0p625 = vec_splats(0.625f); +const vfloat32 tanh_half_max = vec_splats(44.014845935754205f); +const vfloat32 tanh_p0 = vec_splats(-5.70498872745E-3f); +const vfloat32 tanh_p1 = vec_splats(2.06390887954E-2f); +const vfloat32 tanh_p2 = vec_splats(-5.37397155531E-2f); +const vfloat32 tanh_p3 = vec_splats(1.33314422036E-1f); +const vfloat32 tanh_p4 = vec_splats(-3.33332819422E-1f); +const vfloat32 vcheck = vec_splats((float)(1LL << 24)); +const vfloat32 imag_one = vfloat32{0.f, 1.f, 0.f, 1.f}; +const vfloat32 imag_half = vfloat32{0.f, 0.5f, 0.f, 0.5f}; +const vfloat32 sqrt2_2 = vfloat32{ + 0.70710676908493042f, + 0.70710676908493042, + 0.70710676908493042, + 0.70710676908493042}; +const vfloat32 pi_2 = vfloat32{M_PI / 2, 0.0, M_PI / 2, 0.0}; +const vfloat32 vf_89 = vfloat32{89.f, 89.f, 89.f, 89.f}; +const vfloat64 vd_one = vec_splats(1.0); +const vfloat64 vd_zero = vec_splats(0.0); +const vfloat64 vd_log10e_inv = vec_splats(0.43429448190325176); +const vfloat64 vd_log2e_inv = vec_splats(1.4426950408889634); +const vfloat64 vd_imag_one = vfloat64{0.0, 1.0}; +const vfloat64 vd_imag_half = vfloat64{0.0, 0.5}; +const vfloat64 vd_sqrt2_2 = vfloat64{0.70710678118654757, 0.70710678118654757}; +const vfloat64 vd_pi_2 = vfloat64{M_PI / 2.0, 0.0}; + +template +Vectorized VsxShiftRightArith( + const Vectorized& a, + const Vectorized& b) { + const Vectorized max_shift(sizeof(T) * CHAR_BIT - std::is_signed_v); + const auto mask = (b < Vectorized(0)) | (b >= max_shift); + const auto shift = Vectorized::blendv(b, max_shift, mask); + return Vectorized{ + vec_sra(a.vec0(), make_vuint(shift.vec0())), + vec_sra(a.vec1(), make_vuint(shift.vec1()))}; +} + +template +Vectorized VsxShiftLeftArith( + const Vectorized& a, + const Vectorized& b) { + const Vectorized max_shift(sizeof(T) * CHAR_BIT); + const auto mask = (b < Vectorized(0)) | (b >= max_shift); + Vectorized ret( + vec_sl(a.vec0(), make_vuint(b.vec0())), + vec_sl(a.vec1(), make_vuint(b.vec1()))); + return Vectorized::blendv(ret, Vectorized(0), mask); +} + +#define DEFINE_SHIFT_FUNCS(operand_type) \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator>>( \ + const Vectorized& a, const Vectorized& b) { \ + return VsxShiftRightArith(a, b); \ + } \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator<<( \ + const Vectorized& a, const Vectorized& b) { \ + return VsxShiftLeftArith(a, b); \ + } + +} // namespace CPU_CAPABILITY +} // namespace vec +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/zarch/vec256_zarch.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/zarch/vec256_zarch.h new file mode 100644 index 0000000000000000000000000000000000000000..c48ae8c5732d8276a45ac698dedf87f27678d582 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec256/zarch/vec256_zarch.h @@ -0,0 +1,2978 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include +#include +#include +#include +#include +#if defined(__clang__) +#include +#elif defined(__GNUC__) || defined(__GNUG__) +#include +#include +#endif +#include +#include +#include + +namespace at { +namespace vec { + +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +template +constexpr bool is_zarch_implemented() { + return ( + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v); +} + +template +constexpr bool is_zarch_implemented_quant() { + return ( + std::is_same_v || std::is_same_v || + std::is_same_v); +} + +template +constexpr bool is_zarch_implemented_complex() { + return std::is_same_v> || + std::is_same_v>; +} + +constexpr int offset0 = 0; +constexpr int offset16 = 16; + +template +struct VecBinaryType { + using type __attribute__((vector_size(16))) = uintmax_t; +}; + +template <> +struct VecBinaryType<8> { + using type = __attribute__((vector_size(16))) unsigned long long; +}; + +template <> +struct VecBinaryType<4> { + using type = __attribute__((vector_size(16))) unsigned int; +}; + +template <> +struct VecBinaryType<2> { + using type = __attribute__((vector_size(16))) unsigned short; +}; + +template <> +struct VecBinaryType<1> { + using type = __attribute__((vector_size(16))) unsigned char; +}; + +template +struct VecInnerType { + using Type __attribute__((vector_size(16))) = T; + using BinaryType = typename VecBinaryType::type; + using ElementType = T; + static constexpr int size = 16 / sizeof(T); +}; + +// define for int64_t properly for load +template <> +struct VecInnerType { + using Type = __attribute__((vector_size(16))) signed long long; + using ElementType = signed long long; + using BinaryType = typename VecBinaryType::type; + static constexpr int size = 16 / sizeof(signed long long); +}; + +template +using ZSimdVect = typename VecInnerType::Type; +template +using ZSimdVectBinary = typename VecInnerType::BinaryType; +template +using ZSimdVectElement = typename VecInnerType::ElementType; + +constexpr int blendChoiceInner( + const uint64_t mask, + const uint64_t half1 = 0xF, + const uint64_t half2 = 0xF0) { + uint64_t none = 0; + uint64_t both = half1 | half2; + // clamp it between 0 and both + auto res_mask = mask & both; + // return (a._vec0, a._vec1) + if (res_mask == none) + return 0; + // return (b._vec0,b._vec1) + else if (res_mask == both) + return 1; + // return (b._vec0, a._vec1) + else if (res_mask == half1) + return 2; + // return (a._vec0,b._vec1) + else if (res_mask == half2) + return 3; + // return (*_vec0,a._vec1) + else if (res_mask > 0 && res_mask < half1) + return 4; + // return (*_vec0,b._vec1) + else if ((res_mask & half2) == half2) + return 5; + // return (a._vec0,*_vec1) + else if ((res_mask & half1) == 0 && res_mask > half1) + return 6; + // return (b._vec0,*_vec1) + else if ((res_mask & half1) == half1 && res_mask > half1) + return 7; + // return (*_vec0,*_vec1) + return 8; +} + +// it can be used to emulate blend faster +template +constexpr int blendChoice(const uint64_t mask) { + static_assert(Z < 1 || Z > 8, "not implemented"); + return blendChoiceInner(mask); +} + +template <> +constexpr int blendChoice<1>(const uint64_t mask) { + return blendChoiceInner(mask, 0x0000FFFF, 0xFFFF0000); +} + +template <> +constexpr int blendChoice<2>(const uint64_t mask) { + return blendChoiceInner(mask, 0x00FF, 0xFF00); +} + +template <> +constexpr int blendChoice<4>(const uint64_t mask) { + return blendChoiceInner(mask, 0xF, 0xF0); +} + +template <> +constexpr int blendChoice<8>(const uint64_t mask) { + // clamp it 0 and 0xF + return blendChoiceInner(mask, 0x3, 0xC); +} + +template +constexpr auto GetMask1(const uint64_t mask) { + return typename VecBinaryType::type{}; +} + +template +constexpr auto GetMask2(const uint64_t mask) { + return typename VecBinaryType::type{}; +} + +template <> +constexpr auto GetMask1<1>(const uint64_t mask) { + constexpr uint8_t t = (int)0xFF; + uint8_t g0 = (mask & 1) * t; + uint8_t g1 = ((mask & 2) >> 1) * t; + uint8_t g2 = ((mask & 4) >> 2) * t; + uint8_t g3 = ((mask & 8) >> 3) * t; + uint8_t g4 = ((mask & 16) >> 4) * t; + uint8_t g5 = ((mask & 32) >> 5) * t; + uint8_t g6 = ((mask & 64) >> 6) * t; + uint8_t g7 = ((mask & 128) >> 7) * t; + uint8_t g8 = ((mask & 256) >> 8) * t; + uint8_t g9 = ((mask & 512) >> 9) * t; + uint8_t g10 = ((mask & 1024) >> 10) * t; + uint8_t g11 = ((mask & 2048) >> 11) * t; + uint8_t g12 = ((mask & 4096) >> 12) * t; + uint8_t g13 = ((mask & 8192) >> 13) * t; + uint8_t g14 = ((mask & 16384) >> 14) * t; + uint8_t g15 = ((mask & 32768) >> 15) * t; + return (typename VecBinaryType<1>::type){ + g0, g1, g2, g3, g4, g5, g6, g7, g8, g9, g10, g11, g12, g13, g14, g15}; +} + +template <> +constexpr auto GetMask2<1>(const uint64_t mask) { + uint64_t mask2 = (mask & 0xFFFFFFFF) >> 16; + return GetMask1<1>(mask2); +} + +template <> +constexpr auto GetMask1<2>(const uint64_t mask) { + constexpr uint16_t t = (int)0xFFFF; + uint16_t g0 = (mask & 1) * t; + uint16_t g1 = ((mask & 2) >> 1) * t; + uint16_t g2 = ((mask & 4) >> 2) * t; + uint16_t g3 = ((mask & 8) >> 3) * t; + uint16_t g4 = ((mask & 16) >> 4) * t; + uint16_t g5 = ((mask & 32) >> 5) * t; + uint16_t g6 = ((mask & 64) >> 6) * t; + uint16_t g7 = ((mask & 128) >> 7) * t; + return (typename VecBinaryType<2>::type){g0, g1, g2, g3, g4, g5, g6, g7}; +} + +template <> +constexpr auto GetMask2<2>(const uint64_t mask) { + uint64_t mask2 = (mask & 0xFFFF) >> 8; + return GetMask1<2>(mask2); +} + +template <> +constexpr auto GetMask1<4>(const uint64_t mask) { + uint32_t g0 = (mask & 1) * 0xffffffff; + uint32_t g1 = ((mask & 2) >> 1) * 0xffffffff; + uint32_t g2 = ((mask & 4) >> 2) * 0xffffffff; + uint32_t g3 = ((mask & 8) >> 3) * 0xffffffff; + return (typename VecBinaryType<4>::type){g0, g1, g2, g3}; +} + +template <> +constexpr auto GetMask2<4>(const uint64_t mask) { + uint64_t mask2 = (mask & 0xFF) >> 4; + return GetMask1<4>(mask2); +} + +template <> +constexpr auto GetMask1<8>(const uint64_t mask) { + uint64_t g0 = (mask & 1) * 0xffffffffffffffff; + uint64_t g1 = ((mask & 2) >> 1) * 0xffffffffffffffff; + return (typename VecBinaryType<8>::type){g0, g1}; +} + +template <> +constexpr auto GetMask2<8>(const uint64_t mask) { + uint64_t mask2 = (mask & 0xF) >> 2; + return GetMask1<8>(mask2); +} + +template +constexpr int maskForComplex(uint32_t mask) { + return 0; +} + +template <> +constexpr int maskForComplex<8>(uint32_t mask) { + mask = mask & 0xF; + int complex_mask = 0; + if (mask & 1) + complex_mask |= 3; + if (mask & 2) + complex_mask |= (3 << 2); + if (mask & 4) + complex_mask |= (3 << 4); + if (mask & 8) + complex_mask |= (3 << 6); + return complex_mask; +} + +template <> +constexpr int maskForComplex<16>(uint32_t mask) { + mask = mask & 0x3; + int complex_mask = 0; + if (mask & 1) + complex_mask |= 3; + if (mask & 2) + complex_mask |= (3 << 2); + return complex_mask; +} + +template > +constexpr int blend_choice() { + return 0xAA; +} + +template <> +constexpr int blend_choice>() { + return 0x0A; +} + +constexpr int64_t allbitset(int16_t x) { + int64_t onex = 1; + return (onex << x) - onex; +} + +namespace { /* unnamed namespace */ + +ZSimdVect vec_mergee(ZSimdVect x, ZSimdVect y) { + constexpr ZSimdVectBinary mergee_mask{ + 0, 1, 2, 3, 16, 17, 18, 19, 8, 9, 10, 11, 24, 25, 26, 27}; + return vec_perm(x, y, mergee_mask); +} + +ZSimdVect vec_mergee(ZSimdVect x, ZSimdVect y) { + return vec_mergeh(x, y); +} + +ZSimdVect vec_mergeo(ZSimdVect x, ZSimdVect y) { + constexpr ZSimdVectBinary mergeo_mask{ + 4, 5, 6, 7, 20, 21, 22, 23, 12, 13, 14, 15, 28, 29, 30, 31}; + return vec_perm(x, y, mergeo_mask); +} + +ZSimdVect vec_mergeo(ZSimdVect x, ZSimdVect y) { + return vec_mergel(x, y); +} + +} /* unnamed namespace */ + +// +template +constexpr auto GetBpermZeroMask() { + return ZSimdVectBinary{ + 128, + 128, + 128, + 128, + 128, + 128, + 128, + 128, + 128, + 128, + 128, + 128, + 96, + 64, + 32, + 0}; +} + +template <> +constexpr auto GetBpermZeroMask() { + return ZSimdVectBinary{ + 128, + 128, + 128, + 128, + 128, + 128, + 128, + 128, + 128, + 128, + 128, + 128, + 128, + 128, + 64, + 0}; +} + +constexpr auto GetSwapMaskFloat() { + return ZSimdVectBinary{ + 4, 5, 6, 7, 0, 1, 2, 3, 12, 13, 14, 15, 8, 9, 10, 11}; +} + +template +struct is_vec_specialized_for()>> + : std::bool_constant {}; + +template +struct Vectorized()>> { + public: + using value_type = T; + using vtype = ZSimdVect; + using vmaskType = ZSimdVectBinary; + using size_type = int; + // because of gcc inconsistency for int64_t we are obliged to use this, not + // value_type + using ElementType = ZSimdVectElement; + using vinner_data = std::pair; + + private: + vtype _vec0; + vtype _vec1; + + public: + static constexpr size_type size() { + return VECTOR_WIDTH / sizeof(ElementType); + } + Vectorized() {} + + C10_ALWAYS_INLINE Vectorized(vtype v) : _vec0{v}, _vec1{v} {} + C10_ALWAYS_INLINE Vectorized(const vinner_data& v) + : _vec0{v.first}, _vec1{v.second} {} + C10_ALWAYS_INLINE Vectorized(vtype v1, vtype v2) : _vec0{v1}, _vec1{v2} {} + C10_ALWAYS_INLINE Vectorized(T s) + : _vec0{vec_splats((ElementType)s)}, _vec1{vec_splats((ElementType)s)} {} + + template + struct LoaduHelper { + static Vectorized C10_ALWAYS_INLINE + loadu(const U* ptr, int count = size()) { + __at_align__ ElementType tmp_values[size()] = {}; + std::memcpy( + tmp_values, ptr, std::min(count, size()) * sizeof(ElementType)); + + return { + vec_xl(offset0, &(tmp_values[0])), + vec_xl(offset16, &(tmp_values[0]))}; + } + }; + + template + struct LoaduHelper { + static Vectorized C10_ALWAYS_INLINE + loadu(const ElementType* ptr, int count = size()) { + if (count == size()) { + return {vec_xl(offset0, ptr), vec_xl(offset16, ptr)}; + } + + __at_align__ ElementType tmp_values[size()] = {}; + std::memcpy( + tmp_values, ptr, std::min(count, size()) * sizeof(ElementType)); + + return { + vec_xl(offset0, &(tmp_values[0])), + vec_xl(offset16, &(tmp_values[0]))}; + } + }; + + template + static Vectorized C10_ALWAYS_INLINE + loadu(const U* ptr, int count = size()) { + return LoaduHelper::loadu(ptr, count); + } + + template + static Vectorized C10_ALWAYS_INLINE loadu_one_fourth(const U* ptr) { + // load only first 8 bytes + // only intended to be used with uint8_t + return loadu(ptr, 8 / sizeof(ElementType)); + } + + template + struct StoreHelper { + static void C10_ALWAYS_INLINE + store(const Vectorized& vec, U* ptr, int count = size()) { + if (count > 0) { + __at_align__ ElementType tmp_values[size()]; + vec_xst(vec._vec0, offset0, &(tmp_values[0])); + vec_xst(vec._vec1, offset16, &(tmp_values[0])); + std::memcpy( + ptr, tmp_values, std::min(count, size()) * sizeof(ElementType)); + } + } + }; + + template + struct StoreHelper { + static void C10_ALWAYS_INLINE + store(const Vectorized& vec, ElementType* ptr, int count = size()) { + if (count == size()) { + vec_xst(vec._vec0, offset0, ptr); + vec_xst(vec._vec1, offset16, ptr); + } else if (count > 0) { + __at_align__ ElementType tmp_values[size()]; + vec_xst(vec._vec0, offset0, &(tmp_values[0])); + vec_xst(vec._vec1, offset16, &(tmp_values[0])); + std::memcpy( + ptr, tmp_values, std::min(count, size()) * sizeof(ElementType)); + } + } + }; + + template + void C10_ALWAYS_INLINE store(U* ptr, int count = size()) const { + return StoreHelper::store(*this, ptr, count); + } + + C10_ALWAYS_INLINE const vtype& vec0() const { + return _vec0; + } + + C10_ALWAYS_INLINE const vtype& vec1() const { + return _vec1; + } + + C10_ALWAYS_INLINE vinner_data data() const { + return std::make_pair<>(_vec0, _vec1); + } + + C10_ALWAYS_INLINE operator vinner_data() const { + return data(); + } + + C10_ALWAYS_INLINE const vmaskType vecb0() const { + return (vmaskType)_vec0; + } + C10_ALWAYS_INLINE const vmaskType vecb1() const { + return (vmaskType)_vec1; + } + + static Vectorized C10_ALWAYS_INLINE blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + return { + vec_sel(a._vec0, b._vec0, mask.vecb0()), + vec_sel(a._vec1, b._vec1, mask.vecb1())}; + } + + template = 0> + C10_ALWAYS_INLINE Vectorized(T s1, T s2, T s3, T s4) + : _vec0{s1, s2}, _vec1{s3, s4} {} + + template = 0> + C10_ALWAYS_INLINE Vectorized(T s1, T s2, T s3, T s4, T s5, T s6, T s7, T s8) + : _vec0{s1, s2, s3, s4}, _vec1{s5, s6, s7, s8} {} + + template = 0> + C10_ALWAYS_INLINE Vectorized( + T s1, + T s2, + T s3, + T s4, + T s5, + T s6, + T s7, + T s8, + T s9, + T s10, + T s11, + T s12, + T s13, + T s14, + T s15, + T s16) + : _vec0{s1, s2, s3, s4, s5, s6, s7, s8}, + _vec1{s9, s10, s11, s12, s13, s14, s15, s16} {} + + template = 0> + C10_ALWAYS_INLINE Vectorized( + T s1, + T s2, + T s3, + T s4, + T s5, + T s6, + T s7, + T s8, + T s9, + T s10, + T s11, + T s12, + T s13, + T s14, + T s15, + T s16, + T s17, + T s18, + T s19, + T s20, + T s21, + T s22, + T s23, + T s24, + T s25, + T s26, + T s27, + T s28, + T s29, + T s30, + T s31, + T s32) + : _vec0{s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13, s14, s15, s16}, + _vec1{ + s17, + s18, + s19, + s20, + s21, + s22, + s23, + s24, + s25, + s26, + s27, + s28, + s29, + s30, + s31, + s32} {} + + template + static std::enable_if_t> arange( + T base = 0, + step_t step = static_cast(1)) { + return Vectorized(base, base + step, base + 2 * step, base + 3 * step); + } + + template + static std::enable_if_t> arange( + T base = 0, + step_t step = static_cast(1)) { + return Vectorized( + base, + base + step, + base + 2 * step, + base + 3 * step, + base + 4 * step, + base + 5 * step, + base + 6 * step, + base + 7 * step); + } + + template + static std::enable_if_t> arange( + T base = 0, + step_t step = static_cast(1)) { + return Vectorized( + base, + base + step, + base + 2 * step, + base + 3 * step, + base + 4 * step, + base + 5 * step, + base + 6 * step, + base + 7 * step, + base + 8 * step, + base + 9 * step, + base + 10 * step, + base + 11 * step, + base + 12 * step, + base + 13 * step, + base + 14 * step, + base + 15 * step); + } + + template + static std::enable_if_t> arange( + T base = 0, + step_t step = static_cast(1)) { + return Vectorized( + base, + base + step, + base + 2 * step, + base + 3 * step, + base + 4 * step, + base + 5 * step, + base + 6 * step, + base + 7 * step, + base + 8 * step, + base + 9 * step, + base + 10 * step, + base + 11 * step, + base + 12 * step, + base + 13 * step, + base + 14 * step, + base + 15 * step, + base + 16 * step, + base + 17 * step, + base + 18 * step, + base + 19 * step, + base + 20 * step, + base + 21 * step, + base + 22 * step, + base + 23 * step, + base + 24 * step, + base + 25 * step, + base + 26 * step, + base + 27 * step, + base + 28 * step, + base + 29 * step, + base + 30 * step, + base + 31 * step); + } + + // blend section + template + static std::enable_if_t(mask) == 0, Vectorized> + C10_ALWAYS_INLINE blend(const Vectorized& a, const Vectorized& b) { + return a; + } + + template + static std::enable_if_t(mask) == 1, Vectorized> + C10_ALWAYS_INLINE blend(const Vectorized& a, const Vectorized& b) { + return b; + } + + template + static std::enable_if_t(mask) == 2, Vectorized> + C10_ALWAYS_INLINE blend(const Vectorized& a, const Vectorized& b) { + return {b._vec0, a._vec1}; + } + + template + static std::enable_if_t(mask) == 3, Vectorized> + C10_ALWAYS_INLINE blend(const Vectorized& a, const Vectorized& b) { + return {a._vec0, b._vec1}; + } + + template + static std::enable_if_t(mask) == 4, Vectorized> + C10_ALWAYS_INLINE blend(const Vectorized& a, const Vectorized& b) { + const vmaskType mask_1st = GetMask1(mask); + return {(vtype)vec_sel(a._vec0, b._vec0, mask_1st), a._vec1}; + } + + template + static std::enable_if_t(mask) == 5, Vectorized> + C10_ALWAYS_INLINE blend(const Vectorized& a, const Vectorized& b) { + const vmaskType mask_1st = GetMask1(mask); + return {(vtype)vec_sel(a._vec0, b._vec0, mask_1st), b._vec1}; + } + + template + static std::enable_if_t(mask) == 6, Vectorized> + C10_ALWAYS_INLINE blend(const Vectorized& a, const Vectorized& b) { + const vmaskType mask_2nd = GetMask2(mask); + // generated masks + return {a._vec0, (vtype)vec_sel(a._vec1, b._vec1, mask_2nd)}; + } + + template + static std::enable_if_t(mask) == 7, Vectorized> + C10_ALWAYS_INLINE blend(const Vectorized& a, const Vectorized& b) { + const vmaskType mask_2nd = GetMask2(mask); + // generated masks + return {b._vec0, (vtype)vec_sel(a._vec1, b._vec1, mask_2nd)}; + } + + template + static std::enable_if_t(mask) == 8, Vectorized> + C10_ALWAYS_INLINE blend(const Vectorized& a, const Vectorized& b) { + const vmaskType mask_1st = GetMask1(mask); + const vmaskType mask_2nd = GetMask2(mask); + return { + (vtype)vec_sel(a._vec0, b._vec0, mask_1st), + (vtype)vec_sel(a._vec1, b._vec1, mask_2nd)}; + } + + template + static inline std::enable_if_t<(Z >= C), Vectorized> set_inner( + const Vectorized& a, + const Vectorized& b, + size_t count) { + return b; + } + + template + static inline std::enable_if_t<(Z < C), Vectorized> set_inner( + const Vectorized& a, + const Vectorized& b, + size_t count) { + if (count == Z) + return blend(a, b); + else + return set_inner(a, b, count); + } + + static Vectorized set( + const Vectorized& a, + const Vectorized& b, + size_t count = size()) { + if (count == 0) + return a; + return set_inner<1, size()>(a, b, count); + } + + const ElementType& operator[](int idx) const = delete; + ElementType& operator[](int idx) = delete; + + Vectorized _not() const { + return {(vtype)vec_nor(vecb0(), vecb0()), (vtype)vec_nor(vecb1(), vecb1())}; + } + + Vectorized C10_ALWAYS_INLINE eq(const Vectorized& other) const { + return (*this == other) & Vectorized((T)1.0); + } + Vectorized C10_ALWAYS_INLINE ne(const Vectorized& other) const { + return (*this != other) & Vectorized((T)1.0); + } + Vectorized C10_ALWAYS_INLINE gt(const Vectorized& other) const { + return (*this > other) & Vectorized((T)1.0); + } + Vectorized C10_ALWAYS_INLINE ge(const Vectorized& other) const { + return (*this >= other) & Vectorized((T)1.0); + } + Vectorized C10_ALWAYS_INLINE lt(const Vectorized& other) const { + return (*this < other) & Vectorized((T)1.0); + } + Vectorized C10_ALWAYS_INLINE le(const Vectorized& other) const { + return (*this <= other) & Vectorized((T)1.0); + } + + template , int> = 0> + Vectorized C10_ALWAYS_INLINE abs() const { + return {vec_abs(_vec0), vec_abs(_vec1)}; + } + + template , int> = 0> + Vectorized C10_ALWAYS_INLINE abs() const { + return {_vec0, _vec1}; + } + + Vectorized C10_ALWAYS_INLINE neg() const { + return {-_vec0, -_vec1}; + } + + Vectorized isnan() const { + auto x = *this; + auto ret = (x == x); + return ret._not(); + } + + bool has_inf_nan() const { + for (const auto i : c10::irange(size() / 2)) { + if (_isnan(_vec0[i]) || _isinf(_vec0[i])) { + return true; + } + } + for (const auto i : c10::irange(size() / 2)) { + if (_isnan(_vec1[i]) || _isinf(_vec1[i])) { + return true; + } + } + return false; + } + + template < + typename U = T, + std::enable_if_t, int> = 0> + Vectorized angle() const { + auto tmp = blendv( + Vectorized(0), Vectorized(c10::pi), *this < Vectorized(0)); + return blendv(tmp, *this, isnan()); + } + + template < + typename U = T, + std::enable_if_t, int> = 0> + Vectorized angle() const { + return blendv( + Vectorized(0), Vectorized(c10::pi), *this < Vectorized(0)); + } + + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return Vectorized{0}; + } + Vectorized conj() const { + return *this; + } + + template < + typename U = T, + std::enable_if_t, int> = 0> + int zero_mask() const { + auto cmp = (*this == Vectorized(0)); + constexpr auto mask_zero_bits = GetBpermZeroMask(); + ZSimdVectBinary result0 = + vec_bperm_u128((ZSimdVectBinary)cmp.vecb0(), mask_zero_bits); + ZSimdVectBinary result1 = + vec_bperm_u128((ZSimdVectBinary)cmp.vecb1(), mask_zero_bits); + return (result0[0] | (result1[0] << (size() / 2))); + } + + Vectorized C10_ALWAYS_INLINE floor() const { + return {vec_floor(_vec0), vec_floor(_vec1)}; + } + + Vectorized C10_ALWAYS_INLINE ceil() const { + return {vec_ceil(_vec0), vec_ceil(_vec1)}; + } + + Vectorized C10_ALWAYS_INLINE round() const { + return {vec_round(_vec0), vec_round(_vec1)}; + } + + Vectorized C10_ALWAYS_INLINE rint() const { + return {vec_rint(_vec0), vec_rint(_vec1)}; + } + + Vectorized C10_ALWAYS_INLINE trunc() const { + return {vec_trunc(_vec0), vec_trunc(_vec1)}; + } + + Vectorized C10_ALWAYS_INLINE frac() const { + return *this - trunc(); + } + + Vectorized C10_ALWAYS_INLINE sqrt() const { + return {vec_sqrt(_vec0), vec_sqrt(_vec1)}; + } + Vectorized C10_ALWAYS_INLINE reciprocal() const { + return Vectorized((T)1) / (*this); + } + Vectorized C10_ALWAYS_INLINE rsqrt() const { + return sqrt().reciprocal(); + } + + template , int> = 0> + inline Vectorized mapOrdinary(float (*const f)(float)) const { + float a00 = f(_vec0[0]); + float a01 = f(_vec0[1]); + float a02 = f(_vec0[2]); + float a03 = f(_vec0[3]); + float a10 = f(_vec1[0]); + float a11 = f(_vec1[1]); + float a12 = f(_vec1[2]); + float a13 = f(_vec1[3]); + return Vectorized{a00, a01, a02, a03, a10, a11, a12, a13}; + } + + template < + typename U = T, + std::enable_if_t, int> = 0> + inline Vectorized mapOrdinary(double (*const f)(double)) const { + return Vectorized(f(_vec0[0]), f(_vec0[1]), f(_vec1[0]), f(_vec1[1])); + } + + template , int> = 0> + inline Vectorized mapOrdinary( + float (*const f)(float, float), + const Vectorized& b) const { + float a00 = f(_vec0[0], b._vec0[0]); + float a01 = f(_vec0[1], b._vec0[1]); + float a02 = f(_vec0[2], b._vec0[2]); + float a03 = f(_vec0[3], b._vec0[3]); + float a10 = f(_vec1[0], b._vec1[0]); + float a11 = f(_vec1[1], b._vec1[1]); + float a12 = f(_vec1[2], b._vec1[2]); + float a13 = f(_vec1[3], b._vec1[3]); + return Vectorized{a00, a01, a02, a03, a10, a11, a12, a13}; + } + + template < + typename U = T, + std::enable_if_t, int> = 0> + inline Vectorized mapOrdinary( + double (*const f)(double, double), + const Vectorized& b) const { + return Vectorized( + f(_vec0[0], b._vec0[0]), + f(_vec0[1], b._vec0[1]), + f(_vec1[0], b._vec1[0]), + f(_vec1[1], b._vec1[1])); + } + + template < + typename FloatOp, + typename DoubleOp, + typename U = T, + std::enable_if_t, int> = 0> + inline Vectorized mapSleef(FloatOp f, DoubleOp d) const { + vtype a0 = f(_vec0); + vtype a1 = f(_vec1); + return Vectorized{a0, a1}; + } + + template < + typename FloatOp, + typename DoubleOp, + typename U = T, + std::enable_if_t, int> = 0> + inline Vectorized mapSleef(FloatOp f, DoubleOp d) const { + return Vectorized(d(_vec0), d(_vec1)); + } + + template < + typename FloatOp, + typename DoubleOp, + typename U = T, + std::enable_if_t, int> = 0> + inline Vectorized mapSleef(FloatOp f, DoubleOp d, const Vectorized& b) + const { + vtype a0 = f(_vec0, b._vec0); + vtype a1 = f(_vec1, b._vec1); + return Vectorized{a0, a1}; + } + + template < + typename FloatOp, + typename DoubleOp, + typename U = T, + std::enable_if_t, int> = 0> + inline Vectorized mapSleef(FloatOp f, DoubleOp d, const Vectorized& b) + const { + return Vectorized(d(_vec0, b._vec0), d(_vec1, b._vec1)); + } + + Vectorized acos() const { + return mapSleef(Sleef_acosf4_u10, Sleef_acosd2_u10); + } + Vectorized asin() const { + return mapSleef(Sleef_asinf4_u10, Sleef_asind2_u10); + } + Vectorized atan() const { + return mapSleef(Sleef_atanf4_u10, Sleef_atand2_u10); + } + Vectorized atanh() const { + return mapSleef(Sleef_atanhf4_u10, Sleef_atanhd2_u10); + } + + Vectorized erf() const { + return mapSleef(Sleef_erff4_u10, Sleef_erfd2_u10); + } + Vectorized erfc() const { + return mapSleef(Sleef_erfcf4_u15, Sleef_erfcd2_u15); + } + + Vectorized exp() const { + return mapSleef(Sleef_expf4_u10, Sleef_expd2_u10); + } + Vectorized exp2() const { + return mapSleef(Sleef_exp2f4_u10, Sleef_exp2d2_u10); + } + Vectorized expm1() const { + return mapSleef(Sleef_expm1f4_u10, Sleef_expm1d2_u10); + } + Vectorized exp_u20() const { + return exp(); + } + Vectorized fexp_u20() const { + return exp(); + } + + Vectorized log() const { + return mapSleef(Sleef_logf4_u10, Sleef_logd2_u10); + } + Vectorized log2() const { + return mapSleef(Sleef_log2f4_u10, Sleef_log2d2_u10); + } + Vectorized log10() const { + return mapSleef(Sleef_log10f4_u10, Sleef_log10d2_u10); + } + Vectorized log1p() const { + return mapSleef(Sleef_log1pf4_u10, Sleef_log1pd2_u10); + } + + Vectorized sin() const { + return mapSleef(Sleef_sinf4_u10, Sleef_sind2_u10); + } + Vectorized sinh() const { + return mapSleef(Sleef_sinhf4_u10, Sleef_sinhd2_u10); + } + Vectorized cos() const { + return mapSleef(Sleef_cosf4_u10, Sleef_cosd2_u10); + } + Vectorized cosh() const { + return mapSleef(Sleef_coshf4_u10, Sleef_coshd2_u10); + } + + Vectorized tan() const { + return mapSleef(Sleef_tanf4_u10, Sleef_tand2_u10); + } + Vectorized tanh() const { + return mapSleef(Sleef_tanhf4_u10, Sleef_tanhd2_u10); + } + + Vectorized lgamma() const { + return mapSleef(Sleef_lgammaf4_u10, Sleef_lgammad2_u10); + } + + Vectorized atan2(const Vectorized& b) const { + return mapSleef(Sleef_atan2f4_u10, Sleef_atan2d2_u10, b); + } + Vectorized copysign(const Vectorized& sign) const { + return mapSleef(Sleef_copysignf4, Sleef_copysignd2, sign); + } + Vectorized fmod(const Vectorized& q) const { + return mapSleef(Sleef_fmodf4, Sleef_fmodd2, q); + } + + Vectorized hypot(const Vectorized& b) const { + return mapSleef(Sleef_hypotf4_u05, Sleef_hypotd2_u05, b); + } + + Vectorized pow(const Vectorized& b) const { + return mapSleef(Sleef_powf4_u10, Sleef_powd2_u10, b); + } + + Vectorized nextafter(const Vectorized& b) const { + return mapSleef(Sleef_nextafterf4, Sleef_nextafterd2, b); + } + + Vectorized erfinv() const { + return mapOrdinary(calc_erfinv); + } + + Vectorized digamma() const { + return mapOrdinary(calc_digamma); + } + + Vectorized igamma(const Vectorized& x) const { + return mapOrdinary(calc_igamma, x); + } + + Vectorized igammac(const Vectorized& x) const { + return mapOrdinary(calc_igammac, x); + } + + Vectorized i0() const { + return mapOrdinary(calc_i0); + } + + Vectorized i0e() const { + return mapOrdinary(calc_i0e); + } + + template < + typename U = T, + std::enable_if_t, int> = 0> + Vectorized minimum(const Vectorized& other) const { + return {vec_min(_vec0, other._vec0), vec_min(_vec1, other._vec1)}; + } + + /* Propagates NaN if either input is a NaN. */ + template < + typename U = T, + std::enable_if_t, int> = 0> + Vectorized minimum(const Vectorized& other) const { + Vectorized tmp = { + vec_min(_vec0, other._vec0), vec_min(_vec1, other._vec1)}; + tmp = blendv(tmp, *this, isnan()); + return blendv(tmp, other, other.isnan()); + } + + template < + typename U = T, + std::enable_if_t, int> = 0> + Vectorized maximum(const Vectorized& other) const { + return {vec_max(_vec0, other._vec0), vec_max(_vec1, other._vec1)}; + } + + /* Propagates NaN if either input is a NaN. */ + template < + typename U = T, + std::enable_if_t, int> = 0> + Vectorized maximum(const Vectorized& other) const { + Vectorized tmp = { + vec_max(_vec0, other._vec0), vec_max(_vec1, other._vec1)}; + tmp = blendv(tmp, *this, isnan()); + return blendv(tmp, other, other.isnan()); + } + + template < + typename U = T, + std::enable_if_t, int> = 0> + Vectorized clamp_min(const Vectorized& min) const { + return {vec_max(_vec0, min._vec0), vec_max(_vec1, min._vec1)}; + } + + /* Keeps NaN if actual value is NaN */ + template < + typename U = T, + std::enable_if_t, int> = 0> + Vectorized clamp_min(const Vectorized& min) const { + Vectorized tmp = {vec_max(_vec0, min._vec0), vec_max(_vec1, min._vec1)}; + return blendv(tmp, *this, isnan()); + } + + template < + typename U = T, + std::enable_if_t, int> = 0> + Vectorized clamp_max(const Vectorized& max) const { + return {vec_min(_vec0, max._vec0), vec_min(_vec1, max._vec1)}; + } + + /* Keeps NaN if actual value is NaN */ + template < + typename U = T, + std::enable_if_t, int> = 0> + Vectorized clamp_max(const Vectorized& max) const { + Vectorized tmp = {vec_min(_vec0, max._vec0), vec_min(_vec1, max._vec1)}; + return blendv(tmp, *this, isnan()); + } + + template , int> = 0> + Vectorized swapped() const { + auto swap_mask = GetSwapMaskFloat(); + vtype v0 = vec_perm(_vec0, _vec0, swap_mask); + vtype v1 = vec_perm(_vec1, _vec1, swap_mask); + return {v0, v1}; + } + + template < + typename U = T, + std::enable_if_t, int> = 0> + Vectorized swapped() const { + vtype v0 = {_vec0[1], _vec0[0]}; + vtype v1 = {_vec1[1], _vec1[0]}; + return {v0, v1}; + } + + template < + typename U = T, + std::enable_if_t, int> = 0> + static Vectorized mergee(Vectorized& first, Vectorized& second) { + return { + vec_mergee(first._vec0, second._vec0), + vec_mergee(first._vec1, second._vec1)}; + } + + template < + typename U = T, + std::enable_if_t, int> = 0> + static Vectorized mergeo(Vectorized& first, Vectorized& second) { + return { + vec_mergeo(first._vec0, second._vec0), + vec_mergeo(first._vec1, second._vec1)}; + } + + static Vectorized horizontal_add_perm( + Vectorized& first, + Vectorized& second) { + // we will simulate it differently with 6 instructions total + // lets permute second so that we can add it getting horizontal sums + auto first_perm = first.swapped(); // 2perm + auto second_perm = second.swapped(); // 2perm + // summ + auto first_ret = first + first_perm; // 2add + auto second_ret = second + second_perm; // 2 add + // now lets choose evens + return mergee(first_ret, second_ret); // 2 mergee's + } + + static Vectorized horizontal_sub_perm( + Vectorized& first, + Vectorized& second) { + // we will simulate it differently with 6 instructions total + // lets permute second so that we can add it getting horizontal sums + auto first_perm = first.swapped(); // 2perm + auto second_perm = second.swapped(); // 2perm + // summ + auto first_ret = first - first_perm; // 2sub + auto second_ret = second - second_perm; // 2 sub + // now lets choose evens + return mergee(first_ret, second_ret); // 2 mergee's + } + + template < + typename U = T, + std::enable_if_t, int> = 0> + Vectorized mergee() const { + return {vec_mergee(_vec0, _vec0), vec_mergee(_vec1, _vec1)}; + } + + template < + typename U = T, + std::enable_if_t, int> = 0> + Vectorized mergeo() const { + return {vec_mergeo(_vec0, _vec0), vec_mergeo(_vec1, _vec1)}; + } + + template < + typename U = T, + std::enable_if_t, int> = 0> + Vectorized to_vec_float_helper() const { + int32_t values[8] = { + _vec0[0], + _vec0[1], + _vec0[2], + _vec0[3], + _vec0[4], + _vec0[5], + _vec0[6], + _vec0[7], + }; + + return Vectorized{ + values[0], + values[1], + values[2], + values[3], + values[4], + values[5], + values[6], + values[7]}; + } + + template < + typename U = T, + std::enable_if_t, int> = 0> + Vectorized to_vec_uint8_helper() const { + // helper function for float to uint8_t conversion + uint8_t values[8] = { + static_cast(_vec0[0]), + static_cast(_vec0[1]), + static_cast(_vec0[2]), + static_cast(_vec0[3]), + static_cast(_vec1[0]), + static_cast(_vec1[1]), + static_cast(_vec1[2]), + static_cast(_vec1[3]), + }; + + return Vectorized{ + values[0], values[1], values[2], values[3], values[4], values[5], + values[6], values[7], 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, + 0, 0, + }; + } +}; + +#define ZVECTOR_OPERATORS(typex) \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator+( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec0() + b.vec0(), a.vec1() + b.vec1()}; \ + } \ + \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator-( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec0() - b.vec0(), a.vec1() - b.vec1()}; \ + } \ + \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator*( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec0() * b.vec0(), a.vec1() * b.vec1()}; \ + } \ + \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator/( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec0() / b.vec0(), a.vec1() / b.vec1()}; \ + } \ + \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator&( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{ \ + (Vectorized::vtype)(a.vecb0() & b.vecb0()), \ + (Vectorized::vtype)(a.vecb1() & b.vecb1())}; \ + } \ + \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator|( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{ \ + (Vectorized::vtype)(a.vecb0() | b.vecb0()), \ + (Vectorized::vtype)(a.vecb1() | b.vecb1())}; \ + } \ + \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator^( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{ \ + (Vectorized::vtype)(a.vecb0() ^ b.vecb0()), \ + (Vectorized::vtype)(a.vecb1() ^ b.vecb1())}; \ + } \ + \ + Vectorized C10_ALWAYS_INLINE operator==( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{ \ + vec_cmpeq(a.vec0(), b.vec0()), vec_cmpeq(a.vec1(), b.vec1())}; \ + } \ + \ + Vectorized C10_ALWAYS_INLINE operator!=( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{ \ + vec_cmpeq(a.vec0(), b.vec0()), vec_cmpeq(a.vec1(), b.vec1())} \ + ._not(); \ + } \ + \ + Vectorized C10_ALWAYS_INLINE operator>( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{ \ + vec_cmpgt(a.vec0(), b.vec0()), vec_cmpgt(a.vec1(), b.vec1())}; \ + } \ + \ + Vectorized C10_ALWAYS_INLINE operator>=( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{ \ + vec_cmpge(a.vec0(), b.vec0()), vec_cmpge(a.vec1(), b.vec1())}; \ + } \ + \ + Vectorized C10_ALWAYS_INLINE operator<( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{ \ + vec_cmplt(a.vec0(), b.vec0()), vec_cmplt(a.vec1(), b.vec1())}; \ + } \ + \ + Vectorized C10_ALWAYS_INLINE operator<=( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{ \ + vec_cmple(a.vec0(), b.vec0()), vec_cmple(a.vec1(), b.vec1())}; \ + } + +ZVECTOR_OPERATORS(float) +ZVECTOR_OPERATORS(double) +ZVECTOR_OPERATORS(int8_t) +ZVECTOR_OPERATORS(uint8_t) +ZVECTOR_OPERATORS(uint16_t) +ZVECTOR_OPERATORS(int16_t) +ZVECTOR_OPERATORS(int32_t) +ZVECTOR_OPERATORS(int64_t) + +#undef ZVECTOR_OPERATORS + +#define ZVECTOR_OPERATORS(typex) \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator<<( \ + const Vectorized& a, const Vectorized& b) { \ + constexpr Vectorized::ElementType max_shift = \ + sizeof(Vectorized::ElementType) * CHAR_BIT; \ + \ + Vectorized::ElementType a_array[Vectorized::size()]; \ + Vectorized::ElementType b_array[Vectorized::size()]; \ + Vectorized::ElementType c_array[Vectorized::size()]; \ + \ + a.store(a_array); \ + b.store(b_array); \ + \ + for (int i = 0; i != Vectorized::size(); i++) { \ + typex shift = b_array[i]; \ + if ((static_cast>(shift) < 0) || \ + (shift >= max_shift)) { \ + c_array[i] = 0; \ + } else { \ + c_array[i] = static_cast>(a_array[i]) \ + << shift; \ + } \ + } \ + \ + return Vectorized::loadu(c_array); \ + } \ + \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator>>( \ + const Vectorized& a, const Vectorized& b) { \ + /* right shift value to retain sign bit for signed and no bits for \ + * unsigned */ \ + constexpr Vectorized::ElementType max_shift = \ + sizeof(typex) * CHAR_BIT - std::is_signed_v; \ + \ + Vectorized::ElementType a_array[Vectorized::size()]; \ + Vectorized::ElementType b_array[Vectorized::size()]; \ + Vectorized::ElementType c_array[Vectorized::size()]; \ + \ + a.store(a_array); \ + b.store(b_array); \ + \ + for (int i = 0; i != Vectorized::size(); i++) { \ + typex shift = b_array[i]; \ + if ((static_cast>(shift) < 0) || \ + (shift >= max_shift)) { \ + c_array[i] = a_array[i] >> max_shift; \ + } else { \ + c_array[i] = a_array[i] >> shift; \ + } \ + } \ + \ + return Vectorized::loadu(c_array); \ + } \ + \ + template <> \ + inline Vectorized operator~(const Vectorized& a) { \ + return a._not(); \ + } + +ZVECTOR_OPERATORS(int8_t) +ZVECTOR_OPERATORS(uint8_t) +ZVECTOR_OPERATORS(uint16_t) +ZVECTOR_OPERATORS(int16_t) +ZVECTOR_OPERATORS(int32_t) +ZVECTOR_OPERATORS(int64_t) + +#undef ZVECTOR_OPERATORS + +#define DEFINE_MAXMIN_FUNCS(operand_type) \ + template <> \ + Vectorized inline maximum( \ + const Vectorized& a, const Vectorized& b) { \ + return a.maximum(b); \ + } \ + template <> \ + Vectorized inline minimum( \ + const Vectorized& a, const Vectorized& b) { \ + return a.minimum(b); \ + } + +#define DEFINE_CLAMP_MAXMIN_FUNCS(typex) \ + DEFINE_MAXMIN_FUNCS(typex) \ + template <> \ + Vectorized C10_ALWAYS_INLINE clamp_min( \ + const Vectorized& a, const Vectorized& min) { \ + return a.clamp_min(min); \ + } \ + template <> \ + Vectorized C10_ALWAYS_INLINE clamp_max( \ + const Vectorized& a, const Vectorized& max) { \ + return a.clamp_max(max); \ + } \ + template <> \ + Vectorized C10_ALWAYS_INLINE clamp( \ + const Vectorized& a, \ + const Vectorized& min, \ + const Vectorized& max) { \ + return clamp_max(clamp_min(a, min), max); \ + } + +DEFINE_CLAMP_MAXMIN_FUNCS(int8_t) +DEFINE_CLAMP_MAXMIN_FUNCS(uint8_t) +DEFINE_CLAMP_MAXMIN_FUNCS(int16_t) +DEFINE_CLAMP_MAXMIN_FUNCS(int32_t) +DEFINE_CLAMP_MAXMIN_FUNCS(int64_t) +DEFINE_CLAMP_MAXMIN_FUNCS(float) +DEFINE_CLAMP_MAXMIN_FUNCS(double) + +namespace { /* unnamed namespace */ + +#if !defined(vec_float) || __ARCH__ < 13 +#warning \ + "float->int and int->float conversion is simulated. compile for z15 for improved performance" +inline ZSimdVect vec_int_flt(const ZSimdVect x) { + return ZSimdVect{float(x[0]), float(x[1]), float(x[2]), float(x[3])}; +} +inline ZSimdVect vec_flt_int(const ZSimdVect x) { + return ZSimdVect{int(x[0]), int(x[1]), int(x[2]), int(x[3])}; +} +#else +#define vec_int_flt vec_float +#define vec_flt_int vec_signed +#endif + +Vectorized zvec_convert_to_float(const Vectorized& x) { + return {vec_int_flt(x.vec0()), vec_int_flt(x.vec1())}; +} + +Vectorized zvec_convert_to_int(const Vectorized& x) { + return {vec_flt_int(x.vec0()), vec_flt_int(x.vec1())}; +} + +Vectorized zvec_convert_to_float(const Vectorized& x) { + return {vec_double(x.vec0()), vec_double(x.vec1())}; +} + +Vectorized zvec_convert_to_int(const Vectorized& x) { + return {vec_signed(x.vec0()), vec_signed(x.vec1())}; +} + +} /* unnamed namespace */ + +template +Vectorized cast_zvector(const Vectorized& x) { + using cast_type = typename Vectorized::vtype; + return Vectorized{(cast_type)x.vec0(), (cast_type)x.vec1()}; +} + +template <> +Vectorized C10_ALWAYS_INLINE fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return Vectorized{ + __builtin_s390_vfmasb(a.vec0(), b.vec0(), c.vec0()), + __builtin_s390_vfmasb(a.vec1(), b.vec1(), c.vec1())}; +} +template <> +Vectorized C10_ALWAYS_INLINE fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return Vectorized{ + __builtin_s390_vfmadb(a.vec0(), b.vec0(), c.vec0()), + __builtin_s390_vfmadb(a.vec1(), b.vec1(), c.vec1())}; +} +template <> +Vectorized C10_ALWAYS_INLINE fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return Vectorized{ + a.vec0() * b.vec0() + c.vec0(), a.vec1() * b.vec1() + c.vec1()}; +} +template <> +Vectorized C10_ALWAYS_INLINE fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return Vectorized{ + a.vec0() * b.vec0() + c.vec0(), a.vec1() * b.vec1() + c.vec1()}; +} +template <> +Vectorized C10_ALWAYS_INLINE fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return Vectorized{ + a.vec0() * b.vec0() + c.vec0(), a.vec1() * b.vec1() + c.vec1()}; +} + +template <> +Vectorized C10_ALWAYS_INLINE +convert_to_int_of_same_size(const Vectorized& src) { + return zvec_convert_to_int(src); +} + +template <> +Vectorized C10_ALWAYS_INLINE +convert_to_int_of_same_size(const Vectorized& src) { + return zvec_convert_to_int(src); +} + +template <> +inline void convert(const int32_t* src, float* dst, int64_t n) { + // int32_t and float have same size + int64_t i; + for (i = 0; i <= (n - Vectorized::size()); + i += Vectorized::size()) { + const int32_t* src_a = src + i; + float* dst_a = dst + i; + auto input_vec = Vectorized::loadu(src_a); + auto output_vec = zvec_convert_to_float(input_vec); + output_vec.store(dst_a); + } + + for (; i < n; i++) { + dst[i] = static_cast(src[i]); + } +} + +template <> +inline void convert(const int64_t* src, double* dst, int64_t n) { + int64_t i; + for (i = 0; i <= (n - Vectorized::size()); + i += Vectorized::size()) { + const int64_t* src_a = src + i; + double* dst_a = dst + i; + auto input_vec = Vectorized::loadu(src_a); + auto output_vec = zvec_convert_to_float(input_vec); + output_vec.store(dst_a); + } + for (; i < n; i++) { + dst[i] = static_cast(src[i]); + } +} + +#define DEFINE_REINTERPRET_CAST_FUNCS(Fst, Cst) \ + template <> \ + C10_ALWAYS_INLINE Vectorized cast( \ + const Vectorized& src) { \ + return cast_zvector(src); \ + } + +#define DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(Fst) \ + DEFINE_REINTERPRET_CAST_FUNCS(Fst, double) \ + DEFINE_REINTERPRET_CAST_FUNCS(Fst, float) \ + DEFINE_REINTERPRET_CAST_FUNCS(Fst, int64_t) \ + DEFINE_REINTERPRET_CAST_FUNCS(Fst, int32_t) \ + DEFINE_REINTERPRET_CAST_FUNCS(Fst, int16_t) + +DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(float) +DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(double) +DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(int64_t) +DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(int32_t) +DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(int16_t) + +#undef DEFINE_REINTERPRET_CAST_FUNCS + +template +struct unpack_type { + using type = T; +}; +template <> +struct unpack_type { + using type = int16_t; +}; +template <> +struct unpack_type { + using type = int16_t; +}; +template <> +struct unpack_type { + using type = int32_t; +}; + +template +struct pack_type { + using type = T; +}; +template <> +struct pack_type { + using type = int8_t; +}; +template <> +struct pack_type { + using type = int16_t; +}; + +namespace { /* unnamed namespace */ + +template ::type> +std::pair, Vectorized> unpack(const Vectorized& x) { + auto vec0 = vec_unpackh(x.vec0()); + auto vec1 = vec_unpackl(x.vec0()); + auto vec2 = vec_unpackh(x.vec1()); + auto vec3 = vec_unpackl(x.vec1()); + return {Vectorized{vec0, vec1}, Vectorized{vec2, vec3}}; +} + +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-function") +template <> +std::pair, Vectorized> unpack( + const Vectorized& x) { + using typeX = typename Vectorized::vtype; + typeX vec0 = vec_unpackh(x.vec0()); + typeX vec1 = vec_unpackl(x.vec0()); + typeX vec2 = vec_unpackh(x.vec1()); + typeX vec3 = vec_unpackl(x.vec1()); + // auto mask = Vectorized(0xFF); + // vec0 = vec0 & mask; + // vec1 = vec1 & mask; + // vec2 = vec2 & mask; + // vec3 = vec3 & mask; + return { + cast_zvector(Vectorized{vec0, vec1}), + cast_zvector(Vectorized{vec2, vec3})}; +} +C10_DIAGNOSTIC_POP() + +template ::type> +Vectorized pack(const Vectorized& first, const Vectorized& second) { + auto vec0 = vec_packs(first.vec0(), first.vec1()); + auto vec1 = vec_packs(second.vec0(), second.vec1()); + return Vectorized{vec0, vec1}; +} + +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-function") +template <> +Vectorized pack( + const Vectorized& first, + const Vectorized& second) { + auto vec0 = vec_packsu(first.vec0(), first.vec1()); + auto vec1 = vec_packsu(second.vec0(), second.vec1()); + return Vectorized{vec0, vec1}; +} +C10_DIAGNOSTIC_POP() + +} /* unnamed namespace */ + +//////////////////////////////////QUANT/////////////////////////////////////////// +template +struct is_vec_specialized_for< + T, + std::enable_if_t()>> + : std::bool_constant {}; + +template +struct Vectorized()>> { + public: + using value_type = typename T::underlying; + using vtype = ZSimdVect; + using vmaskType = ZSimdVectBinary; + using vinner_type = Vectorized; + using size_type = int; + + static constexpr size_type size() { + return VECTOR_WIDTH / sizeof(value_type); + } + + static constexpr int float_num_vecs() { + return size() / Vectorized::size(); + } + static constexpr int int_num_vecs() { + return float_num_vecs(); + } + using float_vec_return_type = std::array, float_num_vecs()>; + using int_vec_return_type = + std::array, int_num_vecs()>; + + private: + vinner_type _vec; + + public: + Vectorized() {} + + explicit C10_ALWAYS_INLINE Vectorized(vinner_type v) : _vec{v} {} + Vectorized(const T& val) : _vec(val.val_) {} + + C10_ALWAYS_INLINE const vinner_type& vec() const { + return _vec; + } + + template + static Vectorized C10_ALWAYS_INLINE + loadu(const U* ptr, int count = size()) { + return Vectorized{vinner_type::loadu(ptr, count)}; + } + + template + void C10_ALWAYS_INLINE store(U* ptr, int count = size()) const { + _vec.store(ptr, count); + } + + Vectorized relu(Vectorized zero_point) const { + return Vectorized{_vec.maximum(zero_point._vec)}; + } + + Vectorized relu6(Vectorized zero_point, Vectorized q_six) const { + auto ret_max = _vec.maximum(zero_point._vec); + auto ret_min = ret_max.minimum(q_six._vec); + return Vectorized{ret_min}; + } + + template < + typename U = T, + std::enable_if_t::float_num_vecs() == 1, int> = 0> + int_vec_return_type widening_subtract(Vectorized b) const { + return {*this - b}; + } + + template < + typename U = T, + std::enable_if_t::float_num_vecs() == 1, int> = 0> + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point, + Vectorized scale_zp_premul) const { + auto float_val = zvec_convert_to_float(_vec); + return {fmadd(scale, float_val, scale_zp_premul)}; + } + + template < + typename U = T, + std::enable_if_t::float_num_vecs() == 1, int> = 0> + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point) const { + auto float_val = zvec_convert_to_float(_vec); + return {(float_val - zero_point) * scale}; + } + + template < + typename U = T, + std::enable_if_t::float_num_vecs() == 1, int> = 0> + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale) { + Vectorized vecf = rhs[0]; + vecf = vecf * Vectorized(inverse_scale); + vecf = vecf.rint() + Vectorized((float)(zero_point)); + auto veci = zvec_convert_to_int(vecf); + + return Vectorized{veci}; + } + + template < + typename U = T, + std::enable_if_t::int_num_vecs() == 1, int> = 0> + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + Vectorized vi = inp[0]; + auto vecf = zvec_convert_to_float(vi.vec()); + vecf = vecf * Vectorized(multiplier); + vecf = vecf.rint(); + auto veci = zvec_convert_to_int(vecf) + Vectorized(zero_point); + + return Vectorized{veci}; + } + + template < + typename U = T, + std::enable_if_t::int_num_vecs() == 4, int> = 0> + int_vec_return_type widening_subtract(Vectorized b) const { + auto ret16 = unpack(_vec); + auto ret16B = unpack(b.vec()); + auto ret32_0 = unpack(ret16.first); + auto ret32_1 = unpack(ret16.second); + auto ret32B_0 = unpack(ret16B.first); + auto ret32B_1 = unpack(ret16B.second); + + return { + Vectorized(ret32_0.first - ret32B_0.first), + Vectorized(ret32_0.second - ret32B_0.second), + Vectorized(ret32_1.first - ret32B_1.first), + Vectorized(ret32_1.second - ret32B_1.second)}; + } + + template < + typename U = T, + std::enable_if_t::float_num_vecs() == 4, int> = 0> + float_vec_return_type C10_ALWAYS_INLINE dequantize( + Vectorized scale, + Vectorized zero_point, + Vectorized scale_zp_premul) const { + // unpacking unsigned as signed + auto ret16 = unpack(_vec); + auto ret32_0 = unpack(ret16.first); + auto ret32_1 = unpack(ret16.second); + + auto vecf_0 = zvec_convert_to_float(ret32_0.first); + auto vecf_1 = zvec_convert_to_float(ret32_0.second); + + auto vecf_2 = zvec_convert_to_float(ret32_1.first); + auto vecf_3 = zvec_convert_to_float(ret32_1.second); + return { + fmadd(scale, vecf_0, scale_zp_premul), + fmadd(scale, vecf_1, scale_zp_premul), + fmadd(scale, vecf_2, scale_zp_premul), + fmadd(scale, vecf_3, scale_zp_premul)}; + } + + template < + typename U = T, + std::enable_if_t::float_num_vecs() == 4, int> = 0> + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point) const { + // unpacking unsigned as signed + auto ret16 = unpack(_vec); + auto ret32_0 = unpack(ret16.first); + auto ret32_1 = unpack(ret16.second); + + auto vecf_0 = zvec_convert_to_float(ret32_0.first); + auto vecf_1 = zvec_convert_to_float(ret32_0.second); + + auto vecf_2 = zvec_convert_to_float(ret32_1.first); + auto vecf_3 = zvec_convert_to_float(ret32_1.second); + + return { + (vecf_0 - zero_point) * scale, + (vecf_1 - zero_point) * scale, + (vecf_2 - zero_point) * scale, + (vecf_3 - zero_point) * scale}; + } + + template < + typename U = T, + std::enable_if_t::float_num_vecs() == 4, int> = 0> + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale) { + auto vec_inverse = Vectorized(inverse_scale); + auto vec_zero_point = Vectorized((float)zero_point); + + auto vecf0 = rhs[0]; + auto vecf2 = rhs[1]; + auto vecf4 = rhs[2]; + auto vecf6 = rhs[3]; + + vecf0 = vecf0 * vec_inverse; + vecf2 = vecf2 * vec_inverse; + vecf4 = vecf4 * vec_inverse; + vecf6 = vecf6 * vec_inverse; + + vecf0 = vecf0.rint() + vec_zero_point; + vecf2 = vecf2.rint() + vec_zero_point; + vecf4 = vecf4.rint() + vec_zero_point; + vecf6 = vecf6.rint() + vec_zero_point; + + auto veci0 = zvec_convert_to_int(vecf0); + auto veci2 = zvec_convert_to_int(vecf2); + auto veci4 = zvec_convert_to_int(vecf4); + auto veci6 = zvec_convert_to_int(vecf6); + + auto vecshi0 = pack(veci0, veci2); + auto vecshi2 = pack(veci4, veci6); + auto ret = pack(vecshi0, vecshi2); + + return Vectorized{ret}; + } + + template < + typename U = T, + std::enable_if_t::int_num_vecs() == 4, int> = 0> + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + Vectorized vec_multiplier = Vectorized(multiplier); + Vectorized vec_zero_point = Vectorized(zero_point); + + Vectorized vi0 = inp[0]; + Vectorized vi1 = inp[1]; + Vectorized vi2 = inp[2]; + Vectorized vi3 = inp[3]; + + auto vecf0 = zvec_convert_to_float(vi0.vec()); + auto vecf2 = zvec_convert_to_float(vi1.vec()); + + auto vecf4 = zvec_convert_to_float(vi2.vec()); + auto vecf6 = zvec_convert_to_float(vi3.vec()); + + vecf0 = vecf0 * vec_multiplier; + vecf2 = vecf2 * vec_multiplier; + + vecf4 = vecf4 * vec_multiplier; + vecf6 = vecf6 * vec_multiplier; + + vecf0 = vecf0.rint(); + vecf2 = vecf2.rint(); + vecf4 = vecf4.rint(); + vecf6 = vecf6.rint(); + + auto veci0 = zvec_convert_to_int(vecf0); + auto veci2 = zvec_convert_to_int(vecf2); + auto veci4 = zvec_convert_to_int(vecf4); + auto veci6 = zvec_convert_to_int(vecf6); + + veci0 = veci0 + vec_zero_point; + veci2 = veci2 + vec_zero_point; + + veci4 = veci4 + vec_zero_point; + veci6 = veci6 + vec_zero_point; + + auto vecshi0 = pack(veci0, veci2); + auto vecshi2 = pack(veci4, veci6); + + auto ret = pack(vecshi0, vecshi2); + + return Vectorized{ret}; + } + + Vectorized C10_ALWAYS_INLINE eq(const Vectorized& other) const { + return Vectorized{_vec.eq(other._vec)}; + } + Vectorized C10_ALWAYS_INLINE ne(const Vectorized& other) const { + return Vectorized{_vec.ne(other._vec)}; + } + Vectorized C10_ALWAYS_INLINE gt(const Vectorized& other) const { + return Vectorized{_vec.gt(other._vec)}; + } + Vectorized C10_ALWAYS_INLINE ge(const Vectorized& other) const { + return Vectorized{_vec.ge(other._vec)}; + } + Vectorized C10_ALWAYS_INLINE lt(const Vectorized& other) const { + return Vectorized{_vec.lt(other._vec)}; + } + Vectorized C10_ALWAYS_INLINE le(const Vectorized& other) const { + return Vectorized{_vec.le(other._vec)}; + } + + Vectorized clamp_min(const Vectorized& min) const { + return Vectorized{_vec.clamp_min(min._vec)}; + } + + Vectorized clamp_max(const Vectorized& max) const { + return Vectorized{_vec.clamp_max(max._vec)}; + } + + Vectorized minimum(const Vectorized& other) const { + return Vectorized{_vec.minimum(other._vec)}; + } + + Vectorized maximum(const Vectorized& other) const { + return Vectorized{_vec.maximum(other._vec)}; + } +}; + +#define ZVECTOR_OPERATORS(typex) \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator+( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec() + b.vec()}; \ + } \ + \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator-( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec() - b.vec()}; \ + } \ + \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator*( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec() * b.vec()}; \ + } \ + \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator/( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec() / b.vec()}; \ + } \ + \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator&( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec() & b.vec()}; \ + } \ + \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator|( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec() | b.vec()}; \ + } \ + \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator^( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec() ^ b.vec()}; \ + } \ + \ + Vectorized C10_ALWAYS_INLINE operator==( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec() == b.vec()}; \ + } \ + \ + Vectorized C10_ALWAYS_INLINE operator!=( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec() != b.vec()}; \ + } \ + \ + Vectorized C10_ALWAYS_INLINE operator>( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec() > b.vec()}; \ + } \ + \ + Vectorized C10_ALWAYS_INLINE operator>=( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec() >= b.vec()}; \ + } \ + \ + Vectorized C10_ALWAYS_INLINE operator<( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec() < b.vec()}; \ + } \ + \ + Vectorized C10_ALWAYS_INLINE operator<=( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec() <= b.vec()}; \ + } + +ZVECTOR_OPERATORS(c10::qint32) +ZVECTOR_OPERATORS(c10::qint8) +ZVECTOR_OPERATORS(c10::quint8) + +#undef ZVECTOR_OPERATORS + +DEFINE_CLAMP_MAXMIN_FUNCS(c10::quint8) +DEFINE_CLAMP_MAXMIN_FUNCS(c10::qint8) +DEFINE_CLAMP_MAXMIN_FUNCS(c10::qint32) + +template +constexpr auto real_mask() { + return (ZSimdVect)ZSimdVectBinary{0xFFFFFFFF, 0, 0xFFFFFFFF, 0}; +} + +template <> +constexpr auto real_mask() { + return (ZSimdVect)ZSimdVectBinary{0xFFFFFFFFFFFFFFFF, 0}; +} + +template +constexpr auto image_mask() { + return (ZSimdVect)ZSimdVectBinary{0, 0xFFFFFFFF, 0, 0xFFFFFFFF}; +} + +template <> +constexpr auto image_mask() { + return (ZSimdVect)ZSimdVectBinary{0, 0xFFFFFFFFFFFFFFFF}; +} + +template +constexpr auto rsign_mask() { + return ZSimdVect{-0.f, 0.f, -0.f, 0.f}; +} + +template <> +constexpr auto rsign_mask() { + return ZSimdVect{-0.0, 0.f}; +} + +template +constexpr auto isign_mask() { + return ZSimdVect{0.0, -0.f, 0.0, -0.f}; +} + +template <> +constexpr auto isign_mask() { + return ZSimdVect{0.0, -0.0}; +} + +template +constexpr auto image_one() { + return ZSimdVect{0, 1.f, 0, 1.f}; +} + +template <> +constexpr auto image_one() { + return ZSimdVect{0.0, 1.0}; +} + +template +constexpr auto pi_half() { + return ZSimdVect{(float)(M_PI / 2.0), 0.f, (float)(M_PI / 2.0), 0.f}; +} + +template <> +constexpr auto pi_half() { + return ZSimdVect{M_PI / 2.0, 0.0}; +} + +template +constexpr auto image_half() { + return ZSimdVect{0, 0.5f, 0, 0.5f}; +} + +template <> +constexpr auto image_half() { + return ZSimdVect{0.0, 0.5}; +} + +template +constexpr U log2e_inv() { + return static_cast(1.4426950408889634); +} + +template +constexpr U log10e_inv() { + return static_cast(0.43429448190325176); +} + +template +struct is_vec_specialized_for< + T, + std::enable_if_t()>> + : std::bool_constant {}; + +template +struct Vectorized()>> { + public: + using underline_type = decltype(std::declval().imag()); + using value_type = T; + using vtype = ZSimdVect; + using vmaskType = ZSimdVectBinary; + using vinner_type = Vectorized; + using size_type = int; + using vinner_data = typename Vectorized::vinner_data; + + static constexpr size_type size() { + return VECTOR_WIDTH / sizeof(value_type); + } + + private: + vinner_type _vec; + + public: + Vectorized() {} + + C10_ALWAYS_INLINE Vectorized(const vinner_data& v) + : _vec{v.first, v.second} {} + + template = 0> + C10_ALWAYS_INLINE Vectorized(T s1, T s2) + : _vec{s1.real(), s1.imag(), s2.real(), s2.imag()} {} + + template = 0> + C10_ALWAYS_INLINE Vectorized(T s1, T s2, T s3, T s4) + : _vec{ + s1.real(), + s1.imag(), + s2.real(), + s2.imag(), + s3.real(), + s3.imag(), + s4.real(), + s4.imag()} {} + + template = 0> + C10_ALWAYS_INLINE Vectorized(T s) : Vectorized(s, s) {} + + template = 0> + C10_ALWAYS_INLINE Vectorized(T s) : Vectorized(s, s, s, s) {} + + C10_ALWAYS_INLINE operator vinner_type() const { + return _vec; + } + + C10_ALWAYS_INLINE const vinner_type& vec() const { + return _vec; + } + + C10_ALWAYS_INLINE operator vinner_data() const { + return _vec.data(); + } + + C10_ALWAYS_INLINE vinner_data data() const { + return _vec.data(); + } + + template + static Vectorized C10_ALWAYS_INLINE + loadu(const U* ptr, int count = size()) { + return Vectorized{vinner_type::loadu(ptr, 2 * count)}; + } + + template + void C10_ALWAYS_INLINE store(U* ptr, int count = size()) const { + return _vec.store(ptr, 2 * count); + } + + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + // convert std::complex index mask to V index mask: xy -> xxyy + vinner_type vmask = mask.vec(); + auto mask_complex = vinner_type( + vec_mergeh(vmask.vec0(), vmask.vec0()), + vec_mergeh(vmask.vec1(), vmask.vec1())); + return Vectorized{vinner_type::blendv(a.vec(), b.vec(), mask_complex)}; + } + + template + static auto C10_ALWAYS_INLINE + blend(const Vectorized& a, const Vectorized& b) { + constexpr int mask_complex = maskForComplex(mask); + return Vectorized{ + vinner_type::template blend(a.vec(), b.vec())}; + } + + template + static std::enable_if_t> arange( + T base = 0, + step_t step = static_cast(1)) { + return Vectorized(base, base + step); + } + + template + static std::enable_if_t> arange( + T base = 0, + step_t step = static_cast(1)) { + return Vectorized( + base, + base + step, + base + value_type(2) * step, + base + value_type(3) * step); + } + + template + static inline std::enable_if_t<(Z >= C), Vectorized> set_inner( + const Vectorized& a, + const Vectorized& b, + size_t count) { + return b; + } + + template + static inline std::enable_if_t<(Z < C), Vectorized> set_inner( + const Vectorized& a, + const Vectorized& b, + size_t count) { + if (count == Z) + return blend(a, b); + else + return set_inner(a, b, count); + } + + static Vectorized set( + const Vectorized& a, + const Vectorized& b, + size_t count = size()) { + if (count == 0) + return a; + return set_inner<1, size()>(a, b, count); + } + + const T& operator[](int idx) const = delete; + T& operator[](int idx) = delete; + + template < + typename U = T, + std::enable_if_t>::value, int> = 0> + Vectorized mapOrdinary(T (*const f)(const T&)) const { + auto v0 = _vec.vec0(); + auto v1 = _vec.vec1(); + return Vectorized{ + f(T(v0[0], v0[1])), + f(T(v0[2], v0[3])), + f(T(v1[0], v1[1])), + f(T(v1[2], v1[3]))}; + } + + template < + typename U = T, + std::enable_if_t>::value, int> = 0> + Vectorized mapOrdinary(T (*const f)(const T&)) const { + auto v0 = _vec.vec0(); + auto v1 = _vec.vec1(); + return Vectorized{f(T(v0[0], v0[1])), f(T(v1[0], v1[1]))}; + } + + template < + typename U = T, + std::enable_if_t>::value, int> = 0> + Vectorized mapOrdinary(T (*const f)(T)) const { + auto v0 = _vec.vec0(); + auto v1 = _vec.vec1(); + return Vectorized{ + f(T(v0[0], v0[1])), + f(T(v0[2], v0[3])), + f(T(v1[0], v1[1])), + f(T(v1[2], v1[3]))}; + } + + template < + typename U = T, + std::enable_if_t>::value, int> = 0> + Vectorized mapOrdinary(T (*const f)(T)) const { + auto v0 = _vec.vec0(); + auto v1 = _vec.vec1(); + return Vectorized{f(T(v0[0], v0[1])), f(T(v1[0], v1[1]))}; + } + + template < + typename U = T, + std::enable_if_t>::value, int> = 0> + inline Vectorized mapOrdinary( + T (*const f)(const T&, const T&), + const Vectorized& b) const { + auto v0 = _vec.vec0(); + auto v1 = _vec.vec1(); + auto bvec = b.vec(); + auto b0 = bvec.vec0(); + auto b1 = bvec.vec1(); + T a00 = f(T(v0[0], v0[1]), T(b0[0], b0[1])); + T a01 = f(T(v0[2], v0[3]), T(b0[2], b0[3])); + T a02 = f(T(v1[0], v1[1]), T(b1[0], b1[1])); + T a03 = f(T(v1[2], v1[3]), T(b1[2], b1[3])); + return Vectorized{a00, a01, a02, a03}; + } + + template < + typename U = T, + std::enable_if_t>::value, int> = 0> + inline Vectorized mapOrdinary( + T (*const f)(const T&, const T&), + const Vectorized& b) const { + auto v0 = _vec.vec0(); + auto v1 = _vec.vec1(); + auto bvec = b.vec(); + auto b0 = bvec.vec0(); + auto b1 = bvec.vec1(); + U a00 = f(U(v0[0], v0[1]), U(b0[0], b0[1])); + U a01 = f(U(v1[0], v1[1]), U(b1[0], b1[1])); + return Vectorized{a00, a01}; + } + + template < + typename U = T, + std::enable_if_t>::value, int> = 0> + static typename Vectorized::vinner_type real_neg( + const typename Vectorized::vinner_type& a) { + const auto swap_mask = ZSimdVectBinary{ + 0, 1, 2, 3, 20, 21, 22, 23, 8, 9, 10, 11, 28, 29, 30, 31}; + + auto a_neg = a.neg(); + vtype v0 = vec_perm(a_neg.vec0(), a.vec0(), swap_mask); + vtype v1 = vec_perm(a_neg.vec1(), a.vec1(), swap_mask); + return {v0, v1}; + } + + template < + typename U = T, + std::enable_if_t>::value, int> = 0> + static typename Vectorized::vinner_type real_neg( + const typename Vectorized::vinner_type& a) { + auto a_neg = a.neg(); + vtype v0 = {a_neg.vec0()[0], a.vec0()[1]}; + vtype v1 = {a_neg.vec1()[0], a.vec1()[1]}; + return {v0, v1}; + } + + Vectorized angle2_() const { + auto b_a = _vec.swapped(); // b a + return Vectorized{_vec.atan2(b_a).swapped()}; + } + + Vectorized angle() const { + return angle2_().real(); + } + + Vectorized atan() const { + // atan(x) = i/2 * ln((i + z)/(i - z)) + auto ione = Vectorized{vinner_type(image_one())}; + auto sum = ione + *this; + auto sub = ione - *this; + auto ln = (sum / sub).log(); // ln((i + z)/(i - z)) + return ln * + Vectorized{vinner_type(image_half())}; // i/2*ln() + } + + Vectorized atanh() const { + return mapOrdinary(std::atanh); + } + + Vectorized asin() const { + // asin(x) + // = -i*ln(iz + sqrt(1 -z^2)) + // = -i*ln((ai - b) + sqrt(1 - (a + bi)*(a + bi))) + // = -i*ln((-b + ai) + sqrt(1 - (a**2 - b**2) - 2*abi)) +#if 1 + vinner_type cnj = conj().vec(); + vinner_type b_a = cnj.swapped(); + vinner_type ab = cnj * b_a; + vinner_type im = ab + ab; + vinner_type val_2 = _vec * _vec; + vinner_type val_2_swapped = val_2.swapped(); + vinner_type re = vinner_type::horizontal_sub_perm(val_2, val_2_swapped); + re = vinner_type(static_cast(1)) - re; + constexpr int blend_mask = + blend_choice(); // 0x0A for complex , 0xAA for complex + vinner_type blendx = vinner_type::template blend(re, im); + auto root = Vectorized(blendx).sqrt(); + auto ln = Vectorized(Vectorized(b_a) + root).log(); + return Vectorized(ln.vec().swapped()).conj(); +#else + return mapOrdinary(std::asin); +#endif + } + + Vectorized acos() const { + // acos(x) = pi/2 - asin(x) + return Vectorized(vinner_type(pi_half())) - asin(); + } + + Vectorized sin() const { + return mapOrdinary(std::sin); + } + Vectorized sinh() const { + return mapOrdinary(std::sinh); + } + Vectorized cos() const { + return mapOrdinary(std::cos); + } + Vectorized cosh() const { + return mapOrdinary(std::cosh); + } + Vectorized ceil() const { + return Vectorized{_vec.ceil()}; + } + Vectorized floor() const { + return Vectorized{_vec.floor()}; + } + Vectorized neg() const { + return Vectorized(_vec.neg()); + } + Vectorized round() const { + return Vectorized{_vec.round()}; + } + Vectorized tan() const { + return mapOrdinary(std::tan); + } + Vectorized tanh() const { + return mapOrdinary(std::tanh); + } + Vectorized trunc() const { + return Vectorized{_vec.trunc()}; + } + + Vectorized C10_ALWAYS_INLINE eq(const Vectorized& other) const { + auto eq = _vec.eq(other._vec); // compares real and imag individually + // If both real numbers and imag numbers are equal, then the complex numbers + // are equal + auto real = eq & vinner_type(real_mask()); + auto imag = (eq & vinner_type(image_mask())).swapped(); + return Vectorized{real & imag}; + } + Vectorized C10_ALWAYS_INLINE ne(const Vectorized& other) const { + auto ne = _vec.ne(other._vec); // compares real and imag individually + // If either real numbers or imag numbers are not equal, then the complex + // numbers are not equal + auto real = ne & vinner_type(real_mask()); + auto imag = (ne & vinner_type(image_mask())).swapped(); + return Vectorized{real | imag}; + } + + Vectorized real() const { + return Vectorized(_vec & vinner_type(real_mask())); + } + Vectorized imag_() const { + return Vectorized(_vec & vinner_type(image_mask())); + } + Vectorized imag() const { + return Vectorized{ + (_vec & vinner_type(image_mask())).swapped()}; + } + + Vectorized conj() const { + return Vectorized(_vec ^ vinner_type(isign_mask())); + } + + vinner_data abs_2_() const { + auto a = _vec * _vec; + a = a + a.swapped(); + return a.mergee().data(); + } + + static T abs_helper(const T& value) { + return T(std::abs(value)); + } + + Vectorized abs() const { + return mapOrdinary(abs_helper); + } + + Vectorized exp() const { + return mapOrdinary(std::exp); + } + + Vectorized exp2() const { + return mapOrdinary(exp2_impl); + } + + Vectorized expm1() const { + return mapOrdinary(std::expm1); + } + + Vectorized log() const { + return mapOrdinary(std::log); + } + + Vectorized log2() const { + // log2eB_inv + auto ret = log(); + return Vectorized{ret._vec * vinner_type(log2e_inv())}; + } + + Vectorized log10() const { + auto ret = log(); + return Vectorized{ret._vec * vinner_type(log10e_inv())}; + } + + Vectorized log1p() const { + return mapOrdinary(std::log1p); + } + + Vectorized sgn() const { + return mapOrdinary(at::native::sgn_impl); + } + + Vectorized pow(const Vectorized& exp) const { + return mapOrdinary(std::pow, exp); + } + + Vectorized sqrt() const { + return mapOrdinary(std::sqrt); + } + + Vectorized reciprocal() const { + // re + im*i = (a + bi) / (c + di) + // re = (ac + bd)/abs_2() = c/abs_2() + // im = (bc - ad)/abs_2() = d/abs_2() + vinner_type c_d = _vec ^ vinner_type(isign_mask()); + vinner_type abs = abs_2_(); + return Vectorized{c_d / abs}; + } + + Vectorized rsqrt() const { + return sqrt().reciprocal(); + } + + Vectorized lt(const Vectorized& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vectorized le(const Vectorized& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vectorized gt(const Vectorized& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vectorized ge(const Vectorized& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } +}; + +#define ZVECTOR_OPERATORS(typex) \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator+( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec() + b.vec()}; \ + } \ + \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator-( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec() - b.vec()}; \ + } \ + \ + template <> \ + Vectorized inline operator*( \ + const Vectorized& a, const Vectorized& b) { \ + /* (a + bi) * (c + di) = (ac - bd) + (ad + bc)i */ \ + Vectorized::vinner_type bv = b.vec(); \ + \ + /* this is more z arch friendly than simulating horizontal from x86 */ \ + Vectorized::vinner_type vi = bv.mergeo(); \ + Vectorized::vinner_type vr = bv.mergee(); \ + vi = vi ^ \ + Vectorized::vinner_type( \ + rsign_mask::underline_type>()); \ + Vectorized::vinner_type ret = a.vec() * vr; \ + Vectorized::vinner_type vx_swapped = a.vec().swapped(); \ + ret = fmadd(vx_swapped, vi, ret); \ + \ + return Vectorized{ret}; \ + } \ + \ + template <> \ + Vectorized inline operator/( \ + const Vectorized& a, const Vectorized& b) { \ + /* Unfortunately, this breaks some tests */ \ + /* Implement it like it's done for avx2 */ \ + auto fabs_cd = b.vec().abs(); /* |c| |d| */ \ + auto fabs_dc = fabs_cd.swapped(); /* |d| |c| */ \ + auto scale = Vectorized::vinner_type{1.0} / \ + maximum(fabs_cd, fabs_dc); /* 1/sc 1/sc */ \ + auto a2 = a.vec() * scale; /* a/sc b/sc */ \ + auto b2 = b.vec() * scale; /* c/sc d/sc */ \ + auto acbd2 = a2 * b2; /* ac/sc^2 bd/sc^2 */ \ + \ + auto dc2 = b2.swapped(); /* d/sc c/sc */ \ + dc2 = Vectorized::real_neg(dc2); /* -d/|c,d| c/sc */ \ + auto adbc2 = a2 * dc2; /* -ad/sc^2 bc/sc^2 */ \ + auto sum1 = acbd2 + acbd2.swapped(); /* (ac+bd)/sc^2 (ac+bd)/sc^2 */ \ + auto sum2 = adbc2 + adbc2.swapped(); /* (bc-ad)/sc^2 (bc-ad)/sc^2 */ \ + auto res2 = Vectorized::vinner_type::mergee( \ + sum1, sum2); /* (ac+bd)/sc^2 (bc-ad)/sc^2 */ \ + \ + /* get the denominator */ \ + Vectorized::vinner_type denom2 = \ + Vectorized{b2}.abs_2_(); /* (c^2+d^2)/sc^2 (c^2+d^2)/sc^2 */ \ + res2 = res2 / denom2; \ + return Vectorized{res2}; \ + } \ + \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator&( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec() & b.vec()}; \ + } \ + \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator|( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec() | b.vec()}; \ + } \ + \ + template <> \ + Vectorized C10_ALWAYS_INLINE operator^( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec() ^ b.vec()}; \ + } \ + \ + Vectorized C10_ALWAYS_INLINE operator==( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec() == b.vec()}; \ + } \ + \ + Vectorized C10_ALWAYS_INLINE operator!=( \ + const Vectorized& a, const Vectorized& b) { \ + return Vectorized{a.vec() != b.vec()}; \ + } \ + \ + Vectorized C10_ALWAYS_INLINE operator<( \ + const Vectorized& a, const Vectorized& b) { \ + TORCH_CHECK(false, "not supported for complex numbers"); \ + } \ + \ + Vectorized C10_ALWAYS_INLINE operator<=( \ + const Vectorized& a, const Vectorized& b) { \ + TORCH_CHECK(false, "not supported for complex numbers"); \ + } \ + \ + Vectorized C10_ALWAYS_INLINE operator>( \ + const Vectorized& a, const Vectorized& b) { \ + TORCH_CHECK(false, "not supported for complex numbers"); \ + } \ + \ + Vectorized C10_ALWAYS_INLINE operator>=( \ + const Vectorized& a, const Vectorized& b) { \ + TORCH_CHECK(false, "not supported for complex numbers"); \ + } + +ZVECTOR_OPERATORS(c10::complex) +ZVECTOR_OPERATORS(c10::complex) + +#undef ZVECTOR_OPERATORS + +template = 0> +std::pair, Vectorized> inline inner_interleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, a1, a2, a3} + // b = {b0, b1, b2, b3} + using vtype = typename Vectorized::vtype; + vtype ab00 = {a.vec0()[0], b.vec0()[0]}; + vtype ab11 = {a.vec0()[1], b.vec0()[1]}; + vtype ab2_00 = {a.vec1()[0], b.vec1()[0]}; + vtype ab2_11 = {a.vec1()[1], b.vec1()[1]}; + // return {a0, b0, a1, b1} + // {a2, b2, a3, b3} + return std::make_pair( + Vectorized{ab00, ab11}, Vectorized{ab2_00, ab2_11}); +} + +template = 0> +std::pair, Vectorized> inline inner_deinterleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, b0, a1, b1} + // b = {a2, b2, a3, b3} + using vtype = typename Vectorized::vtype; + vtype aa01 = {a.vec0()[0], a.vec1()[0]}; + vtype aa23 = {b.vec0()[0], b.vec1()[0]}; + + vtype bb_01 = {a.vec0()[1], a.vec1()[1]}; + vtype bb_23 = {b.vec0()[1], b.vec1()[1]}; + + // swap lanes: + // return {a0, a1, a2, a3} + // {b0, b1, b2, b3} + return std::make_pair(Vectorized{aa01, aa23}, Vectorized{bb_01, bb_23}); +} + +template = 0> +std::pair, Vectorized> inline inner_interleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, a1, a2, a3,, a4, a5, a6, a7} + // b = {b0, b1, b2, b3,, b4, b5, b6, b7} + using vtype = typename Vectorized::vtype; + vtype ab0011 = vec_mergeh(a.vec0(), b.vec0()); + vtype ab2233 = vec_mergel(a.vec0(), b.vec0()); + + vtype ab2_0011 = vec_mergeh(a.vec1(), b.vec1()); + vtype ab2_2233 = vec_mergel(a.vec1(), b.vec1()); + // group cols crossing lanes: + // return {a0, b0, a1, b1,, a2, b2, a3, b3} + // {a4, b4, a5, b5,, a6, b6, a7, b7} + + return std::make_pair( + Vectorized{ab0011, ab2233}, Vectorized{ab2_0011, ab2_2233}); +} + +template = 0> +std::pair, Vectorized> inline inner_deinterleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, b0, a1, b1,, a2, b2, a3, b3} + // b = {a4, b4, a5, b5,, a6, b6, a7, b7} + using vtype = typename Vectorized::vtype; + // {a0,a2,b0,b2} {a1,a3,b1,b3} + vtype a0a2b0b2 = vec_mergeh(a.vec0(), a.vec1()); + vtype a1a3b1b3 = vec_mergel(a.vec0(), a.vec1()); + + vtype aa0123 = vec_mergeh(a0a2b0b2, a1a3b1b3); + vtype bb0123 = vec_mergel(a0a2b0b2, a1a3b1b3); + + vtype a0a2b0b2_2 = vec_mergeh(b.vec0(), b.vec1()); + vtype a1a3b1b3_2 = vec_mergel(b.vec0(), b.vec1()); + + vtype aa0123_2 = vec_mergeh(a0a2b0b2_2, a1a3b1b3_2); + vtype bb0123_2 = vec_mergel(a0a2b0b2_2, a1a3b1b3_2); + + // it could be done with vec_perm ,too + // swap lanes: + // return {a0, a1, a2, a3,, a4, a5, a6, a7} + // {b0, b1, b2, b3,, b4, b5, b6, b7} + + return std::make_pair( + Vectorized{aa0123, aa0123_2}, Vectorized{bb0123, bb0123_2}); +} + +template <> +std::pair, Vectorized> inline interleave2( + const Vectorized& a, + const Vectorized& b) { + return inner_interleave2(a, b); +} + +template <> +std::pair, Vectorized> inline interleave2( + const Vectorized& a, + const Vectorized& b) { + return inner_interleave2(a, b); +} + +template <> +std::pair, Vectorized> inline interleave2( + const Vectorized& a, + const Vectorized& b) { + return inner_interleave2(a, b); +} + +template <> +std::pair, Vectorized> inline interleave2( + const Vectorized& a, + const Vectorized& b) { + return inner_interleave2(a, b); +} + +template <> +std::pair, Vectorized> inline deinterleave2( + const Vectorized& a, + const Vectorized& b) { + return inner_deinterleave2(a, b); +} + +template <> +std::pair, Vectorized> inline deinterleave2< + int32_t>(const Vectorized& a, const Vectorized& b) { + return inner_deinterleave2(a, b); +} + +template <> +std::pair, Vectorized> inline deinterleave2( + const Vectorized& a, + const Vectorized& b) { + return inner_deinterleave2(a, b); +} + +template <> +std::pair, Vectorized> inline deinterleave2< + int64_t>(const Vectorized& a, const Vectorized& b) { + return inner_deinterleave2(a, b); +} + +template +std::enable_if_t< + std::is_same_v, + at::vec::Vectorized< + float>> inline convert_int8_to_float(const Vectorized& src) { + // Note: this function only convert inputs number of elements equal to + // at::vec::Vectorized.size() Only handle first 64 bits + auto vec_int = src.to_vec_float_helper(); + + return zvec_convert_to_float(vec_int); +} + +template +std::enable_if_t< + std::is_same_v, + at::vec::Vectorized< + T>> inline convert_float_to_int8(const Vectorized& src) { + constexpr auto min_val = std::numeric_limits::min(); + constexpr auto max_val = std::numeric_limits::max(); + + auto vec_int = clamp( + zvec_convert_to_int(src), + Vectorized(min_val), + Vectorized(max_val)); + + return vec_int.to_vec_uint8_helper(); +} + +#undef DEFINE_CLAMP_MAXMIN_FUNCS +#undef DEFINE_MAXMIN_FUNCS +} // namespace CPU_CAPABILITY +} // namespace vec +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512.h new file mode 100644 index 0000000000000000000000000000000000000000..c0250e40e3a7ecb2dfdf5ce4da5e2f22289b1a83 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512.h @@ -0,0 +1,414 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include + +// clang-format off +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +// clang-format on + +#include +#include +#include +#include +#include + +namespace at { +namespace vec { + +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +inline std::ostream& operator<<(std::ostream& stream, const c10::qint32& val) { + stream << val.val_; + return stream; +} +inline std::ostream& operator<<(std::ostream& stream, const c10::qint8& val) { + stream << static_cast(val.val_); + return stream; +} +inline std::ostream& operator<<(std::ostream& stream, const c10::quint8& val) { + stream << static_cast(val.val_); + return stream; +} + +template +std::ostream& operator<<(std::ostream& stream, const Vectorized& vec) { + T buf[Vectorized::size()]; + vec.store(buf); + stream << "vec["; + for (int i = 0; i != Vectorized::size(); i++) { + if (i != 0) { + stream << ", "; + } + stream << buf[i]; + } + stream << ']'; + return stream; +} + +#if defined(CPU_CAPABILITY_AVX512) + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CAST (AVX512) +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template <> +inline Vectorized cast(const Vectorized& src) { + return _mm512_castpd_ps(src); +} + +template <> +inline Vectorized cast(const Vectorized& src) { + return _mm512_castps_pd(src); +} + +template <> +inline Vectorized cast(const Vectorized& src) { + return _mm512_castsi512_ps(src); +} + +template <> +inline Vectorized cast( + const Vectorized& src) { + return _mm512_castsi512_pd(src); +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +#ifndef _MSC_VER +// MSVC is not working well on complex function overload. +template +std::enable_if_t< + scale == 1 || scale == 2 || scale == 4 || scale == 8, + Vectorized< + double>> inline gather(const double* base_addr, const Vectorized& vindex) { + return _mm512_i64gather_pd(vindex, base_addr, scale); +} + +template +std::enable_if_t< + scale == 1 || scale == 2 || scale == 4 || scale == 8, + Vectorized< + float>> inline gather(const float* base_addr, const Vectorized& vindex) { + return _mm512_i32gather_ps(vindex, base_addr, scale); +} +#endif +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MASK GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +#ifndef _MSC_VER +// MSVC is not working well on complex function overload. +template +std:: + enable_if_t> inline mask_gather( + const Vectorized& src, + const double* base_addr, + const Vectorized& vindex, + Vectorized& mask) { + auto all_ones = _mm512_castsi512_pd(_mm512_set1_epi64(0xFFFFFFFFFFFFFFFF)); + auto mask_ = _mm512_cmp_pd_mask(all_ones, mask.values, _CMP_EQ_OQ); + return _mm512_mask_i64gather_pd(src, mask_, vindex, base_addr, scale); +} + +template +std:: + enable_if_t> inline mask_gather( + const Vectorized& src, + const float* base_addr, + const Vectorized& vindex, + Vectorized& mask) { + auto all_ones = _mm512_castsi512_ps(_mm512_set1_epi32(0xFFFFFFFF)); + auto mask_ = _mm512_cmp_ps_mask(all_ones, mask.values, _CMP_EQ_OQ); + return _mm512_mask_i32gather_ps(src, mask_, vindex, base_addr, scale); +} +#endif +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CONVERT ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template <> +Vectorized inline convert_to_int_of_same_size( + const Vectorized& src) { + return _mm512_cvtpd_epi64(src); +} + +template <> +Vectorized inline convert_to_int_of_same_size( + const Vectorized& src) { + return _mm512_cvttps_epi32(src); +} + +template <> +Vectorized inline convert_to_fp_of_same_size( + const Vectorized& src) { + return _mm512_cvtepi64_pd(src); +} + +template <> +Vectorized inline convert_to_fp_of_same_size( + const Vectorized& src) { + return _mm512_cvtepi32_ps(src); +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ INTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template <> +std::pair, Vectorized> inline interleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, a1, a3, a3, a4, a5, a6, a7} + // b = {b0, b1, b2, b3, b4, b5, b6, b7} + // group cols crossing lanes: + // return {a0, b0, a1, b1, a2, b2, a3, b3} + // {a4, b4, a5, b5, a6, b6, a7, b7} + __m512i idx1 = _mm512_set_epi64(11, 3, 10, 2, 9, 1, 8, 0); + __m512i idx2 = _mm512_set_epi64(15, 7, 14, 6, 13, 5, 12, 4); + return std::make_pair( + _mm512_mask_permutex2var_pd(a, 0xff, idx1, b), + _mm512_mask_permutex2var_pd(a, 0xff, idx2, b)); +} + +template <> +std::pair, Vectorized> inline interleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, + // a15} b = {b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, + // b14, b15} + // + // return: + // {a0, b0, a1, b1, a2, b2, a3, b3, a4, b4, a5, b5, a6, b6, a7, b7} + // {a8, b8, a9, b9, a10, b10, a11, b11, a12, b12, a13, b13, a14, b14, a15, + // b15} + __m512i idx1 = + _mm512_set_epi32(23, 7, 22, 6, 21, 5, 20, 4, 19, 3, 18, 2, 17, 1, 16, 0); + __m512i idx2 = _mm512_set_epi32( + 31, 15, 30, 14, 29, 13, 28, 12, 27, 11, 26, 10, 25, 9, 24, 8); + return std::make_pair( + _mm512_mask_permutex2var_ps(a, 0xffff, idx1, b), + _mm512_mask_permutex2var_ps(a, 0xffff, idx2, b)); +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ DEINTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template <> +std::pair, Vectorized> inline deinterleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, b0, a1, b1, a2, b2, a3, b3} + // b = {a4, b4, a5, b5, a6, b6, a7, b7} + // output: + // return {a0, a1, a2, a3, a4, a5, a6, a7} + // {b0, b1, b2, b3, b4, b5, b6, b7} + // The members of indices have been written in binary format for better + // understandability + __m512i idx1 = _mm512_set_epi64(14, 12, 10, 8, 6, 4, 2, 0); + __m512i idx2 = _mm512_set_epi64(15, 13, 11, 9, 7, 5, 3, 1); + + return std::make_pair( + _mm512_mask_permutex2var_pd(a, 0xff, idx1, b), + _mm512_mask_permutex2var_pd(a, 0xff, idx2, b)); +} + +template <> +std::pair, Vectorized> inline deinterleave2( + const Vectorized& a, + const Vectorized& b) { + // inputs: + // a = {a0, b0, a1, b1, a2, b2, a3, b3, a4, b4, a5, b5, a6, b6, a7, b7} + // b = {a8, b8, a9, b9, a10, b10, a11, b11, a12, b12, a13, b13, a14, b14, + // a15, b15} + // output: + // return {a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, + // a15} + // {b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, + // b15} + __m512i idx1 = _mm512_set_epi32( + 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0); + __m512i idx2 = _mm512_set_epi32( + 31, 29, 27, 25, 23, 21, 19, 17, 15, 13, 11, 9, 7, 5, 3, 1); + + return std::make_pair( + _mm512_mask_permutex2var_ps(a, 0xffff, idx1, b), + _mm512_mask_permutex2var_ps(a, 0xffff, idx2, b)); +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FLIP ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template <> +inline Vectorized flip(const Vectorized& v) { + const __m512i mask = + _mm512_set_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + return _mm512_permutexvar_ps(mask, v); +} + +template <> +inline Vectorized flip(const Vectorized& v) { + const __m512i mask = _mm512_set_epi64(0, 1, 2, 3, 4, 5, 6, 7); + return _mm512_permutexvar_pd(mask, v); +} + +template <> +inline Vectorized flip(const Vectorized& v) { + const __m512i mask = _mm512_set_epi64(0, 1, 2, 3, 4, 5, 6, 7); + return _mm512_permutexvar_epi64(mask, v); +} + +template <> +inline Vectorized flip(const Vectorized& v) { + const __m512i mask = + _mm512_set_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + return _mm512_permutexvar_epi32(mask, v); +} + +template <> +inline Vectorized flip(const Vectorized& v) { + const __m512i mask = _mm512_set_epi16( + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + 27, + 28, + 29, + 30, + 31); + return _mm512_permutexvar_epi16(mask, v); +} + +inline __m512i flip8(const __m512i& v) { + const __m512i mask1 = _mm512_set_epi8( + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15); + const __m512i mask2 = _mm512_set_epi64(1, 0, 3, 2, 5, 4, 7, 6); + auto reversed_vec = _mm512_shuffle_epi8(v, mask1); + return _mm512_permutexvar_epi64(mask2, reversed_vec); +} + +template <> +inline Vectorized flip(const Vectorized& v) { + return flip8(v); +} + +template <> +inline Vectorized flip(const Vectorized& v) { + return flip8(v); +} + +inline Vectorized operator&&( + const Vectorized& self, + const Vectorized& other) { + const __m512i* self_ = reinterpret_cast(self.as_bytes()); + const __m512i* other_ = reinterpret_cast(other.as_bytes()); + __m512i out = _mm512_and_si512(*self_, *other_); + Vectorized ret; + // We do not have a constructor that takes __m512i, so we need to memcpy + std::memcpy(ret, &out, ret.size() * sizeof(bool)); + return ret; +} + +#endif // defined(CPU_CAPABILITY_AVX512) + +} // namespace CPU_CAPABILITY +} // namespace vec +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_bfloat16.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_bfloat16.h new file mode 100644 index 0000000000000000000000000000000000000000..44a632b3fb6ef40b766b95446efd36d3e4d72657 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_bfloat16.h @@ -0,0 +1,1947 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include +#include + +#if defined(CPU_CAPABILITY_AVX512) +#define SLEEF_STATIC_LIBS +#include +#endif + +namespace at::vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_AVX512) + +#ifndef SLEEF_CONST +#if (defined(__GNUC__) || defined(__CLANG__)) && !defined(__INTEL_COMPILER) +#define SLEEF_CONST const +#else +#define SLEEF_CONST +#endif +#define SLEEF_CONST_OLD SLEEF_CONST +#else +#define SLEEF_CONST_OLD +#endif + +// bfloat16 conversion +static inline void cvtbf16_fp32(const __m256i& a, __m512& o) { + o = _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a), 16)); +} + +static inline void cvtbf16_fp32(const __m512i& a, __m512& o1, __m512& o2) { + __m256i lo = _mm512_extracti32x8_epi32(a, 0); + __m256i hi = _mm512_extracti32x8_epi32(a, 1); + cvtbf16_fp32(lo, o1); + cvtbf16_fp32(hi, o2); +} + +static inline __m256i cvtfp32_bf16(const __m512& src) { + __m512i value = _mm512_castps_si512(src); + __m512i nan = _mm512_set1_epi32(0xffff); + auto mask_value = _mm512_cmp_ps_mask(src, src, _CMP_ORD_Q); + __m512i ones = _mm512_set1_epi32(0x1); + __m512i vec_bias = _mm512_set1_epi32(0x7fff); + // uint32_t lsb = (input >> 16) & 1; + auto t_value = _mm512_and_si512(_mm512_srli_epi32(value, 16), ones); + // uint32_t rounding_bias = 0x7fff + lsb; + t_value = _mm512_add_epi32(t_value, vec_bias); + // input += rounding_bias; + t_value = _mm512_add_epi32(t_value, value); + // input = input >> 16; + t_value = _mm512_srli_epi32(t_value, 16); + // Check NaN before converting back to bf16 + t_value = _mm512_mask_blend_epi32(mask_value, nan, t_value); + return _mm512_cvtusepi32_epi16(t_value); +} + +static inline __m512i cvtfp32_bf16(const __m512& a, const __m512& b) { + __m512i lo = _mm512_castps_si512(a); + __m512i hi = _mm512_castps_si512(b); + __m512i nan = _mm512_set1_epi32(0xffff); + auto mask_lo = _mm512_cmp_ps_mask(a, a, _CMP_ORD_Q); + auto mask_hi = _mm512_cmp_ps_mask(b, b, _CMP_ORD_Q); + __m512i ones = _mm512_set1_epi32(0x1); + __m512i vec_bias = _mm512_set1_epi32(0x7fff); + // uint32_t lsb = (input >> 16) & 1; + auto t_lo = _mm512_and_si512(_mm512_srli_epi32(lo, 16), ones); + auto t_hi = _mm512_and_si512(_mm512_srli_epi32(hi, 16), ones); + // uint32_t rounding_bias = 0x7fff + lsb; + t_lo = _mm512_add_epi32(t_lo, vec_bias); + t_hi = _mm512_add_epi32(t_hi, vec_bias); + // input += rounding_bias; + t_lo = _mm512_add_epi32(t_lo, lo); + t_hi = _mm512_add_epi32(t_hi, hi); + // input = input >> 16; + t_lo = _mm512_srli_epi32(t_lo, 16); + t_hi = _mm512_srli_epi32(t_hi, 16); + // Check NaN before converting back to bf16 + t_lo = _mm512_mask_blend_epi32(mask_lo, nan, t_lo); + t_hi = _mm512_mask_blend_epi32(mask_hi, nan, t_hi); + + t_lo = _mm512_packus_epi32( + t_lo, t_hi); // t_hi[4-7] t_lo[4-7] t_hi[0-4] t_lo[0-4] + __m512i idx = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0); + return _mm512_permutexvar_epi64(idx, t_lo); +} + +static inline __m512i merge_compare_result(const __m512& a, const __m512& b) { + __m512i lo = _mm512_castps_si512(a); + __m512i hi = _mm512_castps_si512(b); + lo = _mm512_srli_epi32(lo, 16); + hi = _mm512_srli_epi32(hi, 16); + auto out = _mm512_packus_epi32(lo, hi); + __m512i idx = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0); + return _mm512_permutexvar_epi64(idx, out); +} + +// float16 conversion +static inline void cvtfp16_fp32(const __m256i& a, __m512& o) { + o = _mm512_cvtph_ps(a); +} + +static inline void cvtfp16_fp32(const __m512i& a, __m512& o1, __m512& o2) { + __m256i lo = _mm512_extracti32x8_epi32(a, 0); + __m256i hi = _mm512_extracti32x8_epi32(a, 1); + cvtfp16_fp32(lo, o1); + cvtfp16_fp32(hi, o2); +} + +static inline __m256i cvtfp32_fp16(const __m512& src) { + return _mm512_cvtps_ph(src, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); +} + +static inline __m512i cvtfp32_fp16(const __m512& a, const __m512& b) { + __m256i lo = + _mm512_cvtps_ph(a, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + __m256i hi = + _mm512_cvtps_ph(b, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + __m512 t_lo = _mm512_castsi512_ps(_mm512_castsi256_si512(lo)); + __m256 t_hi = _mm256_castsi256_ps(hi); + return _mm512_castps_si512(_mm512_insertf32x8(t_lo, t_hi, 1)); +} + +// dtype conversion between float16/bfloat16 and float32 +template < + typename T, + typename std::enable_if_t, int> = 0> +inline void cvt_to_fp32(const __m256i& a, __m512& o); +template <> +inline void cvt_to_fp32(const __m256i& a, __m512& o) { + cvtbf16_fp32(a, o); +} +template <> +inline void cvt_to_fp32(const __m256i& a, __m512& o) { + cvtfp16_fp32(a, o); +} + +template < + typename T, + typename std::enable_if_t, int> = 0> +inline void cvt_to_fp32(const __m512i& a, __m512& o1, __m512& o2); +template <> +inline void cvt_to_fp32(const __m512i& a, __m512& o1, __m512& o2) { + cvtbf16_fp32(a, o1, o2); +} +template <> +inline void cvt_to_fp32(const __m512i& a, __m512& o1, __m512& o2) { + cvtfp16_fp32(a, o1, o2); +} + +template < + typename T, + bool is_compare_op = false, + typename std::enable_if_t, int> = 0> +inline __m512i cvt_from_fp32(const __m512& a, const __m512& b); +template <> +inline __m512i cvt_from_fp32( + const __m512& a, + const __m512& b) { + return cvtfp32_bf16(a, b); +} +template <> +inline __m512i cvt_from_fp32(const __m512& a, const __m512& b) { + return merge_compare_result(a, b); +} +template <> +inline __m512i cvt_from_fp32(const __m512& a, const __m512& b) { + return cvtfp32_fp16(a, b); +} +template <> +inline __m512i cvt_from_fp32(const __m512& a, const __m512& b) { + return cvtfp32_fp16(a, b); +} + +template +class Vectorized16 { + static_assert( + is_reduced_floating_point_v, + "Support only float16 and bfloat16."); + + private: + __m512i values; + + public: + using value_type = uint16_t; + using size_type = int; + static constexpr size_type size() { + return 32; + } + Vectorized16() { + values = _mm512_setzero_si512(); + } + Vectorized16(__m512i v) : values(v) {} + Vectorized16(T val) { + value_type uw = val.x; + values = _mm512_set1_epi16(uw); + } + Vectorized16( + T val1, + T val2, + T val3, + T val4, + T val5, + T val6, + T val7, + T val8, + T val9, + T val10, + T val11, + T val12, + T val13, + T val14, + T val15, + T val16, + T val17, + T val18, + T val19, + T val20, + T val21, + T val22, + T val23, + T val24, + T val25, + T val26, + T val27, + T val28, + T val29, + T val30, + T val31, + T val32) { + values = _mm512_set_epi16( + val32.x, + val31.x, + val30.x, + val29.x, + val28.x, + val27.x, + val26.x, + val25.x, + val24.x, + val23.x, + val22.x, + val21.x, + val20.x, + val19.x, + val18.x, + val17.x, + val16.x, + val15.x, + val14.x, + val13.x, + val12.x, + val11.x, + val10.x, + val9.x, + val8.x, + val7.x, + val6.x, + val5.x, + val4.x, + val3.x, + val2.x, + val1.x); + } + operator __m512i() const { + return values; + } + T& operator[](int idx) = delete; + const T& operator[](int idx) const = delete; + int zero_mask() const { + // returns an integer mask where all zero elements are translated to 1-bit + // and others are translated to 0-bit + return _mm512_cmpeq_epi16_mask(values, _mm512_set1_epi16(0)); + } + static Vectorized loadu(const void* ptr, int16_t count = size()) { + if (count == size()) + return _mm512_loadu_si512(reinterpret_cast(ptr)); + + __mmask32 mask = (1ULL << count) - 1; + return _mm512_maskz_loadu_epi16(mask, ptr); + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + _mm512_storeu_si512(reinterpret_cast<__m512i*>(ptr), values); + } else if (count > 0) { + __mmask32 mask = (1ULL << count) - 1; + _mm512_mask_storeu_epi16(ptr, mask, values); + } + } + template + static Vectorized blend(const Vectorized& a, const Vectorized& b) { + return _mm512_mask_blend_epi16(mask, a.values, b.values); + } + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + auto all_ones = _mm512_set1_epi16(0xFFFF); + auto mask_ = _mm512_cmp_epi16_mask(mask, all_ones, _MM_CMPINT_EQ); + return _mm512_mask_blend_epi16(mask_, a.values, b.values); + } + template + static Vectorized arange( + T base = 0.f, + step_t step = static_cast(1)) { + return Vectorized( + base, + base + step, + base + 2 * step, + base + 3 * step, + base + 4 * step, + base + 5 * step, + base + 6 * step, + base + 7 * step, + base + 8 * step, + base + 9 * step, + base + 10 * step, + base + 11 * step, + base + 12 * step, + base + 13 * step, + base + 14 * step, + base + 15 * step, + base + 16 * step, + base + 17 * step, + base + 18 * step, + base + 19 * step, + base + 20 * step, + base + 21 * step, + base + 22 * step, + base + 23 * step, + base + 24 * step, + base + 25 * step, + base + 26 * step, + base + 27 * step, + base + 28 * step, + base + 29 * step, + base + 30 * step, + base + 31 * step); + } + static Vectorized set( + const Vectorized& a, + const Vectorized& b, + int64_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + case 4: + return blend<15>(a, b); + case 5: + return blend<31>(a, b); + case 6: + return blend<63>(a, b); + case 7: + return blend<127>(a, b); + case 8: + return blend<255>(a, b); + case 9: + return blend<511>(a, b); + case 10: + return blend<1023>(a, b); + case 11: + return blend<2047>(a, b); + case 12: + return blend<4095>(a, b); + case 13: + return blend<8191>(a, b); + case 14: + return blend<16383>(a, b); + case 15: + return blend<32767>(a, b); + case 16: + return blend<65535>(a, b); + case 17: + return blend<131071>(a, b); + case 18: + return blend<262143>(a, b); + case 19: + return blend<524287>(a, b); + case 20: + return blend<1048575>(a, b); + case 21: + return blend<2097151>(a, b); + case 22: + return blend<4194303>(a, b); + case 23: + return blend<8388607>(a, b); + case 24: + return blend<16777215>(a, b); + case 25: + return blend<33554431>(a, b); + case 26: + return blend<67108863>(a, b); + case 27: + return blend<134217727>(a, b); + case 28: + return blend<268435455>(a, b); + case 29: + return blend<536870911>(a, b); + case 30: + return blend<1073741823>(a, b); + case 31: + return blend<2147483647>(a, b); + } + return b; + } +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wignored-qualifiers" + + Vectorized map(SLEEF_CONST __m512 (*SLEEF_CONST_OLD vop)(__m512)) const { + __m512 lo, hi; + cvt_to_fp32(values, lo, hi); + const auto o1 = vop(lo); + const auto o2 = vop(hi); + return cvt_from_fp32(o1, o2); + } + Vectorized isnan() const { + __m512 lo, hi; + cvt_to_fp32(values, lo, hi); + __mmask16 lo_mask, hi_mask; + __m512 zero = _mm512_set1_ps(0.0); + __m512i zeroi = _mm512_castps_si512(zero); + lo_mask = _mm512_cmp_ps_mask(lo, zero, _CMP_UNORD_Q); + lo = _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zeroi, lo_mask, 0xFFFF'FFFF)); + hi_mask = _mm512_cmp_ps_mask(hi, zero, _CMP_UNORD_Q); + hi = _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zeroi, hi_mask, 0xFFFF'FFFF)); + return merge_compare_result(lo, hi); + } +#pragma clang diagnostic pop + Vectorized abs() const { + return _mm512_andnot_si512(_mm512_set1_epi16(0x8000), values); + } + Vectorized angle() const { + __m512 lo, hi; + cvt_to_fp32(values, lo, hi); + auto angle_lambda = [](__m512 values) { + const auto zero_vec = _mm512_set1_ps(0.f); + const auto nan_vec = _mm512_set1_ps(NAN); + const auto not_nan_mask = _mm512_cmp_ps_mask(values, values, _CMP_EQ_OQ); + const auto non_nan_mask_vec = _mm512_mask_set1_epi32( + _mm512_castps_si512(zero_vec), not_nan_mask, 0xFFFFFFFF); + const auto nan_mask = _mm512_cmp_ps_mask( + _mm512_castsi512_ps(non_nan_mask_vec), zero_vec, _CMP_EQ_OQ); + const auto pi = _mm512_set1_ps(c10::pi); + + const auto neg_mask = _mm512_cmp_ps_mask(values, zero_vec, _CMP_LT_OQ); + auto angle = _mm512_mask_blend_ps(neg_mask, zero_vec, pi); + angle = _mm512_mask_blend_ps(nan_mask, angle, nan_vec); + return angle; + }; + auto o1 = angle_lambda(lo); + auto o2 = angle_lambda(hi); + return cvt_from_fp32(o1, o2); + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return _mm512_set1_epi16(0); + } + Vectorized conj() const { + return *this; + } + Vectorized acos() const { + return map(Sleef_acosf16_u10); + } + Vectorized acosh() const { + return map(Sleef_acoshf16_u10); + } + Vectorized asin() const { + return map(Sleef_asinf16_u10); + } + Vectorized asinh() const { + return map(Sleef_asinhf16_u10); + } + Vectorized atan() const { + return map(Sleef_atanf16_u10); + } + Vectorized atanh() const { + return map(Sleef_atanhf16_u10); + } + Vectorized atan2(const Vectorized& b) const { + __m512 lo, hi; + __m512 b1, b2; + cvt_to_fp32(values, lo, hi); + cvt_to_fp32(b.values, b1, b2); + auto o1 = Sleef_atan2f16_u10(lo, b1); + auto o2 = Sleef_atan2f16_u10(hi, b2); + return cvt_from_fp32(o1, o2); + } + Vectorized copysign(const Vectorized& sign) const { + // copy sign bit (0x8000) from sign and remaining bits from values + __m512i mask_value = _mm512_set1_epi32(~0x80008000); + __m512i mask_signbit = _mm512_set1_epi32(0x80008000); + return Vectorized(_mm512_or_si512( + _mm512_and_si512(values, mask_value), + _mm512_and_si512(sign, mask_signbit))); + } + Vectorized erf() const { + return map(Sleef_erff16_u10); + } + Vectorized erfc() const { + return map(Sleef_erfcf16_u15); + } + Vectorized erfinv() const { + __m512 lo, hi; + cvt_to_fp32(values, lo, hi); + __at_align__ float tmp1[size() / 2], tmp2[size() / 2]; + _mm512_storeu_ps(reinterpret_cast(tmp1), lo); + _mm512_storeu_ps(reinterpret_cast(tmp2), hi); + for (int64_t i = 0; i < size() / 2; i++) { + tmp1[i] = calc_erfinv(tmp1[i]); + tmp2[i] = calc_erfinv(tmp2[i]); + } + auto o1 = _mm512_loadu_ps(tmp1); + auto o2 = _mm512_loadu_ps(tmp2); + return cvt_from_fp32(o1, o2); + } + Vectorized exp() const { + return map(Sleef_expf16_u10); + } + Vectorized exp2() const { + return map(Sleef_exp2f16_u10); + } + Vectorized expm1() const { + return map(Sleef_expm1f16_u10); + } + Vectorized fexp_u20() const { + return exp(); + } + Vectorized exp_u20() const { + return exp(); + } + Vectorized fmod(const Vectorized& q) const { + __m512 x_lo, x_hi; + cvt_to_fp32(values, x_lo, x_hi); + __m512 q_lo, q_hi; + cvtbf16_fp32(q.values, q_lo, q_hi); + auto o1 = Sleef_fmodf16(x_lo, q_lo); + auto o2 = Sleef_fmodf16(x_hi, q_hi); + return cvt_from_fp32(o1, o2); + } + Vectorized hypot(const Vectorized& b) const { + __m512 lo, hi; + __m512 b1, b2; + cvt_to_fp32(values, lo, hi); + cvt_to_fp32(b.values, b1, b2); + auto o1 = Sleef_hypotf16_u05(lo, b1); + auto o2 = Sleef_hypotf16_u05(hi, b2); + return cvt_from_fp32(o1, o2); + } + Vectorized i0() const { + __m512 lo, hi; + cvt_to_fp32(values, lo, hi); + __at_align__ float tmp1[size() / 2], tmp2[size() / 2]; + _mm512_storeu_ps(reinterpret_cast(tmp1), lo); + _mm512_storeu_ps(reinterpret_cast(tmp2), hi); + for (int64_t i = 0; i < size() / 2; i++) { + tmp1[i] = calc_i0(tmp1[i]); + tmp2[i] = calc_i0(tmp2[i]); + } + auto o1 = _mm512_loadu_ps(tmp1); + auto o2 = _mm512_loadu_ps(tmp2); + return cvt_from_fp32(o1, o2); + } + Vectorized i0e() const { + __m512 lo, hi; + cvt_to_fp32(values, lo, hi); + constexpr auto sz = size(); + __at_align__ float tmp1[sz / 2], tmp2[sz / 2]; + _mm512_storeu_ps(reinterpret_cast(tmp1), lo); + _mm512_storeu_ps(reinterpret_cast(tmp2), hi); + + for (auto i = decltype(sz){0}; i < sz / 2; i++) { + tmp1[i] = calc_i0e(tmp1[i]); + tmp2[i] = calc_i0e(tmp2[i]); + } + const auto o1 = _mm512_loadu_ps(tmp1); + const auto o2 = _mm512_loadu_ps(tmp2); + return cvt_from_fp32(o1, o2); + } + Vectorized digamma() const { + __m512 lo, hi; + cvt_to_fp32(values, lo, hi); + constexpr auto sz = size(); + __at_align__ float tmp1[sz / 2], tmp2[sz / 2]; + _mm512_storeu_ps(reinterpret_cast(tmp1), lo); + _mm512_storeu_ps(reinterpret_cast(tmp2), hi); + + for (auto i = decltype(sz){0}; i < sz / 2; i++) { + tmp1[i] = calc_digamma(tmp1[i]); + tmp2[i] = calc_digamma(tmp2[i]); + } + const auto o1 = _mm512_loadu_ps(tmp1); + const auto o2 = _mm512_loadu_ps(tmp2); + return cvt_from_fp32(o1, o2); + } + Vectorized igamma(const Vectorized& x) const { + __m512 lo, hi; + __m512 xlo, xhi; + cvt_to_fp32(values, lo, hi); + cvt_to_fp32(x.values, xlo, xhi); + __at_align__ float tmp1[size() / 2], tmp2[size() / 2]; + _mm512_storeu_ps(reinterpret_cast(tmp1), lo); + _mm512_storeu_ps(reinterpret_cast(tmp2), hi); + __at_align__ float tmpx1[size() / 2], tmpx2[size() / 2]; + _mm512_storeu_ps(reinterpret_cast(tmpx1), xlo); + _mm512_storeu_ps(reinterpret_cast(tmpx2), xhi); + for (int64_t i = 0; i < size() / 2; ++i) { + tmp1[i] = calc_igamma(tmp1[i], tmpx1[i]); + tmp2[i] = calc_igamma(tmp2[i], tmpx2[i]); + } + auto o1 = _mm512_loadu_ps(tmp1); + auto o2 = _mm512_loadu_ps(tmp2); + return cvt_from_fp32(o1, o2); + } + + Vectorized igammac(const Vectorized& x) const { + __m512 lo, hi; + __m512 xlo, xhi; + cvt_to_fp32(values, lo, hi); + cvt_to_fp32(x.values, xlo, xhi); + __at_align__ float tmp1[size() / 2], tmp2[size() / 2]; + _mm512_storeu_ps(reinterpret_cast(tmp1), lo); + _mm512_storeu_ps(reinterpret_cast(tmp2), hi); + __at_align__ float tmpx1[size() / 2], tmpx2[size() / 2]; + _mm512_storeu_ps(reinterpret_cast(tmpx1), xlo); + _mm512_storeu_ps(reinterpret_cast(tmpx2), xhi); + for (int64_t i = 0; i < size() / 2; ++i) { + tmp1[i] = calc_igammac(tmp1[i], tmpx1[i]); + tmp2[i] = calc_igammac(tmp2[i], tmpx2[i]); + } + auto o1 = _mm512_loadu_ps(tmp1); + auto o2 = _mm512_loadu_ps(tmp2); + return cvt_from_fp32(o1, o2); + } + Vectorized log() const { + return map(Sleef_logf16_u10); + } + Vectorized log2() const { + return map(Sleef_log2f16_u10); + } + Vectorized log10() const { + return map(Sleef_log10f16_u10); + } + Vectorized log1p() const { + return map(Sleef_log1pf16_u10); + } + Vectorized sin() const { + return map(Sleef_sinf16_u10); + } + Vectorized sinh() const { + return map(Sleef_sinhf16_u10); + } + Vectorized cos() const { + return map(Sleef_cosf16_u10); + } + Vectorized cosh() const { + return map(Sleef_coshf16_u10); + } + Vectorized ceil() const { + __m512 lo, hi; + cvt_to_fp32(values, lo, hi); + auto o1 = _mm512_ceil_ps(lo); + auto o2 = _mm512_ceil_ps(hi); + return cvt_from_fp32(o1, o2); + } + Vectorized floor() const { + __m512 lo, hi; + cvt_to_fp32(values, lo, hi); + auto o1 = _mm512_floor_ps(lo); + auto o2 = _mm512_floor_ps(hi); + return cvt_from_fp32(o1, o2); + } + Vectorized neg() const { + return _mm512_xor_si512(values, _mm512_set1_epi16(0x8000)); + } + Vectorized round() const { + __m512 lo, hi; + cvt_to_fp32(values, lo, hi); + auto o1 = _mm512_roundscale_ps( + lo, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + auto o2 = _mm512_roundscale_ps( + hi, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + return cvt_from_fp32(o1, o2); + } + Vectorized tan() const { + return map(Sleef_tanf16_u10); + } + Vectorized tanh() const { + return map(Sleef_tanhf16_u10); + } + Vectorized trunc() const { + __m512 lo, hi; + cvt_to_fp32(values, lo, hi); + auto o1 = + _mm512_roundscale_ps(lo, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); + auto o2 = + _mm512_roundscale_ps(hi, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); + return cvt_from_fp32(o1, o2); + } + Vectorized lgamma() const { + return map(Sleef_lgammaf16_u10); + } + Vectorized sqrt() const { + __m512 lo, hi; + cvt_to_fp32(values, lo, hi); + auto o1 = _mm512_sqrt_ps(lo); + auto o2 = _mm512_sqrt_ps(hi); + return cvt_from_fp32(o1, o2); + } + Vectorized reciprocal() const { + __m512 lo, hi; + cvt_to_fp32(values, lo, hi); + auto ones = _mm512_set1_ps(1); + auto o1 = _mm512_div_ps(ones, lo); + auto o2 = _mm512_div_ps(ones, hi); + return cvt_from_fp32(o1, o2); + } + Vectorized rsqrt() const { + __m512 lo, hi; + cvt_to_fp32(values, lo, hi); + auto ones = _mm512_set1_ps(1); + auto o1 = _mm512_div_ps(ones, _mm512_sqrt_ps(lo)); + auto o2 = _mm512_div_ps(ones, _mm512_sqrt_ps(hi)); + return cvt_from_fp32(o1, o2); + } + Vectorized pow(const Vectorized& b) const { + __m512 lo, hi; + __m512 b1, b2; + cvt_to_fp32(values, lo, hi); + cvt_to_fp32(b.values, b1, b2); + auto o1 = Sleef_powf16_u10(lo, b1); + auto o2 = Sleef_powf16_u10(hi, b2); + return cvt_from_fp32(o1, o2); + } + + private: + template + Vectorized inline binary_compare(const VectorizedType& b, Op op) const { + __m512 a_lo, a_hi; + __m512 b_lo, b_hi; + cvt_to_fp32(values, a_lo, a_hi); + cvt_to_fp32(b.values, b_lo, b_hi); + auto o1 = op(a_lo, b_lo); + auto o2 = op(a_hi, b_hi); + return cvt_from_fp32(o1, o2); + } + + public: + Vectorized inline operator>(const Vectorized& other) const { + return binary_compare(other, [](__m512 x, __m512 y) { + auto zero_vec = _mm512_set1_epi32(0); + auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_GT_OQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF)); + }); + } + Vectorized inline operator<(const Vectorized& other) const { + return binary_compare(other, [](__m512 x, __m512 y) { + auto zero_vec = _mm512_set1_epi32(0); + auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_LT_OQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF)); + }); + } + Vectorized inline operator>=(const Vectorized& other) const { + return binary_compare(other, [](__m512 x, __m512 y) { + auto zero_vec = _mm512_set1_epi32(0); + auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_GE_OQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF)); + }); + } + Vectorized inline operator<=(const Vectorized& other) const { + return binary_compare(other, [](__m512 x, __m512 y) { + auto zero_vec = _mm512_set1_epi32(0); + auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_LE_OQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF)); + }); + } + Vectorized inline operator==(const Vectorized16& other) const { + return binary_compare(other, [](__m512 x, __m512 y) { + auto zero_vec = _mm512_set1_epi32(0); + auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_EQ_OQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF)); + }); + } + Vectorized inline operator!=(const Vectorized16& other) const { + return binary_compare(other, [](__m512 x, __m512 y) { + auto zero_vec = _mm512_set1_epi32(0); + auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_NEQ_UQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF)); + }); + } +}; + +template +static inline Vectorized binary_op_as_fp32( + const Vectorized& a, + const Vectorized& b, + Op op) { + __m512 a_lo, a_hi; + __m512 b_lo, b_hi; + cvt_to_fp32(__m512i(a), a_lo, a_hi); + cvt_to_fp32(__m512i(b), b_lo, b_hi); + auto o1 = op(a_lo, b_lo); + auto o2 = op(a_hi, b_hi); + return cvt_from_fp32(o1, o2); +} + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized : public Vectorized16 { + public: + using Vectorized16::Vectorized16; + + using value_type = BFloat16; + + Vectorized frac() const; + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { + return _mm512_add_ps(x, y); + }); +} +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { + return _mm512_sub_ps(x, y); + }); +} +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { + return _mm512_mul_ps(x, y); + }); +} +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { + return _mm512_div_ps(x, y); + }); +} +Vectorized inline operator&( + const Vectorized& a, + const Vectorized& b) { + return _mm512_and_si512(a, b); +} +Vectorized inline operator|( + const Vectorized& a, + const Vectorized& b) { + return _mm512_or_si512(a, b); +} +Vectorized inline operator^( + const Vectorized& a, + const Vectorized& b) { + return _mm512_xor_si512(a, b); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1.0f); +} + +// frac. Implement this here so we can use subtraction +inline Vectorized Vectorized::frac() const { + return *this - this->trunc(); +} + +// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + __m512 a_lo, a_hi; + __m512 b_lo, b_hi; + cvtbf16_fp32(__m512i(a), a_lo, a_hi); + cvtbf16_fp32(__m512i(b), b_lo, b_hi); + auto max_lo = _mm512_max_ps(a_lo, b_lo); + auto max_hi = _mm512_max_ps(a_hi, b_hi); + auto nan_lo_mask = _mm512_cmp_ps_mask(a_lo, b_lo, _CMP_UNORD_Q); + auto nan_hi_mask = _mm512_cmp_ps_mask(a_hi, b_hi, _CMP_UNORD_Q); + auto nan_lo = _mm512_castsi512_ps(_mm512_set1_epi32(nan_lo_mask)); + auto nan_hi = _mm512_castsi512_ps(_mm512_set1_epi32(nan_hi_mask)); + // Exploit the fact that all-ones is a NaN. + auto o1 = _mm512_or_ps(max_lo, nan_lo); + auto o2 = _mm512_or_ps(max_hi, nan_hi); + return cvtfp32_bf16(o1, o2); +} + +// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + __m512 a_lo, a_hi; + __m512 b_lo, b_hi; + __m512i zero_vec = _mm512_set1_epi32(0); + cvtbf16_fp32(__m512i(a), a_lo, a_hi); + cvtbf16_fp32(__m512i(b), b_lo, b_hi); + auto min_lo = _mm512_min_ps(a_lo, b_lo); + auto min_hi = _mm512_min_ps(a_hi, b_hi); + auto nan_lo_mask = _mm512_cmp_ps_mask(a_lo, b_lo, _CMP_UNORD_Q); + auto nan_hi_mask = _mm512_cmp_ps_mask(a_hi, b_hi, _CMP_UNORD_Q); + auto nan_lo = _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, nan_lo_mask, 0xFFFFFFFF)); + auto nan_hi = _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, nan_hi_mask, 0xFFFFFFFF)); + // Exploit the fact that all-ones is a NaN. + auto o1 = _mm512_or_ps(min_lo, nan_lo); + auto o2 = _mm512_or_ps(min_hi, nan_hi); + return cvtfp32_bf16(o1, o2); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min, + const Vectorized& max) { + __m512 a_lo, a_hi; + __m512 min_lo, min_hi; + __m512 max_lo, max_hi; + cvtbf16_fp32(__m512i(a), a_lo, a_hi); + cvtbf16_fp32(__m512i(min), min_lo, min_hi); + cvtbf16_fp32(__m512i(max), max_lo, max_hi); + auto o1 = _mm512_min_ps(max_lo, _mm512_max_ps(min_lo, a_lo)); + auto o2 = _mm512_min_ps(max_hi, _mm512_max_ps(min_hi, a_hi)); + return cvtfp32_bf16(o1, o2); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max) { + __m512 a_lo, a_hi; + __m512 max_lo, max_hi; + cvtbf16_fp32(__m512i(a), a_lo, a_hi); + cvtbf16_fp32(__m512i(max), max_lo, max_hi); + auto o1 = _mm512_min_ps(max_lo, a_lo); + auto o2 = _mm512_min_ps(max_hi, a_hi); + return cvtfp32_bf16(o1, o2); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min) { + __m512 a_lo, a_hi; + __m512 min_lo, min_hi; + cvtbf16_fp32(__m512i(a), a_lo, a_hi); + cvtbf16_fp32(__m512i(min), min_lo, min_hi); + auto o1 = _mm512_max_ps(min_lo, a_lo); + auto o2 = _mm512_max_ps(min_hi, a_hi); + return cvtfp32_bf16(o1, o2); +} + +template <> +inline void convert(const BFloat16* src, BFloat16* dst, int64_t n) { + int64_t i; +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (i = 0; i <= (n - Vectorized::size()); + i += Vectorized::size()) { + auto vsrc = + _mm512_loadu_si512(reinterpret_cast<__m512i*>((void*)(src + i))); + _mm512_storeu_si512(reinterpret_cast<__m512i*>((void*)(dst + i)), vsrc); + } +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (; i < n; i++) { + dst[i] = src[i]; + } +} + +template <> +inline void convert(const float* src, BFloat16* dst, int64_t n) { + int64_t i; + for (i = 0; i + Vectorized::size() <= n; + i += Vectorized::size()) { + __m512 a = _mm512_loadu_ps(&src[i]); + __m512 b = _mm512_loadu_ps(&src[i + 16]); + + __m512i bf = cvtfp32_bf16(a, b); + _mm512_storeu_si512(reinterpret_cast<__m512i*>(&dst[i]), bf); + } + for (; i < n; i++) { + dst[i] = c10::convert(src[i]); + } +} + +template <> +inline void convert(const double* src, BFloat16* dst, int64_t n) { + auto load_float = [](const double* src) -> __m512 { + // Load one float vector from an array of doubles + __m256 a = _mm512_cvtpd_ps(_mm512_loadu_pd(src)); + __m256 b = _mm512_cvtpd_ps(_mm512_loadu_pd(src + 8)); + return _mm512_insertf32x8(_mm512_castps256_ps512(a), b, 1); + }; + + int64_t i; + for (i = 0; i + Vectorized::size() <= n; + i += Vectorized::size()) { + __m512 a = load_float(&src[i]); + __m512 b = load_float(&src[i + 16]); + + __m512i bf = cvtfp32_bf16(a, b); + _mm512_storeu_si512(reinterpret_cast<__m512i*>(&dst[i]), bf); + } + for (; i < n; i++) { + dst[i] = c10::convert(src[i]); + } +} + +template <> +Vectorized inline fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + __m512 a_lo, a_hi; + __m512 b_lo, b_hi; + __m512 c_lo, c_hi; + cvtbf16_fp32(__m512i(a), a_lo, a_hi); + cvtbf16_fp32(__m512i(b), b_lo, b_hi); + cvtbf16_fp32(__m512i(c), c_lo, c_hi); + auto o1 = _mm512_fmadd_ps(a_lo, b_lo, c_lo); + auto o2 = _mm512_fmadd_ps(a_hi, b_hi, c_hi); + return cvtfp32_bf16(o1, o2); +} + +static inline void _transpose_mxn_half_16_16(__m256i t[], __m512i u[]) { + __m512i r[8]; + // a0a1 a2a3 a4a5 a6a7 a8a9 a10a11 a12a13 a14a15 e0e1 e2e3 e4e5 e6e7 e8e9 + // e10e11 e12e13 e14e15 b0-b15 f0-f15 c0-c15 g0-g15 d0-d15 h0-h15 i0-i15 + // m0-m15 j0-j15 n0-n15 k0-k15 o0-o15 l0-l15 p0-p15 +#ifndef __msvc_cl__ +#pragma unroll(4) +#endif + for (int i = 0; i < 4; i++) { + r[i] = _mm512_inserti64x4(_mm512_castsi256_si512(t[i]), t[i + 4], 0x01); + r[i + 4] = + _mm512_inserti64x4(_mm512_castsi256_si512(t[i + 8]), t[i + 12], 0x01); + } + + // u0: a0a1 b0b1 a2a3 b2b3 a8a9 b8b9 a10a11 b10b11 e0e1 f0f1 e2e3 f2f3 e8e9 + // f8f9 e10e11 f10f11 u1: a4a5 b4b5 a6a7 b6b7 a12a13 b12b13 a14a15 b14b15 e4e5 + // f4f5 e6e7 f6f7 e12e13 f12f13 e14e15 f14f15 u2: c0c1 d0d1 c2c3 d2d3 c8c9 + // d8d9 c10c11 d10d11 g0g1 h0h1 g2g3 h2h3 g8g9 h8h9 g10g11 h10h11 u3: c4c5 + // d4b5 c6c7 d6b7 c12c13 d12d13 c14c15 d14d15 g4g5 h4h5 g6g7 h6h7 g12g13 + // h12h13 g14g15 h14h15 i j m n k l o p +#ifndef __msvc_cl__ +#pragma unroll(4) +#endif + for (int i = 0; i < 8; i += 2) { + u[i] = _mm512_unpacklo_epi32(r[i], r[i + 1]); + u[i + 1] = _mm512_unpackhi_epi32(r[i], r[i + 1]); + } + + // r0: a0a1 b0b1 c0c1 d0d1 a8a9 b8b9 c8c9 d8d9 e0e1 f0f1 g0g1 h0h1 e8e9 f8f9 + // g8g9 h8h9 r1: a2a3 b2b3 c2c3 d2d3 a10a11 b10b11 c10c11 d10d11 e2e3 f2f3 + // g2g3 h2h3 e10e11 f10f11 g10g11 h10h11 r2: a4a5 b4b5 c4c5 d4b5 a12a13 b12b13 + // c12c13 d12d13 r3: a6a7 b6b7 c6c7 d6b7 a14a15 b14b15 c14c15 d14d15 r4: i j k + // l m n o p + r[0] = _mm512_unpacklo_epi64(u[0], u[2]); + r[1] = _mm512_unpackhi_epi64(u[0], u[2]); + r[2] = _mm512_unpacklo_epi64(u[1], u[3]); + r[3] = _mm512_unpackhi_epi64(u[1], u[3]); + r[4] = _mm512_unpacklo_epi64(u[4], u[6]); + r[5] = _mm512_unpackhi_epi64(u[4], u[6]); + r[6] = _mm512_unpacklo_epi64(u[5], u[7]); + r[7] = _mm512_unpackhi_epi64(u[5], u[7]); + + __m512i const1 = _mm512_set_epi32( + 0x00370035, + 0x00330031, + 0x00270025, + 0x00230021, + 0x00170015, + 0x00130011, + 0x00070005, + 0x00030001, + 0x00360034, + 0x00320030, + 0x00260024, + 0x00220020, + 0x00160014, + 0x00120010, + 0x00060004, + 0x00020000); + __m512i const2 = _mm512_set_epi32( + 0x003f003d, + 0x003b0039, + 0x002f002d, + 0x002b0029, + 0x001f001d, + 0x001b0019, + 0x000f000d, + 0x000b0009, + 0x003e003c, + 0x003a0038, + 0x002e002c, + 0x002a0028, + 0x001e001c, + 0x001a0018, + 0x000e000c, + 0x000a0008); + // merge values from two regs + // 0-- 1-- + // 8-- 9-- + // 2-- 3-- + // 10-- 11-- + // 4-- 5-- + // 12-- 13-- + // 6-- 7-- + // 14-- 15-- +#ifndef __msvc_cl__ +#pragma unroll(4) +#endif + for (int i = 0; i < 4; i++) { + u[i] = _mm512_permutex2var_epi16(r[i], const1, r[i + 4]); + u[i + 4] = _mm512_permutex2var_epi16(r[i], const2, r[i + 4]); + } +} + +// TODO(Leslie): Add the AVX2 Version of transpose_mxn for BFloat16 and Float16 +// Code referred to FBGEMM: +// https://github.com/pytorch/FBGEMM/blob/39a423e4ad1a04b77fea81c7d09c3e6f8984fae9/src/UtilsAvx512.cc#L1483-L1607 +template <> +inline void transpose_mxn( + const BFloat16* src, + int64_t ld_src, + BFloat16* dst, + int64_t ld_dst) { + __m256i t[16]; + // load from src to registers + // a: a0 a1 a2 a3 a4 a5 a6 a7 a8 a9 a10 a11 a12 a13 a14 a15 + // b: b0 b1 b2 b3 b4 b5 b6 b7 b8 b9 b10 b11 b12 b13 b14 b15 + // c: c0 c1 c2 c3 c4 c5 c6 c7 c8 c9 c10 c11 c12 c13 c14 c15 + // d: d0 d1 d2 d3 d4 d5 d6 d7 d8 d9 d10 d11 d12 d13 d14 d15 + // e: e0 e1 e2 e3 e4 e5 e6 e7 e8 e9 e10 e11 e12 e13 e14 e15 + // f: f0 f1 f2 f3 f4 f5 f6 f7 f8 f9 f10 f11 f12 f13 f14 f15 + // g: g0 g1 g2 g3 g4 g5 g6 g7 g8 g9 g10 g11 g12 g13 g14 g15 + // h: h0 h1 h2 h3 h4 h5 h6 h7 h8 h9 h10 h11 h12 h13 h14 h15 + // i: i0 i1 i2 i3 i4 i5 i6 i7 i8 i9 i10 i11 i12 i13 i14 i15 + // j: j0 j1 j2 j3 j4 j5 j6 j7 j8 j9 j10 j11 j12 j13 j14 j15 + // k: k0 k1 k2 k3 k4 k5 k6 k7 k8 k9 k10 k11 k12 k13 k14 k15 + // l: l0 l1 l2 l3 l4 l5 l6 l7 l8 l9 l10 l11 l12 l13 l14 l15 + // m: m0 m1 m2 m3 m4 m5 m6 m7 m8 m9 m10 m11 m12 m13 m14 m15 + // n: n0 n1 n2 n3 n4 n5 n6 n7 n8 n9 n10 n11 n12 n13 n14 n15 + // o: o0 o1 o2 o3 o4 o5 o6 o7 o8 o9 o10 o11 o12 o13 o14 o15 + // p: p0 p1 p2 p3 p4 p5 p6 p7 p8 p9 p10 p11 p12 p13 p14 p15 +#ifndef __msvc_cl__ +#pragma unroll(16) +#endif + for (int i = 0; i < 16; i++) { + t[i] = + _mm256_loadu_si256(reinterpret_cast(src + i * ld_src)); + } + + __m512i u[8]; + _transpose_mxn_half_16_16(t, u); + +#ifndef __msvc_cl__ +#pragma unroll(8) +#endif + for (int i = 0; i < 8; i++) { + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(dst + (i * 2) * ld_dst), + _mm512_extracti32x8_epi32(u[i], 0x0)); + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(dst + (i * 2 + 1) * ld_dst), + _mm512_extracti32x8_epi32(u[i], 0x01)); + } +} + +// Code referred to FBGEMM: +// https://github.com/pytorch/FBGEMM/blob/39a423e4ad1a04b77fea81c7d09c3e6f8984fae9/src/UtilsAvx512.cc#L1483-L1607 +template <> +inline void transpose_mxn( + const Half* src, + int64_t ld_src, + Half* dst, + int64_t ld_dst) { + __m256i t[16]; + // load from src to registers + // Same matrix indices as above transpose_mxn +#ifndef __msvc_cl__ +#pragma unroll(16) +#endif + for (int i = 0; i < 16; i++) { + t[i] = + _mm256_loadu_si256(reinterpret_cast(src + i * ld_src)); + } + + __m512i u[8]; + _transpose_mxn_half_16_16(t, u); + +#ifndef __msvc_cl__ +#pragma unroll(8) +#endif + for (int i = 0; i < 8; i++) { + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(dst + (i * 2) * ld_dst), + _mm512_extracti32x8_epi32(u[i], 0x0)); + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(dst + (i * 2 + 1) * ld_dst), + _mm512_extracti32x8_epi32(u[i], 0x01)); + } +} + +static inline void _transpose_mxn_half_32_32(__m512i r[], __m512i d[]) { + // t[0]: 0 32 1 33 2 34 3 35 8 40 9 41 10 42 11 43 16 ... 59 + // t[1]: 4 36 5 37 6 38 7 39 12 44 13 45 14 46 15 47 20 ... 63 + // t[2]: 64 96 65 97 66 98 67 99 72 104 73 105 74 106 75 ... 123 + // t[3]: 68 100 69 101 70 102 71 103 76 108 77 109 78 110 79 111 84 ... 127 + // t[4]: 128 160 129 161 130 162 131 163 136 168 137 169 138 170 139 171 144 + // ... 187 t[5]: 132 164 133 165 134 166 135 167 140 172 141 173 142 174 143 + // 175 148 ... 191 t[6]: 192 224 193 225 194 226 195 227 200 232 201 233 202 + // 234 203 235 208 ... 251 t[7]: 196 228 197 229 198 230 199 231 204 236 205 + // 237 206 238 207 239 212 ... 255 t[8]: 256 288 257 289 258 290 259 291 264 + // 296 265 297 266 298 267 299 272 ... 315 t[9]: 260 292 261 293 262 294 263 + // 295 268 300 269 301 270 302 271 303 276 ... 319 t[10]: 320 352 321 353 322 + // 354 323 355 328 360 329 361 330 362 331 363 336 ... 379 t[11]: 324 356 325 + // 357 326 358 327 359 332 364 333 365 334 366 335 367 340 ... 383 t[12]: 384 + // 416 385 417 386 418 387 419 392 424 393 425 394 426 395 427 400 ... 443 + // t[13]: 388 420 389 421 390 422 391 423 396 428 397 429 398 430 399 431 404 + // ... 447 t[14]: 448 480 449 481 450 482 451 483 456 488 457 489 458 490 459 + // 491 464 ... 507 t[15]: 452 484 453 485 454 486 455 487 460 492 461 493 462 + // 494 463 495 468 ... 511 t[16]: 512 544 513 545 514 546 515 547 520 552 521 + // 553 522 554 523 555 528 ... 571 + // ... + // t[31]: 964 996 965 997 966 998 967 999 972 1004 973 1005 974 1006 975 1007 + // 980 ... 1023 +#ifndef __msvc_cl__ +#pragma unroll(16) +#endif + for (int i = 0; i < 16; ++i) { + d[i * 2] = _mm512_unpacklo_epi16(r[i * 2], r[i * 2 + 1]); + d[i * 2 + 1] = _mm512_unpackhi_epi16(r[i * 2], r[i * 2 + 1]); + } + + // t[0]: 0 32 64 96 1 33 65 97 8 40 72 104 9 41 73 105 16 ... 121 + // t[1]: 2 34 66 98 3 35 67 99 10 42 74 106 11 43 75 107 18 ... 123 + // t[2]: 4 36 68 100 5 37 69 101 12 44 76 108 13 45 77 109 20 ... 125 + // t[3]: 6 38 70 102 7 39 71 103 14 46 78 110 15 47 79 111 22 ... 127 + // t[4]: 128 160 192 224 129 161 193 225 136 168 200 232 137 169 201 233 144 + // ... 249 t[5]: 130 162 194 226 131 163 195 227 138 170 202 234 139 171 203 + // 235 146 ... 251 t[6]: 132 164 196 228 133 165 197 229 140 172 204 236 141 + // 173 205 237 148 ... 253 t[7]: 134 166 198 230 135 167 199 231 142 174 206 + // 238 143 175 207 239 150 ... 255 t[8]: 256 288 320 352 257 289 321 353 264 + // 296 328 360 265 297 329 361 272 ... 377 t[9]: 258 290 322 354 259 291 323 + // 355 266 298 330 362 267 299 331 363 274 ... 379 t[10]: 260 292 324 356 261 + // 293 325 357 268 300 332 364 269 301 333 365 276 ... 381 t[11]: 262 294 326 + // 358 263 295 327 359 270 302 334 366 271 303 335 367 278 ... 383 t[12]: 384 + // 416 448 480 385 417 449 481 392 424 456 488 393 425 457 489 400 ... 505 + // t[13]: 386 418 450 482 387 419 451 483 394 426 458 490 395 427 459 491 402 + // ... 507 t[14]: 388 420 452 484 389 421 453 485 396 428 460 492 397 429 461 + // 493 404 ... 509 t[15]: 390 422 454 486 391 423 455 487 398 430 462 494 399 + // 431 463 495 406 ... 511 t[16]: 512 544 576 608 513 545 577 609 520 552 584 + // 616 521 553 585 617 528 ... 633 + // ... + // t[31]: 902 934 966 998 903 935 967 999 910 942 974 1006 911 943 975 1007 + // 918 ... 1023 +#ifndef __msvc_cl__ +#pragma unroll(8) +#endif + for (int i = 0; i < 8; ++i) { + r[i * 4] = _mm512_unpacklo_epi32(d[i * 4], d[i * 4 + 2]); + r[i * 4 + 1] = _mm512_unpackhi_epi32(d[i * 4], d[i * 4 + 2]); + r[i * 4 + 2] = _mm512_unpacklo_epi32(d[i * 4 + 1], d[i * 4 + 3]); + r[i * 4 + 3] = _mm512_unpackhi_epi32(d[i * 4 + 1], d[i * 4 + 3]); + } + + // t[0]: 0 32 64 96 128 160 192 224 8 40 72 104 136 168 200 232 16 ... 248 + // t[1]: 1 33 65 97 129 161 193 225 9 41 73 105 137 169 201 233 17 ... 249 + // t[2]: 2 34 66 98 130 162 194 226 10 42 74 106 138 170 202 234 18 ... 250 + // t[3]: 3 35 67 99 131 163 195 227 11 43 75 107 139 171 203 235 19 ... 251 + // t[4]: 4 36 68 100 132 164 196 228 12 44 76 108 140 172 204 236 20 ... 252 + // t[5]: 5 37 69 101 133 165 197 229 13 45 77 109 141 173 205 237 21 ... 253 + // t[6]: 6 38 70 102 134 166 198 230 14 46 78 110 142 174 206 238 22 ... 254 + // t[7]: 7 39 71 103 135 167 199 231 15 47 79 111 143 175 207 239 23 ... 255 + // t[8]: 256 288 320 352 384 416 448 480 264 296 328 360 392 424 456 488 272 + // ... 504 t[9]: 257 289 321 353 385 417 449 481 265 297 329 361 393 425 457 + // 489 273 ... 505 t[10]: 258 290 322 354 386 418 450 482 266 298 330 362 394 + // 426 458 490 274 ... 506 t[11]: 259 291 323 355 387 419 451 483 267 299 331 + // 363 395 427 459 491 275 ... 507 t[12]: 260 292 324 356 388 420 452 484 268 + // 300 332 364 396 428 460 492 276 ... 508 t[13]: 261 293 325 357 389 421 453 + // 485 269 301 333 365 397 429 461 493 277 ... 509 t[14]: 262 294 326 358 390 + // 422 454 486 270 302 334 366 398 430 462 494 278 ... 510 t[15]: 263 295 327 + // 359 391 423 455 487 271 303 335 367 399 431 463 495 279 ... 511 t[16]: 512 + // 544 576 608 640 672 704 736 520 552 584 616 648 680 712 744 528 ... 760 + // ... + // t[31]: 775 807 839 871 903 935 967 999 783 815 847 879 911 943 975 1007 791 + // ... 1023 +#ifndef __msvc_cl__ +#pragma unroll(4) +#endif + for (int i = 0; i < 4; ++i) { + d[i * 8] = _mm512_unpacklo_epi64(r[i * 8], r[i * 8 + 4]); + d[i * 8 + 1] = _mm512_unpackhi_epi64(r[i * 8], r[i * 8 + 4]); + d[i * 8 + 2] = _mm512_unpacklo_epi64(r[i * 8 + 1], r[i * 8 + 5]); + d[i * 8 + 3] = _mm512_unpackhi_epi64(r[i * 8 + 1], r[i * 8 + 5]); + d[i * 8 + 4] = _mm512_unpacklo_epi64(r[i * 8 + 2], r[i * 8 + 6]); + d[i * 8 + 5] = _mm512_unpackhi_epi64(r[i * 8 + 2], r[i * 8 + 6]); + d[i * 8 + 6] = _mm512_unpacklo_epi64(r[i * 8 + 3], r[i * 8 + 7]); + d[i * 8 + 7] = _mm512_unpackhi_epi64(r[i * 8 + 3], r[i * 8 + 7]); + } + + // t[0]: 0 32 64 96 128 160 192 224 256 288 320 352 384 416 448 480 16 ... 496 + // t[1]: 1 33 65 97 129 161 193 225 257 289 321 353 385 417 449 481 17 ... 497 + // t[2]: 2 34 66 98 130 162 194 226 258 290 322 354 386 418 450 482 18 ... 498 + // t[3]: 3 35 67 99 131 163 195 227 259 291 323 355 387 419 451 483 19 ... 499 + // t[4]: 4 36 68 100 132 164 196 228 260 292 324 356 388 420 452 484 20 ... + // 500 t[5]: 5 37 69 101 133 165 197 229 261 293 325 357 389 421 453 485 21 + // ... 501 t[6]: 6 38 70 102 134 166 198 230 262 294 326 358 390 422 454 486 + // 22 ... 502 t[7]: 7 39 71 103 135 167 199 231 263 295 327 359 391 423 455 + // 487 23 ... 503 t[8]: 8 40 72 104 136 168 200 232 264 296 328 360 392 424 + // 456 488 24 ... 504 t[9]: 9 41 73 105 137 169 201 233 265 297 329 361 393 + // 425 457 489 25 ... 505 t[10]: 10 42 74 106 138 170 202 234 266 298 330 362 + // 394 426 458 490 26 ... 506 t[11]: 11 43 75 107 139 171 203 235 267 299 331 + // 363 395 427 459 491 27 ... 507 t[12]: 12 44 76 108 140 172 204 236 268 300 + // 332 364 396 428 460 492 28 ... 508 t[13]: 13 45 77 109 141 173 205 237 269 + // 301 333 365 397 429 461 493 29 ... 509 t[14]: 14 46 78 110 142 174 206 238 + // 270 302 334 366 398 430 462 494 30 ... 510 t[15]: 15 47 79 111 143 175 207 + // 239 271 303 335 367 399 431 463 495 31 ... 511 t[16]: 512 544 576 608 640 + // 672 704 736 768 800 832 864 896 928 960 992 528 ... 1008 + // ... + // t[31]: 527 559 591 623 655 687 719 751 783 815 847 879 911 943 975 1007 543 + // ... 1023 + __m512i const1 = _mm512_set_epi64( + 0x000000000000000d, + 0x000000000000000c, + 0x0000000000000005, + 0x0000000000000004, + 0x0000000000000009, + 0x0000000000000008, + 0x0000000000000001, + 0x0000000000000000); + __m512i const2 = _mm512_set_epi64( + 0x000000000000000f, + 0x000000000000000e, + 0x0000000000000007, + 0x0000000000000006, + 0x000000000000000b, + 0x000000000000000a, + 0x0000000000000003, + 0x0000000000000002); +#ifndef __msvc_cl__ +#pragma unroll(8) +#endif + for (int i = 0; i < 8; ++i) { + r[i] = _mm512_permutex2var_epi64(d[i], /*idx*/ const1, d[i + 8]); + r[i + 8] = _mm512_permutex2var_epi64(d[i], /*idx*/ const2, d[i + 8]); + r[i + 16] = _mm512_permutex2var_epi64(d[i + 16], /*idx*/ const1, d[i + 24]); + r[i + 24] = _mm512_permutex2var_epi64(d[i + 16], /*idx*/ const2, d[i + 24]); + } + + // t[0]: 0 32 64 96 128 160 192 224 256 288 320 352 384 416 448 480 512 544 + // ... 992 t[1]: 1 33 65 97 129 161 193 225 257 289 321 353 385 417 449 481 + // 513 545 ... 993 t[2]: 2 34 66 98 130 162 194 226 258 290 322 354 386 418 + // 450 482 514 546 ... 994 t[3]: 3 35 67 99 131 163 195 227 259 291 323 355 + // 387 419 451 483 515 547 ... 995 t[4]: 4 36 68 100 132 164 196 228 260 292 + // 324 356 388 420 452 484 516 548 ... 996 t[5]: 5 37 69 101 133 165 197 229 + // 261 293 325 357 389 421 453 485 517 549 ... 997 t[6]: 6 38 70 102 134 166 + // 198 230 262 294 326 358 390 422 454 486 518 550 ... 998 t[7]: 7 39 71 103 + // 135 167 199 231 263 295 327 359 391 423 455 487 519 551 ... 999 t[8]: 8 40 + // 72 104 136 168 200 232 264 296 328 360 392 424 456 488 520 552 ... 1000 + // t[9]: 9 41 73 105 137 169 201 233 265 297 329 361 393 425 457 489 521 553 + // ... 1001 t[10]: 10 42 74 106 138 170 202 234 266 298 330 362 394 426 458 + // 490 522 554 ... 1002 t[11]: 11 43 75 107 139 171 203 235 267 299 331 363 + // 395 427 459 491 523 555 ... 1003 t[12]: 12 44 76 108 140 172 204 236 268 + // 300 332 364 396 428 460 492 524 556 ... 1004 t[13]: 13 45 77 109 141 173 + // 205 237 269 301 333 365 397 429 461 493 525 557 ... 1005 t[14]: 14 46 78 + // 110 142 174 206 238 270 302 334 366 398 430 462 494 526 558 ... 1006 t[15]: + // 15 47 79 111 143 175 207 239 271 303 335 367 399 431 463 495 527 559 ... + // 1007 t[16]: 16 48 80 112 144 176 208 240 272 304 336 368 400 432 464 496 + // 528 560 ... 1008 + // ... + // t[31]: 31 63 95 127 159 191 223 255 287 319 351 383 415 447 479 511 543 575 + // ... 1023 + __m512i const3 = _mm512_set_epi64( + 0x000000000000000b, + 0x000000000000000a, + 0x0000000000000009, + 0x0000000000000008, + 0x0000000000000003, + 0x0000000000000002, + 0x0000000000000001, + 0x0000000000000000); + __m512i const4 = _mm512_set_epi64( + 0x000000000000000f, + 0x000000000000000e, + 0x000000000000000d, + 0x000000000000000c, + 0x0000000000000007, + 0x0000000000000006, + 0x0000000000000005, + 0x0000000000000004); +#ifndef __msvc_cl__ +#pragma unroll(16) +#endif + for (int i = 0; i < 16; ++i) { + d[i] = _mm512_permutex2var_epi64(r[i], /*idx*/ const3, r[i + 16]); + d[i + 16] = _mm512_permutex2var_epi64(r[i], /*idx*/ const4, r[i + 16]); + } +} + +// Code referred to FBGEMM: +// https://github.com/pytorch/FBGEMM/blob/39a423e4ad1a04b77fea81c7d09c3e6f8984fae9/src/UtilsAvx512.cc#LL19C6-L19C6 +template <> +inline void transpose_mxn( + const BFloat16* src, + int64_t ld_src, + BFloat16* dst, + int64_t ld_dst, + int M, + int N) { + // load from src + TORCH_CHECK( + M <= 32 && N <= 32, "transpose_mxn expects M, N <= 32."); + __m512i r[32]; + int i; + if (N == 32) { + for (i = 0; i < M; ++i) { + r[i] = _mm512_loadu_si512(&src[i * ld_src]); + } + } else { + __mmask32 src_mask = (1 << N) - 1; + for (i = 0; i < M; ++i) { + r[i] = _mm512_maskz_loadu_epi16(src_mask, &src[i * ld_src]); + } + } + for (; i < 32; ++i) { + r[i] = _mm512_setzero_si512(); + } + + __m512i d[32]; + _transpose_mxn_half_32_32(r, d); + + // store to dst + if (M == 32) { + for (i = 0; i < N; ++i) { + _mm512_storeu_si512(&dst[i * ld_dst], d[i]); + } + } else { + __mmask32 dst_mask = (1 << M) - 1; + for (i = 0; i < N; ++i) { + _mm512_mask_storeu_epi16(&dst[i * ld_dst], dst_mask, d[i]); + } + } +} + +template < + typename T, + int M, + int N, + typename std::enable_if_t< + std::is_same_v && + ((M <= 32 && M != 16) || (N <= 32 && N != 16)), + int> = 0> +inline void transpose_mxn( + const BFloat16* src, + int64_t ld_src, + BFloat16* dst, + int64_t ld_dst) { + transpose_mxn(src, ld_src, dst, ld_dst, M, N); +} + +template <> +inline void transpose_mxn( + const Half* src, + int64_t ld_src, + Half* dst, + int64_t ld_dst, + int M, + int N) { + TORCH_CHECK(M <= 32 && N <= 32, "transpose_mxn expects M, N <= 32."); + // load from src + __m512i r[32]; + int i; + if (N == 32) { + for (i = 0; i < M; ++i) { + r[i] = _mm512_loadu_si512(&src[i * ld_src]); + } + } else { + __mmask32 src_mask = (1 << N) - 1; + for (i = 0; i < M; ++i) { + r[i] = _mm512_maskz_loadu_epi16(src_mask, &src[i * ld_src]); + } + } + for (; i < 32; ++i) { + r[i] = _mm512_setzero_si512(); + } + + __m512i d[32]; + _transpose_mxn_half_32_32(r, d); + + // store to dst + if (M == 32) { + for (i = 0; i < N; ++i) { + _mm512_storeu_si512(&dst[i * ld_dst], d[i]); + } + } else { + __mmask32 dst_mask = (1 << M) - 1; + for (i = 0; i < N; ++i) { + _mm512_mask_storeu_epi16(&dst[i * ld_dst], dst_mask, d[i]); + } + } +} + +template < + typename T, + int M, + int N, + typename std::enable_if_t< + std::is_same_v && + ((M <= 32 && M != 16) || (N <= 32 && N != 16)), + int> = 0> +inline void transpose_mxn( + const Half* src, + int64_t ld_src, + Half* dst, + int64_t ld_dst) { + transpose_mxn(src, ld_src, dst, ld_dst, M, N); +} + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized : public Vectorized16 { + public: + using Vectorized16::Vectorized16; + + using value_type = Half; + + Vectorized frac() const; + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { + return _mm512_add_ps(x, y); + }); +} +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { + return _mm512_sub_ps(x, y); + }); +} +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { + return _mm512_mul_ps(x, y); + }); +} +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { + return _mm512_div_ps(x, y); + }); +} + +Vectorized inline operator&( + const Vectorized& a, + const Vectorized& b) { + return _mm512_and_si512(a, b); +} +Vectorized inline operator|( + const Vectorized& a, + const Vectorized& b) { + return _mm512_or_si512(a, b); +} +Vectorized inline operator^( + const Vectorized& a, + const Vectorized& b) { + return _mm512_xor_si512(a, b); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1.0f); +} + +// frac. Implement this here so we can use subtraction +inline Vectorized Vectorized::frac() const { + return *this - this->trunc(); +} + +// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + __m512 a_lo, a_hi; + __m512 b_lo, b_hi; + cvtfp16_fp32(__m512i(a), a_lo, a_hi); + cvtfp16_fp32(__m512i(b), b_lo, b_hi); + auto max_lo = _mm512_max_ps(a_lo, b_lo); + auto max_hi = _mm512_max_ps(a_hi, b_hi); + auto nan_lo_mask = _mm512_cmp_ps_mask(a_lo, b_lo, _CMP_UNORD_Q); + auto nan_hi_mask = _mm512_cmp_ps_mask(a_hi, b_hi, _CMP_UNORD_Q); + auto nan_lo = _mm512_castsi512_ps(_mm512_set1_epi32(nan_lo_mask)); + auto nan_hi = _mm512_castsi512_ps(_mm512_set1_epi32(nan_hi_mask)); + // Exploit the fact that all-ones is a NaN. + auto o1 = _mm512_or_ps(max_lo, nan_lo); + auto o2 = _mm512_or_ps(max_hi, nan_hi); + return cvtfp32_fp16(o1, o2); +} + +// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + __m512 a_lo, a_hi; + __m512 b_lo, b_hi; + __m512i zero_vec = _mm512_set1_epi32(0); + cvtfp16_fp32(__m512i(a), a_lo, a_hi); + cvtfp16_fp32(__m512i(b), b_lo, b_hi); + auto min_lo = _mm512_min_ps(a_lo, b_lo); + auto min_hi = _mm512_min_ps(a_hi, b_hi); + auto nan_lo_mask = _mm512_cmp_ps_mask(a_lo, b_lo, _CMP_UNORD_Q); + auto nan_hi_mask = _mm512_cmp_ps_mask(a_hi, b_hi, _CMP_UNORD_Q); + auto nan_lo = _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, nan_lo_mask, 0xFFFFFFFF)); + auto nan_hi = _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, nan_hi_mask, 0xFFFFFFFF)); + // Exploit the fact that all-ones is a NaN. + auto o1 = _mm512_or_ps(min_lo, nan_lo); + auto o2 = _mm512_or_ps(min_hi, nan_hi); + return cvtfp32_fp16(o1, o2); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min, + const Vectorized& max) { + __m512 a_lo, a_hi; + __m512 min_lo, min_hi; + __m512 max_lo, max_hi; + cvtfp16_fp32(__m512i(a), a_lo, a_hi); + cvtfp16_fp32(__m512i(min), min_lo, min_hi); + cvtfp16_fp32(__m512i(max), max_lo, max_hi); + auto o1 = _mm512_min_ps(max_lo, _mm512_max_ps(min_lo, a_lo)); + auto o2 = _mm512_min_ps(max_hi, _mm512_max_ps(min_hi, a_hi)); + return cvtfp32_fp16(o1, o2); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max) { + __m512 a_lo, a_hi; + __m512 max_lo, max_hi; + cvtfp16_fp32(__m512i(a), a_lo, a_hi); + cvtfp16_fp32(__m512i(max), max_lo, max_hi); + auto o1 = _mm512_min_ps(max_lo, a_lo); + auto o2 = _mm512_min_ps(max_hi, a_hi); + return cvtfp32_fp16(o1, o2); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min) { + __m512 a_lo, a_hi; + __m512 min_lo, min_hi; + cvtfp16_fp32(__m512i(a), a_lo, a_hi); + cvtfp16_fp32(__m512i(min), min_lo, min_hi); + auto o1 = _mm512_max_ps(min_lo, a_lo); + auto o2 = _mm512_max_ps(min_hi, a_hi); + return cvtfp32_fp16(o1, o2); +} + +template <> +inline void convert(const Half* src, Half* dst, int64_t n) { + int64_t i; +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (i = 0; i <= (n - Vectorized::size()); + i += Vectorized::size()) { + auto vsrc = + _mm512_loadu_si512(reinterpret_cast<__m512i*>((void*)(src + i))); + _mm512_storeu_si512(reinterpret_cast<__m512i*>((void*)(dst + i)), vsrc); + } +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (; i < n; i++) { + dst[i] = src[i]; + } +} + +template <> +inline void convert(const float* src, Half* dst, int64_t n) { + int64_t i; + for (i = 0; i + Vectorized::size() <= n; + i += Vectorized::size()) { + __m512 a = _mm512_loadu_ps(&src[i]); + __m512 b = _mm512_loadu_ps(&src[i + 16]); + + __m512i bf = cvtfp32_fp16(a, b); + _mm512_storeu_si512(reinterpret_cast<__m512i*>(&dst[i]), bf); + } + for (; i < n; i++) { + dst[i] = c10::convert(src[i]); + } +} + +template <> +inline void convert(const double* src, Half* dst, int64_t n) { + auto load_float = [](const double* src) -> __m512 { + // Load one float vector from an array of doubles + __m256 a = _mm512_cvtpd_ps(_mm512_loadu_pd(src)); + __m256 b = _mm512_cvtpd_ps(_mm512_loadu_pd(src + 8)); + return _mm512_insertf32x8(_mm512_castps256_ps512(a), b, 1); + }; + + int64_t i; + for (i = 0; i + Vectorized::size() <= n; + i += Vectorized::size()) { + __m512 a = load_float(&src[i]); + __m512 b = load_float(&src[i + 16]); + + __m512i bf = cvtfp32_fp16(a, b); + _mm512_storeu_si512(reinterpret_cast<__m512i*>(&dst[i]), bf); + } + for (; i < n; i++) { + dst[i] = c10::convert(src[i]); + } +} + +template <> +Vectorized inline fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + __m512 a_lo, a_hi; + __m512 b_lo, b_hi; + __m512 c_lo, c_hi; + cvtfp16_fp32(__m512i(a), a_lo, a_hi); + cvtfp16_fp32(__m512i(b), b_lo, b_hi); + cvtfp16_fp32(__m512i(c), c_lo, c_hi); + auto o1 = _mm512_fmadd_ps(a_lo, b_lo, c_lo); + auto o2 = _mm512_fmadd_ps(a_hi, b_hi, c_hi); + return cvtfp32_fp16(o1, o2); +} + +#define CONVERT_VECTORIZED_INIT(type, name) \ + inline std::tuple, Vectorized> \ + convert_##name##_float(const Vectorized& a) { \ + __m512 o1, o2; \ + cvt_to_fp32(__m512i(a), o1, o2); \ + return std::make_tuple(o1, o2); \ + } \ + \ + inline Vectorized convert_float_##name( \ + const Vectorized& a, const Vectorized& b) { \ + return cvt_from_fp32(__m512(a), __m512(b)); \ + } +CONVERT_VECTORIZED_INIT(BFloat16, bfloat16) +CONVERT_VECTORIZED_INIT(Half, half) + +#else // defined(CPU_CAPABILITY_AVX512) + +#define CONVERT_NON_VECTORIZED_INIT(type, name) \ + inline std::tuple, Vectorized> \ + convert_##name##_float(const Vectorized& a) { \ + constexpr int64_t K = Vectorized::size(); \ + __at_align__ float arr[K]; \ + __at_align__ type arr2[K]; \ + a.store(arr2); \ + for (const auto k : c10::irange(K)) { \ + arr[k] = c10::convert(arr2[k]); \ + } \ + return std::make_tuple( \ + Vectorized::loadu(arr), \ + Vectorized::loadu(arr + Vectorized::size())); \ + } \ + \ + inline Vectorized convert_float_##name( \ + const Vectorized& a, const Vectorized& b) { \ + constexpr int64_t K = Vectorized::size(); \ + __at_align__ float arr[K]; \ + __at_align__ type arr2[K]; \ + a.store(arr); \ + b.store(arr + Vectorized::size()); \ + for (const auto k : c10::irange(K)) { \ + arr2[k] = c10::convert(arr[k]); \ + } \ + return Vectorized::loadu(arr2); \ + } +CONVERT_NON_VECTORIZED_INIT(BFloat16, bfloat16) +CONVERT_NON_VECTORIZED_INIT(Half, half) + +#endif // defined(CPU_CAPABILITY_AVX512) + +#if defined(CPU_CAPABILITY_AVX512) +#define LOAD_FP32_VECTORIZED_INIT(type, name) \ + inline void load_fp32_from_##name( \ + const type* data, Vectorized& out) { \ + auto values = _mm256_loadu_si256(reinterpret_cast(data)); \ + __m512 out_values; \ + cvt_to_fp32(values, out_values); \ + out = out_values; \ + } \ + \ + inline void load_fp32_from_##name( \ + const type* data, Vectorized& out1, Vectorized& out2) { \ + auto vec = Vectorized::loadu(data); \ + __m512 out1_values, out2_values; \ + cvt_to_fp32(vec, out1_values, out2_values); \ + out1 = out1_values; \ + out2 = out2_values; \ + } +LOAD_FP32_VECTORIZED_INIT(BFloat16, bf16) +LOAD_FP32_VECTORIZED_INIT(Half, fp16) + +#else // defined(CPU_CAPABILITY_AVX512) +#define LOAD_FP32_NON_VECTORIZED_INIT(type, name) \ + inline void load_fp32_from_##name( \ + const type* data, Vectorized& out) { \ + __at_align__ float values[Vectorized::size()]; \ + for (const auto k : c10::irange(Vectorized::size())) { \ + values[k] = data[k]; \ + } \ + out = Vectorized::loadu(values); \ + } \ + \ + inline void load_fp32_from_##name( \ + const type* data, Vectorized& out1, Vectorized& out2) { \ + load_fp32_from_##name(data, out1); \ + data += Vectorized::size(); \ + load_fp32_from_##name(data, out2); \ + } +LOAD_FP32_NON_VECTORIZED_INIT(BFloat16, bf16) +LOAD_FP32_NON_VECTORIZED_INIT(Half, fp16) + +#endif +} // namespace CPU_CAPABILITY +} // namespace at::vec + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_complex_double.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_complex_double.h new file mode 100644 index 0000000000000000000000000000000000000000..0779363c788634d77d10dd700b7c203cae2c206d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_complex_double.h @@ -0,0 +1,661 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include +#include +#include +#if defined(CPU_CAPABILITY_AVX512) +#define SLEEF_STATIC_LIBS +#include +#endif + +namespace at::vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_AVX512) + +template <> +struct is_vec_specialized_for> : std::bool_constant { +}; + +template <> +class Vectorized> { + private: + __m512d values; + static constexpr __m512i zero_vector{0, 0, 0, 0, 0, 0, 0, 0}; + + public: + using value_type = c10::complex; + using size_type = int; + static constexpr size_type size() { + return 4; + } + Vectorized() { + values = _mm512_setzero_pd(); + } + Vectorized(__m512d v) : values(v) {} + Vectorized(c10::complex val) { + double real_value = val.real(); + double imag_value = val.imag(); + values = _mm512_setr_pd( + real_value, + imag_value, + real_value, + imag_value, + real_value, + imag_value, + real_value, + imag_value); + } + Vectorized( + c10::complex val1, + c10::complex val2, + c10::complex val3, + c10::complex val4) { + values = _mm512_setr_pd( + val1.real(), + val1.imag(), + val2.real(), + val2.imag(), + val3.real(), + val3.imag(), + val4.real(), + val4.imag()); + } + operator __m512d() const { + return values; + } + template + static Vectorized> blend( + const Vectorized>& a, + const Vectorized>& b) { + // convert c10::complex index mask to V index mask: xy -> xxyy + // NOLINTNEXTLINE(clang-diagnostic-warning) + switch (mask) { + case 0: + return a; + case 1: + return _mm512_mask_blend_pd( + 0x03, a.values, b.values); // b0000 0001 = b0000 0011 + case 2: + return _mm512_mask_blend_pd( + 0x0C, a.values, b.values); // b0000 0010 = b0000 1100 + case 3: + return _mm512_mask_blend_pd( + 0x0F, a.values, b.values); // b0000 0011 = b0000 1111 + case 4: + return _mm512_mask_blend_pd( + 0x30, a.values, b.values); // b0000 0100 = b0011 0000 + case 5: + return _mm512_mask_blend_pd( + 0x33, a.values, b.values); // b0000 0101 = b0011 0011 + case 6: + return _mm512_mask_blend_pd( + 0x3C, a.values, b.values); // b0000 0110 = b0011 1100 + case 7: + return _mm512_mask_blend_pd( + 0x3F, a.values, b.values); // b0000 0111 = b0011 1111 + case 8: + return _mm512_mask_blend_pd( + 0xC0, a.values, b.values); // b0000 1000 = b1100 0000 + case 9: + return _mm512_mask_blend_pd( + 0xC3, a.values, b.values); // b0000 1001 = b1100 0011 + case 10: + return _mm512_mask_blend_pd( + 0xCC, a.values, b.values); // b0000 1010 = b1100 1100 + case 11: + return _mm512_mask_blend_pd( + 0xCF, a.values, b.values); // b0000 1011 = b1100 1111 + case 12: + return _mm512_mask_blend_pd( + 0xF0, a.values, b.values); // b0000 1100 = b1111 0000 + case 13: + return _mm512_mask_blend_pd( + 0xF3, a.values, b.values); // b0000 1101 = b1111 0011 + case 14: + return _mm512_mask_blend_pd( + 0xFC, a.values, b.values); // b0000 1110 = b1111 1100 + case 15: + return _mm512_mask_blend_pd( + 0xFF, a.values, b.values); // b0000 1111 = b1111 1111 + } + return b; + } + static Vectorized> blendv( + const Vectorized>& a, + const Vectorized>& b, + const Vectorized>& mask) { + // convert c10::complex index mask to V index mask: xy -> xxyy + auto mask_ = _mm512_unpacklo_pd(mask.values, mask.values); + auto all_ones = _mm512_set1_epi64(0xFFFFFFFFFFFFFFFF); + auto mmask = _mm512_cmp_epi64_mask( + _mm512_castpd_si512(mask_), all_ones, _MM_CMPINT_EQ); + return _mm512_mask_blend_pd(mmask, a.values, b.values); + } + template + static Vectorized> arange( + c10::complex base = 0., + step_t step = static_cast(1)) { + return Vectorized>( + base, + base + c10::complex(1) * step, + base + c10::complex(2) * step, + base + c10::complex(3) * step); + } + static Vectorized> set( + const Vectorized>& a, + const Vectorized>& b, + int64_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + } + return b; + } + static Vectorized> loadu( + const void* ptr, + int64_t count = size()) { + if (count == size()) + return _mm512_loadu_pd(reinterpret_cast(ptr)); + + __at_align__ double tmp_values[2 * size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(2 * size())) { + tmp_values[i] = 0.0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(c10::complex)); + return _mm512_load_pd(tmp_values); + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + _mm512_storeu_pd(reinterpret_cast(ptr), values); + } else if (count > 0) { + double tmp_values[2 * size()]; + _mm512_storeu_pd(reinterpret_cast(tmp_values), values); + std::memcpy(ptr, tmp_values, count * sizeof(c10::complex)); + } + } + const c10::complex& operator[](int idx) const = delete; + c10::complex& operator[](int idx) = delete; + Vectorized> map( + c10::complex (*const f)(const c10::complex&)) const { + __at_align__ c10::complex tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + // AVX512 doesn't have horizontal add & horizontal sub instructions. + // TODO: hadd_pd() & hsub_pd() may have scope for improvement. + static inline __m512d hadd_pd(__m512d a, __m512d b) { + __m512i idx1 = _mm512_set_epi64(14, 6, 12, 4, 10, 2, 8, 0); + __m512i idx2 = _mm512_set_epi64(15, 7, 13, 5, 11, 3, 9, 1); + return _mm512_add_pd( + _mm512_mask_permutex2var_pd(a, 0xff, idx1, b), + _mm512_mask_permutex2var_pd(a, 0xff, idx2, b)); + } + static inline __m512d hsub_pd(__m512d a, __m512d b) { + __m512i idx1 = _mm512_set_epi64(14, 6, 12, 4, 10, 2, 8, 0); + __m512i idx2 = _mm512_set_epi64(15, 7, 13, 5, 11, 3, 9, 1); + return _mm512_sub_pd( + _mm512_mask_permutex2var_pd(a, 0xff, idx1, b), + _mm512_mask_permutex2var_pd(a, 0xff, idx2, b)); + } + __m512d abs_2_() const { + auto val_2 = _mm512_mul_pd(values, values); // a*a b*b + return hadd_pd(val_2, val_2); // a*a+b*b a*a+b*b + } + __m512d abs_() const { + auto real = _mm512_movedup_pd(values); // real real + // movehdup_pd does not exist... + auto imag = _mm512_permute_pd(values, 0xff); // imag imag + return Sleef_hypotd8_u05(real, imag); // abs abs + } + Vectorized> abs() const { + const __m512d real_mask = _mm512_castsi512_pd(_mm512_setr_epi64( + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000)); + return _mm512_and_pd(abs_(), real_mask); // abs 0 + } + __m512d angle_() const { + // angle = atan2(b/a) + auto b_a = _mm512_permute_pd(values, 0x55); // b a + return Sleef_atan2d8_u10(values, b_a); // 90-angle angle + } + Vectorized> angle() const { + const __m512d real_mask = _mm512_castsi512_pd(_mm512_setr_epi64( + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000)); + auto angle = _mm512_permute_pd(angle_(), 0x55); // angle 90-angle + return _mm512_and_pd(angle, real_mask); // angle 0 + } + Vectorized> sgn() const { + auto abs = abs_(); + auto zero = _mm512_setzero_pd(); + auto mask = _mm512_cmp_pd_mask(abs, zero, _CMP_EQ_OQ); + auto div = _mm512_div_pd(values, abs); + return _mm512_mask_blend_pd(mask, div, zero); + } + __m512d real_() const { + const __m512d real_mask = _mm512_castsi512_pd(_mm512_setr_epi64( + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000)); + return _mm512_and_pd(values, real_mask); + } + Vectorized> real() const { + return real_(); + } + __m512d imag_() const { + const __m512d imag_mask = _mm512_castsi512_pd(_mm512_setr_epi64( + 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF)); + return _mm512_and_pd(values, imag_mask); + } + Vectorized> imag() const { + return _mm512_permute_pd(imag_(), 0x55); // b a + } + __m512d conj_() const { + const __m512d sign_mask = + _mm512_setr_pd(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0); + return _mm512_xor_pd(values, sign_mask); // a -b + } + Vectorized> conj() const { + return conj_(); + } + Vectorized> log() const { + // Most trigonomic ops use the log() op to improve complex number + // performance. + return map(std::log); + } + Vectorized> log2() const { + const __m512d log2_ = _mm512_set1_pd(std::log(2)); + return _mm512_div_pd(log(), log2_); + } + Vectorized> log10() const { + const __m512d log10_ = _mm512_set1_pd(std::log(10)); + return _mm512_div_pd(log(), log10_); + } + Vectorized> log1p() const { + return map(std::log1p); + } + Vectorized> asin() const { + // TODO: The vectorized implementation requires special handling for the + // case where real number/imag number is 0/Inf/NaN. + // // asin(x) + // // = -i*ln(iz + sqrt(1 -z^2)) + // // = -i*ln((ai - b) + sqrt(1 - (a + bi)*(a + bi))) + // // = -i*ln((-b + ai) + sqrt(1 - (a**2 - b**2) - 2*abi)) + // const __m512d one = _mm512_set1_pd(1); + + // auto conj = conj_(); + // auto b_a = _mm512_permute_pd(conj, 0x55); //-b a + // auto ab = _mm512_mul_pd(conj, b_a); //-ab + // -ab auto im = _mm512_add_pd(ab, ab); //-2ab -2ab + + // auto val_2 = _mm512_mul_pd(values, values); // a*a + // b*b auto re = hsub_pd(val_2, _mm512_permute_pd(val_2, 0x55)); // a*a-b*b + // b*b-a*a re = _mm512_sub_pd(one, re); + + // auto root = Vectorized(_mm512_mask_blend_pd(0xAA, re, im)).sqrt(); + // //sqrt(re + i*im) auto ln = Vectorized(_mm512_add_pd(b_a, root)).log(); + // //ln(iz + sqrt()) return Vectorized(_mm512_permute_pd(ln.values, + // 0x55)).conj(); //-i*ln() + return map(std::asin); + } + Vectorized> acos() const { + // acos(x) = pi/2 - asin(x) + constexpr auto pi_2d = c10::pi / 2; + const __m512d pi_2 = + _mm512_setr_pd(pi_2d, 0.0, pi_2d, 0.0, pi_2d, 0.0, pi_2d, 0.0); + return _mm512_sub_pd(pi_2, asin()); + } + Vectorized> atan() const; + Vectorized> atanh() const { + return map(std::atanh); + } + Vectorized> exp() const { + // TODO: The vectorized implementation requires special handling for the + // case where real number/imag number is 0/Inf/NaN. + // //exp(a + bi) + // // = exp(a)*(cos(b) + sin(b)i) + // auto exp = Sleef_expd8_u10(values); //exp(a) exp(b) exp = + // _mm512_mask_blend_pd(0xAA, exp, _mm512_permute_pd(exp, 0x55)); //exp(a) + // exp(a) + + // auto sin_cos = Sleef_sincosd8_u10(values); //[sin(a), cos(a)] [sin(b), + // cos(b)] auto cos_sin = _mm512_mask_blend_pd(0xAA, + // _mm512_permute_pd(sin_cos.y, 0x55), + // sin_cos.x); //cos(b) + // sin(b) + // return _mm512_mul_pd(exp, cos_sin); + return map(std::exp); + } + Vectorized> exp2() const { + // Use identity 2**x = exp(log(2) * x) + const __m512d ln_2 = _mm512_set1_pd(c10::ln_2); + Vectorized> scaled_values = + _mm512_mul_pd(values, ln_2); + return scaled_values.exp(); + } + Vectorized> expm1() const { + return map(std::expm1); + } + Vectorized> sin() const { + return map(std::sin); + } + Vectorized> sinh() const { + return map(std::sinh); + } + Vectorized> cos() const { + return map(std::cos); + } + Vectorized> cosh() const { + return map(std::cosh); + } + Vectorized> ceil() const { + return _mm512_ceil_pd(values); + } + Vectorized> floor() const { + return _mm512_floor_pd(values); + } + Vectorized> neg() const { + auto zero = _mm512_setzero_pd(); + return _mm512_sub_pd(zero, values); + } + Vectorized> round() const { + return _mm512_roundscale_pd( + values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + } + Vectorized> tan() const { + return map(std::tan); + } + Vectorized> tanh() const { + return map(std::tanh); + } + Vectorized> trunc() const { + return _mm512_roundscale_pd( + values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); + } + Vectorized> sqrt() const { + return map(std::sqrt); + } + Vectorized> reciprocal() const; + Vectorized> rsqrt() const { + return sqrt().reciprocal(); + } + Vectorized> pow( + const Vectorized>& exp) const { + __at_align__ c10::complex x_tmp[size()]; + __at_align__ c10::complex y_tmp[size()]; + store(x_tmp); + exp.store(y_tmp); + for (const auto i : c10::irange(size())) { + x_tmp[i] = std::pow(x_tmp[i], y_tmp[i]); + } + return loadu(x_tmp); + } + // Comparison using the _CMP_**_OQ predicate. + // `O`: get false if an operand is NaN + // `Q`: do not raise if an operand is NaN + Vectorized> operator==( + const Vectorized>& other) const { + auto mask = _mm512_cmp_pd_mask(values, other.values, _CMP_EQ_OQ); + return _mm512_castsi512_pd( + _mm512_mask_set1_epi64(zero_vector, mask, 0xFFFFFFFFFFFFFFFF)); + } + Vectorized> operator!=( + const Vectorized>& other) const { + auto mask = _mm512_cmp_pd_mask(values, other.values, _CMP_NEQ_UQ); + return _mm512_castsi512_pd( + _mm512_mask_set1_epi64(zero_vector, mask, 0xFFFFFFFFFFFFFFFF)); + } + Vectorized> operator<( + const Vectorized>& other [[maybe_unused]]) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized> operator<=( + const Vectorized>& other [[maybe_unused]]) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized> operator>( + const Vectorized>& other [[maybe_unused]]) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized> operator>=( + const Vectorized>& other [[maybe_unused]]) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vectorized> eq( + const Vectorized>& other) const; + Vectorized> ne( + const Vectorized>& other) const; +}; + +template <> +Vectorized> inline operator+( + const Vectorized>& a, + const Vectorized>& b) { + return _mm512_add_pd(a, b); +} + +template <> +Vectorized> inline operator-( + const Vectorized>& a, + const Vectorized>& b) { + return _mm512_sub_pd(a, b); +} + +template <> +Vectorized> inline operator*( + const Vectorized>& a, + const Vectorized>& b) { + //(a + bi) * (c + di) = (ac - bd) + (ad + bc)i + const __m512d sign_mask = + _mm512_setr_pd(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0); + auto ac_bd = _mm512_mul_pd(a, b); // ac bd + + auto d_c = _mm512_permute_pd(b, 0x55); // d c + d_c = _mm512_xor_pd(sign_mask, d_c); // d -c + auto ad_bc = _mm512_mul_pd(a, d_c); // ad -bc + + auto ret = Vectorized>::hsub_pd( + ac_bd, ad_bc); // ac - bd ad + bc + return ret; +} + +template <> +Vectorized> inline operator/( + const Vectorized>& a, + const Vectorized>& b) { + // TODO: The vectorized implementation requires special handling for the case + // where real number/imag number is 0/Inf/NaN. + // //re + im*i = (a + bi) / (c + di) + // auto mask = _mm512_set1_pd(-0.f); + // auto fabs_cd = _mm512_andnot_pd(mask, b); // |c| |d| + // auto fabs_dc = _mm512_permute_pd(fabs_cd, 0x55); // |d| |c| + // auto scale = _mm512_rcp14_pd(_mm512_max_pd(fabs_cd, fabs_dc)); // 1/sc + // 1/sc auto a2 = _mm512_mul_pd(a, scale); // a/sc b/sc auto b2 = + // _mm512_mul_pd(b, scale); // c/sc d/sc auto acbd2 = + // _mm512_mul_pd(a2, b2); + + // const __m512d sign_mask = _mm512_setr_pd(-0.0, 0.0, -0.0, 0.0, -0.0, 0.0, + // -0.0, 0.0); auto dc2 = _mm512_permute_pd(b2, 0x55); // d/sc c/sc + // dc2 = _mm512_xor_pd(sign_mask, dc2); // -d/|c,d| c/sc + // auto adbc2 = _mm512_mul_pd(a2, dc2); //-ad/sc^2 bc/sc^2 + // auto res2 = Vectorized>::hadd_pd(acbd2, adbc2); + // //(ac+bd)/sc^2 (bc-ad)/sc^2 + + // // get the denominator + // auto denom2 = Vectorized>(b2).abs_2_(); // + // (c^2+d^2)/sc^2 (c^2+d^2)/sc^2 res2 = _mm512_div_pd(res2, denom2); return + // res2; + __at_align__ c10::complex + tmp1[Vectorized>::size()]; + __at_align__ c10::complex + tmp2[Vectorized>::size()]; + __at_align__ c10::complex + out[Vectorized>::size()]; + a.store(tmp1); + b.store(tmp2); + for (const auto i : c10::irange(Vectorized>::size())) { + out[i] = tmp1[i] / tmp2[i]; + } + return _mm512_loadu_pd(reinterpret_cast(out)); +} + +// reciprocal. Implement this here so we can use multiplication. +inline Vectorized> Vectorized< + c10::complex>::reciprocal() const { + // TODO: The vectorized implementation requires special handling for the case + // where real number/imag number is 0/Inf/NaN. + // //re + im*i = (a + bi) / (c + di) + // //re = (ac + bd)/abs_2() = c/abs_2() + // //im = (bc - ad)/abs_2() = d/abs_2() + // const __m512d sign_mask = _mm512_setr_pd(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, + // 0.0, -0.0); auto c_d = _mm512_xor_pd(sign_mask, values); //c -d + // return _mm512_div_pd(c_d, abs_2_()); + __at_align__ c10::complex tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + tmp[i] = c10::complex(1) / tmp[i]; + } + return loadu(tmp); +} + +inline Vectorized> Vectorized>::atan() + const { + // TODO: The vectorized implementation requires special handling for the case + // where real number/imag number is 0/Inf/NaN. + // // atan(x) = i/2 * ln((i + z)/(i - z)) + // const __m512d i = _mm512_setr_pd(0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0); + // const Vectorized i_half = _mm512_setr_pd(0.0, 0.5, 0.0, 0.5, 0.0, 0.5, 0.0, + // 0.5); + + // auto sum = Vectorized(_mm512_add_pd(i, values)); // a + // 1+b auto sub = Vectorized(_mm512_sub_pd(i, values)); // -a 1-b auto + // ln = (sum/sub).log(); // ln((i + + // z)/(i - z)) return i_half*ln; // i/2*ln() + return map(std::atan); +} + +template <> +Vectorized> inline maximum( + const Vectorized>& a, + const Vectorized>& b) { + auto zero_vec = _mm512_set1_epi64(0); + auto abs_a = a.abs_2_(); + auto abs_b = b.abs_2_(); + auto mask = _mm512_cmp_pd_mask(abs_a, abs_b, _CMP_LT_OQ); + auto max = _mm512_mask_blend_pd(mask, a, b); + // Exploit the fact that all-ones is a NaN. + auto isnan_mask = _mm512_cmp_pd_mask(abs_a, abs_b, _CMP_UNORD_Q); + auto isnan = _mm512_mask_set1_epi64(zero_vec, isnan_mask, 0xFFFFFFFFFFFFFFFF); + return _mm512_or_pd(max, _mm512_castsi512_pd(isnan)); +} + +template <> +Vectorized> inline minimum( + const Vectorized>& a, + const Vectorized>& b) { + auto zero_vec = _mm512_set1_epi64(0); + auto abs_a = a.abs_2_(); + auto abs_b = b.abs_2_(); + auto mask = _mm512_cmp_pd_mask(abs_a, abs_b, _CMP_GT_OQ); + auto min = _mm512_mask_blend_pd(mask, a, b); + // Exploit the fact that all-ones is a NaN. + auto isnan_mask = _mm512_cmp_pd_mask(abs_a, abs_b, _CMP_UNORD_Q); + auto isnan = _mm512_mask_set1_epi64(zero_vec, isnan_mask, 0xFFFFFFFFFFFFFFFF); + return _mm512_or_pd(min, _mm512_castsi512_pd(isnan)); +} + +template <> +Vectorized> inline operator&( + const Vectorized>& a, + const Vectorized>& b) { + return _mm512_and_pd(a, b); +} + +template <> +Vectorized> inline operator|( + const Vectorized>& a, + const Vectorized>& b) { + return _mm512_or_pd(a, b); +} + +template <> +Vectorized> inline operator^( + const Vectorized>& a, + const Vectorized>& b) { + return _mm512_xor_pd(a, b); +} + +inline Vectorized> Vectorized>::eq( + const Vectorized>& other) const { + auto eq = (*this == other); // compares real and imag individually + // If both real numbers and imag numbers are equal, then the complex numbers + // are equal + return (eq.real() & eq.imag()) & + Vectorized>(_mm512_set1_pd(1.0)); +} + +inline Vectorized> Vectorized>::ne( + const Vectorized>& other) const { + auto ne = (*this != other); // compares real and imag individually + // If either real numbers or imag numbers are not equal, then the complex + // numbers are not equal + return (ne.real() | ne.imag()) & + Vectorized>(_mm512_set1_pd(1.0)); +} + +#endif + +} // namespace CPU_CAPABILITY +} // namespace at::vec + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_complex_float.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_complex_float.h new file mode 100644 index 0000000000000000000000000000000000000000..59fce4ea931c3671dfe3c87387a524bcc6666690 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_complex_float.h @@ -0,0 +1,1229 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include +#include +#include +#if defined(CPU_CAPABILITY_AVX512) +#define SLEEF_STATIC_LIBS +#include +#endif + +namespace at::vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_AVX512) + +template <> +struct is_vec_specialized_for> : std::bool_constant { +}; + +template <> +class Vectorized> { + private: + __m512 values; + static constexpr __m512i zero_vector{0, 0, 0, 0, 0, 0, 0, 0}; + + public: + using value_type = c10::complex; + using size_type = int; + static constexpr size_type size() { + return 8; + } + Vectorized() { + values = _mm512_setzero_ps(); + } + Vectorized(__m512 v) : values(v) {} + Vectorized(c10::complex val) { + float real_value = val.real(); + float imag_value = val.imag(); + values = _mm512_setr_ps( + real_value, + imag_value, + real_value, + imag_value, + real_value, + imag_value, + real_value, + imag_value, + real_value, + imag_value, + real_value, + imag_value, + real_value, + imag_value, + real_value, + imag_value); + } + Vectorized( + c10::complex val1, + c10::complex val2, + c10::complex val3, + c10::complex val4, + c10::complex val5, + c10::complex val6, + c10::complex val7, + c10::complex val8) { + values = _mm512_setr_ps( + val1.real(), + val1.imag(), + val2.real(), + val2.imag(), + val3.real(), + val3.imag(), + val4.real(), + val4.imag(), + val5.real(), + val5.imag(), + val6.real(), + val6.imag(), + val7.real(), + val7.imag(), + val8.real(), + val8.imag()); + } + operator __m512() const { + return values; + } + template + static Vectorized> blend( + const Vectorized>& a, + const Vectorized>& b) { + // convert c10::complex index mask to V index mask: xy -> xxyy + static_assert(mask > -1 && mask < 256, "Unexpected mask value"); + // The compiler would hopefully convert this switch condition + // into a jump table + switch (mask) { + case 0: + return a; + case 1: + return _mm512_mask_blend_ps(0x03, a.values, b.values); + case 2: + return _mm512_mask_blend_ps(0x0C, a.values, b.values); + case 3: + return _mm512_mask_blend_ps(0x0F, a.values, b.values); + case 4: + return _mm512_mask_blend_ps(0x30, a.values, b.values); + case 5: + return _mm512_mask_blend_ps(0x33, a.values, b.values); + case 6: + return _mm512_mask_blend_ps(0x3C, a.values, b.values); + case 7: + return _mm512_mask_blend_ps(0x3F, a.values, b.values); + case 8: + return _mm512_mask_blend_ps(0xC0, a.values, b.values); + case 9: + return _mm512_mask_blend_ps(0xC3, a.values, b.values); + case 10: + return _mm512_mask_blend_ps(0xCC, a.values, b.values); + case 11: + return _mm512_mask_blend_ps(0xCF, a.values, b.values); + case 12: + return _mm512_mask_blend_ps(0xF0, a.values, b.values); + case 13: + return _mm512_mask_blend_ps(0xF3, a.values, b.values); + case 14: + return _mm512_mask_blend_ps(0xFC, a.values, b.values); + case 15: + return _mm512_mask_blend_ps(0xFF, a.values, b.values); + case 16: + return _mm512_mask_blend_ps(0x300, a.values, b.values); + case 17: + return _mm512_mask_blend_ps(0x303, a.values, b.values); + case 18: + return _mm512_mask_blend_ps(0x30C, a.values, b.values); + case 19: + return _mm512_mask_blend_ps(0x30F, a.values, b.values); + case 20: + return _mm512_mask_blend_ps(0x330, a.values, b.values); + case 21: + return _mm512_mask_blend_ps(0x333, a.values, b.values); + case 22: + return _mm512_mask_blend_ps(0x33C, a.values, b.values); + case 23: + return _mm512_mask_blend_ps(0x33F, a.values, b.values); + case 24: + return _mm512_mask_blend_ps(0x3C0, a.values, b.values); + case 25: + return _mm512_mask_blend_ps(0x3C3, a.values, b.values); + case 26: + return _mm512_mask_blend_ps(0x3CC, a.values, b.values); + case 27: + return _mm512_mask_blend_ps(0x3CF, a.values, b.values); + case 28: + return _mm512_mask_blend_ps(0x3F0, a.values, b.values); + case 29: + return _mm512_mask_blend_ps(0x3F3, a.values, b.values); + case 30: + return _mm512_mask_blend_ps(0x3FC, a.values, b.values); + case 31: + return _mm512_mask_blend_ps(0x3FF, a.values, b.values); + case 32: + return _mm512_mask_blend_ps(0xC00, a.values, b.values); + case 33: + return _mm512_mask_blend_ps(0xC03, a.values, b.values); + case 34: + return _mm512_mask_blend_ps(0xC0C, a.values, b.values); + case 35: + return _mm512_mask_blend_ps(0xC0F, a.values, b.values); + case 36: + return _mm512_mask_blend_ps(0xC30, a.values, b.values); + case 37: + return _mm512_mask_blend_ps(0xC33, a.values, b.values); + case 38: + return _mm512_mask_blend_ps(0xC3C, a.values, b.values); + case 39: + return _mm512_mask_blend_ps(0xC3F, a.values, b.values); + case 40: + return _mm512_mask_blend_ps(0xCC0, a.values, b.values); + case 41: + return _mm512_mask_blend_ps(0xCC3, a.values, b.values); + case 42: + return _mm512_mask_blend_ps(0xCCC, a.values, b.values); + case 43: + return _mm512_mask_blend_ps(0xCCF, a.values, b.values); + case 44: + return _mm512_mask_blend_ps(0xCF0, a.values, b.values); + case 45: + return _mm512_mask_blend_ps(0xCF3, a.values, b.values); + case 46: + return _mm512_mask_blend_ps(0xCFC, a.values, b.values); + case 47: + return _mm512_mask_blend_ps(0xCFF, a.values, b.values); + case 48: + return _mm512_mask_blend_ps(0xF00, a.values, b.values); + case 49: + return _mm512_mask_blend_ps(0xF03, a.values, b.values); + case 50: + return _mm512_mask_blend_ps(0xF0C, a.values, b.values); + case 51: + return _mm512_mask_blend_ps(0xF0F, a.values, b.values); + case 52: + return _mm512_mask_blend_ps(0xF30, a.values, b.values); + case 53: + return _mm512_mask_blend_ps(0xF33, a.values, b.values); + case 54: + return _mm512_mask_blend_ps(0xF3C, a.values, b.values); + case 55: + return _mm512_mask_blend_ps(0xF3F, a.values, b.values); + case 56: + return _mm512_mask_blend_ps(0xFC0, a.values, b.values); + case 57: + return _mm512_mask_blend_ps(0xFC3, a.values, b.values); + case 58: + return _mm512_mask_blend_ps(0xFCC, a.values, b.values); + case 59: + return _mm512_mask_blend_ps(0xFCF, a.values, b.values); + case 60: + return _mm512_mask_blend_ps(0xFF0, a.values, b.values); + case 61: + return _mm512_mask_blend_ps(0xFF3, a.values, b.values); + case 62: + return _mm512_mask_blend_ps(0xFFC, a.values, b.values); + case 63: + return _mm512_mask_blend_ps(0xFFF, a.values, b.values); + case 64: + return _mm512_mask_blend_ps(0x3000, a.values, b.values); + case 65: + return _mm512_mask_blend_ps(0x3003, a.values, b.values); + case 66: + return _mm512_mask_blend_ps(0x300C, a.values, b.values); + case 67: + return _mm512_mask_blend_ps(0x300F, a.values, b.values); + case 68: + return _mm512_mask_blend_ps(0x3030, a.values, b.values); + case 69: + return _mm512_mask_blend_ps(0x3033, a.values, b.values); + case 70: + return _mm512_mask_blend_ps(0x303C, a.values, b.values); + case 71: + return _mm512_mask_blend_ps(0x303F, a.values, b.values); + case 72: + return _mm512_mask_blend_ps(0x30C0, a.values, b.values); + case 73: + return _mm512_mask_blend_ps(0X30C3, a.values, b.values); + case 74: + return _mm512_mask_blend_ps(0x30CC, a.values, b.values); + case 75: + return _mm512_mask_blend_ps(0x30CF, a.values, b.values); + case 76: + return _mm512_mask_blend_ps(0x30F0, a.values, b.values); + case 77: + return _mm512_mask_blend_ps(0x30F3, a.values, b.values); + case 78: + return _mm512_mask_blend_ps(0x30FC, a.values, b.values); + case 79: + return _mm512_mask_blend_ps(0x30FF, a.values, b.values); + case 80: + return _mm512_mask_blend_ps(0x3300, a.values, b.values); + case 81: + return _mm512_mask_blend_ps(0X3303, a.values, b.values); + case 82: + return _mm512_mask_blend_ps(0x330C, a.values, b.values); + case 83: + return _mm512_mask_blend_ps(0x330F, a.values, b.values); + case 84: + return _mm512_mask_blend_ps(0x3330, a.values, b.values); + case 85: + return _mm512_mask_blend_ps(0x3333, a.values, b.values); + case 86: + return _mm512_mask_blend_ps(0x333C, a.values, b.values); + case 87: + return _mm512_mask_blend_ps(0X333F, a.values, b.values); + case 88: + return _mm512_mask_blend_ps(0x33C0, a.values, b.values); + case 89: + return _mm512_mask_blend_ps(0x33C3, a.values, b.values); + case 90: + return _mm512_mask_blend_ps(0x33CC, a.values, b.values); + case 91: + return _mm512_mask_blend_ps(0x33CF, a.values, b.values); + case 92: + return _mm512_mask_blend_ps(0x33F0, a.values, b.values); + case 93: + return _mm512_mask_blend_ps(0x33F3, a.values, b.values); + case 94: + return _mm512_mask_blend_ps(0x33FC, a.values, b.values); + case 95: + return _mm512_mask_blend_ps(0x33FF, a.values, b.values); + case 96: + return _mm512_mask_blend_ps(0X3C00, a.values, b.values); + case 97: + return _mm512_mask_blend_ps(0x3C03, a.values, b.values); + case 98: + return _mm512_mask_blend_ps(0x3C0C, a.values, b.values); + case 99: + return _mm512_mask_blend_ps(0x3C0F, a.values, b.values); + case 100: + return _mm512_mask_blend_ps(0x3C30, a.values, b.values); + case 101: + return _mm512_mask_blend_ps(0x3C33, a.values, b.values); + case 102: + return _mm512_mask_blend_ps(0x3C3C, a.values, b.values); + case 103: + return _mm512_mask_blend_ps(0x3C3F, a.values, b.values); + case 104: + return _mm512_mask_blend_ps(0x3CC0, a.values, b.values); + case 105: + return _mm512_mask_blend_ps(0x3CC3, a.values, b.values); + case 106: + return _mm512_mask_blend_ps(0x3CCC, a.values, b.values); + case 107: + return _mm512_mask_blend_ps(0x3CCF, a.values, b.values); + case 108: + return _mm512_mask_blend_ps(0x3CF0, a.values, b.values); + case 109: + return _mm512_mask_blend_ps(0x3CF3, a.values, b.values); + case 110: + return _mm512_mask_blend_ps(0x3CFC, a.values, b.values); + case 111: + return _mm512_mask_blend_ps(0x3CFF, a.values, b.values); + case 112: + return _mm512_mask_blend_ps(0x3F00, a.values, b.values); + case 113: + return _mm512_mask_blend_ps(0x3F03, a.values, b.values); + case 114: + return _mm512_mask_blend_ps(0x3F0C, a.values, b.values); + case 115: + return _mm512_mask_blend_ps(0x3F0F, a.values, b.values); + case 116: + return _mm512_mask_blend_ps(0x3F30, a.values, b.values); + case 117: + return _mm512_mask_blend_ps(0x3F33, a.values, b.values); + case 118: + return _mm512_mask_blend_ps(0x3F3C, a.values, b.values); + case 119: + return _mm512_mask_blend_ps(0x3F3F, a.values, b.values); + case 120: + return _mm512_mask_blend_ps(0x3FC0, a.values, b.values); + case 121: + return _mm512_mask_blend_ps(0x3FC3, a.values, b.values); + case 122: + return _mm512_mask_blend_ps(0x3FCC, a.values, b.values); + case 123: + return _mm512_mask_blend_ps(0x3FCF, a.values, b.values); + case 124: + return _mm512_mask_blend_ps(0x3FF0, a.values, b.values); + case 125: + return _mm512_mask_blend_ps(0x3FF3, a.values, b.values); + case 126: + return _mm512_mask_blend_ps(0x3FFC, a.values, b.values); + case 127: + return _mm512_mask_blend_ps(0x3FFF, a.values, b.values); + case 128: + return _mm512_mask_blend_ps(0xC000, a.values, b.values); + case 129: + return _mm512_mask_blend_ps(0xC003, a.values, b.values); + case 130: + return _mm512_mask_blend_ps(0xC00C, a.values, b.values); + case 131: + return _mm512_mask_blend_ps(0xC00F, a.values, b.values); + case 132: + return _mm512_mask_blend_ps(0xC030, a.values, b.values); + case 133: + return _mm512_mask_blend_ps(0xC033, a.values, b.values); + case 134: + return _mm512_mask_blend_ps(0xC03C, a.values, b.values); + case 135: + return _mm512_mask_blend_ps(0xC03F, a.values, b.values); + case 136: + return _mm512_mask_blend_ps(0xC0C0, a.values, b.values); + case 137: + return _mm512_mask_blend_ps(0xC0C3, a.values, b.values); + case 138: + return _mm512_mask_blend_ps(0xC0CC, a.values, b.values); + case 139: + return _mm512_mask_blend_ps(0xC0CF, a.values, b.values); + case 140: + return _mm512_mask_blend_ps(0xC0F0, a.values, b.values); + case 141: + return _mm512_mask_blend_ps(0xC0F3, a.values, b.values); + case 142: + return _mm512_mask_blend_ps(0xC0FC, a.values, b.values); + case 143: + return _mm512_mask_blend_ps(0xC0FF, a.values, b.values); + case 144: + return _mm512_mask_blend_ps(0xC300, a.values, b.values); + case 145: + return _mm512_mask_blend_ps(0xC303, a.values, b.values); + case 146: + return _mm512_mask_blend_ps(0xC30C, a.values, b.values); + case 147: + return _mm512_mask_blend_ps(0xC30F, a.values, b.values); + case 148: + return _mm512_mask_blend_ps(0xC330, a.values, b.values); + case 149: + return _mm512_mask_blend_ps(0xC333, a.values, b.values); + case 150: + return _mm512_mask_blend_ps(0xC33C, a.values, b.values); + case 151: + return _mm512_mask_blend_ps(0xC33F, a.values, b.values); + case 152: + return _mm512_mask_blend_ps(0xC3C0, a.values, b.values); + case 153: + return _mm512_mask_blend_ps(0xC3C3, a.values, b.values); + case 154: + return _mm512_mask_blend_ps(0xC3CC, a.values, b.values); + case 155: + return _mm512_mask_blend_ps(0xC3CF, a.values, b.values); + case 156: + return _mm512_mask_blend_ps(0xC3F0, a.values, b.values); + case 157: + return _mm512_mask_blend_ps(0xC3F3, a.values, b.values); + case 158: + return _mm512_mask_blend_ps(0xC3FC, a.values, b.values); + case 159: + return _mm512_mask_blend_ps(0xC3FF, a.values, b.values); + case 160: + return _mm512_mask_blend_ps(0xCC00, a.values, b.values); + case 161: + return _mm512_mask_blend_ps(0xCC03, a.values, b.values); + case 162: + return _mm512_mask_blend_ps(0xCC0C, a.values, b.values); + case 163: + return _mm512_mask_blend_ps(0xCC0F, a.values, b.values); + case 164: + return _mm512_mask_blend_ps(0xCC30, a.values, b.values); + case 165: + return _mm512_mask_blend_ps(0xCC33, a.values, b.values); + case 166: + return _mm512_mask_blend_ps(0xCC3C, a.values, b.values); + case 167: + return _mm512_mask_blend_ps(0xCC3F, a.values, b.values); + case 168: + return _mm512_mask_blend_ps(0xCCC0, a.values, b.values); + case 169: + return _mm512_mask_blend_ps(0xCCC3, a.values, b.values); + case 170: + return _mm512_mask_blend_ps(0xCCCC, a.values, b.values); + case 171: + return _mm512_mask_blend_ps(0xCCCF, a.values, b.values); + case 172: + return _mm512_mask_blend_ps(0xCCF0, a.values, b.values); + case 173: + return _mm512_mask_blend_ps(0xCCF3, a.values, b.values); + case 174: + return _mm512_mask_blend_ps(0xCCFC, a.values, b.values); + case 175: + return _mm512_mask_blend_ps(0xCCFF, a.values, b.values); + case 176: + return _mm512_mask_blend_ps(0xCF00, a.values, b.values); + case 177: + return _mm512_mask_blend_ps(0xCF03, a.values, b.values); + case 178: + return _mm512_mask_blend_ps(0xCF0C, a.values, b.values); + case 179: + return _mm512_mask_blend_ps(0xCF0F, a.values, b.values); + case 180: + return _mm512_mask_blend_ps(0xCF30, a.values, b.values); + case 181: + return _mm512_mask_blend_ps(0xCF33, a.values, b.values); + case 182: + return _mm512_mask_blend_ps(0xCF3C, a.values, b.values); + case 183: + return _mm512_mask_blend_ps(0xCF3F, a.values, b.values); + case 184: + return _mm512_mask_blend_ps(0xCFC0, a.values, b.values); + case 185: + return _mm512_mask_blend_ps(0xCFC3, a.values, b.values); + case 186: + return _mm512_mask_blend_ps(0xCFCC, a.values, b.values); + case 187: + return _mm512_mask_blend_ps(0xCFCF, a.values, b.values); + case 188: + return _mm512_mask_blend_ps(0xCFF0, a.values, b.values); + case 189: + return _mm512_mask_blend_ps(0xCFF3, a.values, b.values); + case 190: + return _mm512_mask_blend_ps(0xCFFC, a.values, b.values); + case 191: + return _mm512_mask_blend_ps(0xCFFF, a.values, b.values); + case 192: + return _mm512_mask_blend_ps(0xF000, a.values, b.values); + case 193: + return _mm512_mask_blend_ps(0xF003, a.values, b.values); + case 194: + return _mm512_mask_blend_ps(0xF00C, a.values, b.values); + case 195: + return _mm512_mask_blend_ps(0xF00F, a.values, b.values); + case 196: + return _mm512_mask_blend_ps(0xF030, a.values, b.values); + case 197: + return _mm512_mask_blend_ps(0xF033, a.values, b.values); + case 198: + return _mm512_mask_blend_ps(0xF03C, a.values, b.values); + case 199: + return _mm512_mask_blend_ps(0xF03F, a.values, b.values); + case 200: + return _mm512_mask_blend_ps(0XF0C0, a.values, b.values); + case 201: + return _mm512_mask_blend_ps(0xF0C3, a.values, b.values); + case 202: + return _mm512_mask_blend_ps(0xF0CC, a.values, b.values); + case 203: + return _mm512_mask_blend_ps(0xF0CF, a.values, b.values); + case 204: + return _mm512_mask_blend_ps(0xF0F0, a.values, b.values); + case 205: + return _mm512_mask_blend_ps(0xF0F3, a.values, b.values); + case 206: + return _mm512_mask_blend_ps(0xF0FC, a.values, b.values); + case 207: + return _mm512_mask_blend_ps(0xF0FF, a.values, b.values); + case 208: + return _mm512_mask_blend_ps(0XF300, a.values, b.values); + case 209: + return _mm512_mask_blend_ps(0xF303, a.values, b.values); + case 210: + return _mm512_mask_blend_ps(0xF30C, a.values, b.values); + case 211: + return _mm512_mask_blend_ps(0xF30F, a.values, b.values); + case 212: + return _mm512_mask_blend_ps(0xF330, a.values, b.values); + case 213: + return _mm512_mask_blend_ps(0xF333, a.values, b.values); + case 214: + return _mm512_mask_blend_ps(0XF33C, a.values, b.values); + case 215: + return _mm512_mask_blend_ps(0xF33F, a.values, b.values); + case 216: + return _mm512_mask_blend_ps(0xF3C0, a.values, b.values); + case 217: + return _mm512_mask_blend_ps(0xF3C3, a.values, b.values); + case 218: + return _mm512_mask_blend_ps(0xF3CC, a.values, b.values); + case 219: + return _mm512_mask_blend_ps(0xF3CF, a.values, b.values); + case 220: + return _mm512_mask_blend_ps(0xF3F0, a.values, b.values); + case 221: + return _mm512_mask_blend_ps(0xF3F3, a.values, b.values); + case 222: + return _mm512_mask_blend_ps(0xF3FC, a.values, b.values); + case 223: + return _mm512_mask_blend_ps(0XF3FF, a.values, b.values); + case 224: + return _mm512_mask_blend_ps(0xFC00, a.values, b.values); + case 225: + return _mm512_mask_blend_ps(0xFC03, a.values, b.values); + case 226: + return _mm512_mask_blend_ps(0xFC0C, a.values, b.values); + case 227: + return _mm512_mask_blend_ps(0xFC0F, a.values, b.values); + case 228: + return _mm512_mask_blend_ps(0xFC30, a.values, b.values); + case 229: + return _mm512_mask_blend_ps(0xFC33, a.values, b.values); + case 230: + return _mm512_mask_blend_ps(0xFC3C, a.values, b.values); + case 231: + return _mm512_mask_blend_ps(0xFC3F, a.values, b.values); + case 232: + return _mm512_mask_blend_ps(0xFCC0, a.values, b.values); + case 233: + return _mm512_mask_blend_ps(0xFCC3, a.values, b.values); + case 234: + return _mm512_mask_blend_ps(0xFCCC, a.values, b.values); + case 235: + return _mm512_mask_blend_ps(0xFCCF, a.values, b.values); + case 236: + return _mm512_mask_blend_ps(0xFCF0, a.values, b.values); + case 237: + return _mm512_mask_blend_ps(0xFCF3, a.values, b.values); + case 238: + return _mm512_mask_blend_ps(0xFCFC, a.values, b.values); + case 239: + return _mm512_mask_blend_ps(0xFCFF, a.values, b.values); + case 240: + return _mm512_mask_blend_ps(0xFF00, a.values, b.values); + case 241: + return _mm512_mask_blend_ps(0xFF03, a.values, b.values); + case 242: + return _mm512_mask_blend_ps(0xFF0C, a.values, b.values); + case 243: + return _mm512_mask_blend_ps(0xFF0F, a.values, b.values); + case 244: + return _mm512_mask_blend_ps(0xFF30, a.values, b.values); + case 245: + return _mm512_mask_blend_ps(0xFF33, a.values, b.values); + case 246: + return _mm512_mask_blend_ps(0xFF3C, a.values, b.values); + case 247: + return _mm512_mask_blend_ps(0xFF3F, a.values, b.values); + case 248: + return _mm512_mask_blend_ps(0xFFC0, a.values, b.values); + case 249: + return _mm512_mask_blend_ps(0xFFC3, a.values, b.values); + case 250: + return _mm512_mask_blend_ps(0xFFCC, a.values, b.values); + case 251: + return _mm512_mask_blend_ps(0xFFCF, a.values, b.values); + case 252: + return _mm512_mask_blend_ps(0xFFF0, a.values, b.values); + case 253: + return _mm512_mask_blend_ps(0xFFF3, a.values, b.values); + case 254: + return _mm512_mask_blend_ps(0xFFFC, a.values, b.values); + default: + break; + } + return b; + } + static Vectorized> blendv( + const Vectorized>& a, + const Vectorized>& b, + const Vectorized>& mask) { + // convert c10::complex index mask to V index mask: xy -> xxyy + auto mask_ = _mm512_unpacklo_ps(mask.values, mask.values); + auto all_ones = _mm512_set1_epi32(0xFFFFFFFF); + auto mmask = _mm512_cmp_epi32_mask( + _mm512_castps_si512(mask_), all_ones, _MM_CMPINT_EQ); + return _mm512_mask_blend_ps(mmask, a.values, b.values); + } + template + static Vectorized> arange( + c10::complex base = 0., + step_t step = static_cast(1)) { + return Vectorized>( + base, + base + step, + base + c10::complex(2) * step, + base + c10::complex(3) * step, + base + c10::complex(4) * step, + base + c10::complex(5) * step, + base + c10::complex(6) * step, + base + c10::complex(7) * step); + } + static Vectorized> set( + const Vectorized>& a, + const Vectorized>& b, + int64_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + case 4: + return blend<15>(a, b); + case 5: + return blend<31>(a, b); + case 6: + return blend<63>(a, b); + case 7: + return blend<127>(a, b); + } + return b; + } + static Vectorized> loadu( + const void* ptr, + int64_t count = size()) { + if (count == size()) + return _mm512_loadu_ps(reinterpret_cast(ptr)); + + __at_align__ float tmp_values[2 * size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(2 * size())) { + tmp_values[i] = 0.0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(c10::complex)); + return _mm512_load_ps(tmp_values); + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + _mm512_storeu_ps(reinterpret_cast(ptr), values); + } else if (count > 0) { + float tmp_values[2 * size()]; + _mm512_storeu_ps(reinterpret_cast(tmp_values), values); + std::memcpy(ptr, tmp_values, count * sizeof(c10::complex)); + } + } + // AVX512 doesn't have horizontal add & horizontal sub instructions. + // TODO: hadd_pd() & hsub_pd() may have scope for improvement. + static inline __m512 hadd_ps(__m512 a, __m512 b) { + __m512i idx1 = _mm512_set_epi32( + 30, 14, 28, 12, 26, 10, 24, 8, 22, 6, 20, 4, 18, 2, 16, 0); + __m512i idx2 = _mm512_set_epi32( + 31, 15, 29, 13, 27, 11, 25, 9, 23, 7, 21, 5, 19, 3, 17, 1); + return _mm512_add_ps( + _mm512_mask_permutex2var_ps(a, 0xffff, idx1, b), + _mm512_mask_permutex2var_ps(a, 0xffff, idx2, b)); + } + static inline __m512 hsub_ps(__m512 a, __m512 b) { + __m512i idx1 = _mm512_set_epi32( + 30, 14, 28, 12, 26, 10, 24, 8, 22, 6, 20, 4, 18, 2, 16, 0); + __m512i idx2 = _mm512_set_epi32( + 31, 15, 29, 13, 27, 11, 25, 9, 23, 7, 21, 5, 19, 3, 17, 1); + return _mm512_sub_ps( + _mm512_mask_permutex2var_ps(a, 0xffff, idx1, b), + _mm512_mask_permutex2var_ps(a, 0xffff, idx2, b)); + } + const c10::complex& operator[](int idx) const = delete; + c10::complex& operator[](int idx) = delete; + Vectorized> map( + c10::complex (*const f)(const c10::complex&)) const { + __at_align__ c10::complex tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + __m512 abs_2_() const { + auto val_2 = _mm512_mul_ps(values, values); // a*a b*b + auto ret = hadd_ps(val_2, val_2); // a*a+b*b a*a+b*b + return ret; + } + __m512 abs_() const { + auto real = _mm512_moveldup_ps(values); // real real + auto imag = _mm512_movehdup_ps(values); // imag imag + return Sleef_hypotf16_u05(real, imag); // abs abs + } + Vectorized> abs() const { + const __m512 real_mask = _mm512_castsi512_ps(_mm512_setr_epi32( + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000)); + return _mm512_and_ps(abs_(), real_mask); // abs 0 + } + __m512 angle_() const { + // angle = atan2(b/a) + auto b_a = _mm512_permute_ps(values, 0xB1); // b a + return Sleef_atan2f16_u10(values, b_a); // 90-angle angle + } + Vectorized> angle() const { + const __m512 real_mask = _mm512_castsi512_ps(_mm512_setr_epi32( + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000)); + auto angle = _mm512_permute_ps(angle_(), 0xB1); // angle 90-angle + return _mm512_and_ps(angle, real_mask); // angle 0 + } + Vectorized> sgn() const { + auto abs = abs_(); + auto zero = _mm512_setzero_ps(); + auto mask = _mm512_cmp_ps_mask(abs, zero, _CMP_EQ_OQ); + auto div = _mm512_div_ps(values, abs); + return _mm512_mask_blend_ps(mask, div, zero); + } + __m512 real_() const { + const __m512 real_mask = _mm512_castsi512_ps(_mm512_setr_epi32( + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000)); + return _mm512_and_ps(values, real_mask); + } + Vectorized> real() const { + return real_(); + } + __m512 imag_() const { + const __m512 imag_mask = _mm512_castsi512_ps(_mm512_setr_epi32( + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF, + 0x00000000, + 0xFFFFFFFF)); + return _mm512_and_ps(values, imag_mask); + } + Vectorized> imag() const { + return _mm512_permute_ps(imag_(), 0xB1); // b a + } + __m512 conj_() const { + const __m512 sign_mask = _mm512_setr_ps( + 0.0, + -0.0, + 0.0, + -0.0, + 0.0, + -0.0, + 0.0, + -0.0, + 0.0, + -0.0, + 0.0, + -0.0, + 0.0, + -0.0, + 0.0, + -0.0); + return _mm512_xor_ps(values, sign_mask); // a -b + } + Vectorized> conj() const { + return conj_(); + } + Vectorized> log() const { + // Most trigonomic ops use the log() op to improve complex number + // performance. + return map(std::log); + } + Vectorized> log2() const { + const __m512 log2_ = _mm512_set1_ps(std::log(2)); + return _mm512_div_ps(log(), log2_); + } + Vectorized> log10() const { + const __m512 log10_ = _mm512_set1_ps(std::log(10)); + return _mm512_div_ps(log(), log10_); + } + Vectorized> log1p() const { + return map(std::log1p); + } + Vectorized> asin() const { + // TODO: The vectorized implementation requires special handling for the + // case where real number/imag number is 0/Inf/NaN. + // // asin(x) + // // = -i*ln(iz + sqrt(1 -z^2)) + // // = -i*ln((ai - b) + sqrt(1 - (a + bi)*(a + bi))) + // // = -i*ln((-b + ai) + sqrt(1 - (a**2 - b**2) - 2*abi)) + // const __m512 one = _mm512_set1_ps(1); + + // auto conj = conj_(); + // auto b_a = _mm512_permute_ps(conj, 0xB1); //-b a + // auto ab = _mm512_mul_ps(conj, b_a); //-ab + // -ab auto im = _mm512_add_ps(ab, ab); //-2ab -2ab + + // auto val_2 = _mm512_mul_ps(values, values); // a*a + // b*b auto re = hsub_ps(val_2, _mm512_permute_ps(val_2, 0xB1)); // a*a-b*b + // b*b-a*a re = _mm512_sub_ps(one, re); + + // auto root = Vectorized(_mm512_mask_blend_ps(0xAAAA, re, im)).sqrt(); + // //sqrt(re + i*im) auto ln = Vectorized(_mm512_add_ps(b_a, root)).log(); + // //ln(iz + sqrt()) return Vectorized(_mm512_permute_ps(ln.values, + // 0xB1)).conj(); //-i*ln() + return map(std::asin); + } + Vectorized> acos() const { + return map(std::acos); + } + Vectorized> atan() const; + Vectorized> atanh() const { + return map(std::atanh); + } + Vectorized> exp() const { + // TODO: The vectorized implementation requires special handling for the + // case where real number/imag number is 0/Inf/NaN. + // //exp(a + bi) + // // = exp(a)*(cos(b) + sin(b)i) + // auto exp = Sleef_expf16_u10(values); //exp(a) exp(b) exp = + // _mm512_mask_blend_ps(0xAAAA, exp, _mm512_permute_ps(exp, 0xB1)); //exp(a) + // exp(a) + + // auto sin_cos = Sleef_sincosf16_u10(values); //[sin(a), cos(a)] [sin(b), + // cos(b)] auto cos_sin = _mm512_mask_blend_ps(0xAAAA, + // _mm512_permute_ps(sin_cos.y, 0xB1), + // sin_cos.x); //cos(b) + // sin(b) + // return _mm512_mul_ps(exp, cos_sin); + return map(std::exp); + } + Vectorized> exp2() const { + // Use identity 2**x = exp(log(2) * x) + const __m512 ln_2 = _mm512_set1_ps(c10::ln_2); + Vectorized> scaled_values = _mm512_mul_ps(values, ln_2); + return scaled_values.exp(); + } + Vectorized> expm1() const { + return map(std::expm1); + } + Vectorized> sin() const { + return map(std::sin); + } + Vectorized> sinh() const { + return map(std::sinh); + } + Vectorized> cos() const { + return map(std::cos); + } + Vectorized> cosh() const { + return map(std::cosh); + } + Vectorized> ceil() const { + return _mm512_ceil_ps(values); + } + Vectorized> floor() const { + return _mm512_floor_ps(values); + } + Vectorized> neg() const { + auto zero = _mm512_setzero_ps(); + return _mm512_sub_ps(zero, values); + } + Vectorized> round() const { + return _mm512_roundscale_ps( + values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + } + Vectorized> tan() const { + return map(std::tan); + } + Vectorized> tanh() const { + return map(std::tanh); + } + Vectorized> trunc() const { + return _mm512_roundscale_ps( + values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); + } + Vectorized> sqrt() const { + return map(std::sqrt); + } + Vectorized> reciprocal() const; + Vectorized> rsqrt() const { + return sqrt().reciprocal(); + } + Vectorized> pow( + const Vectorized>& exp) const { + __at_align__ c10::complex x_tmp[size()]; + __at_align__ c10::complex y_tmp[size()]; + store(x_tmp); + exp.store(y_tmp); + for (const auto i : c10::irange(size())) { + x_tmp[i] = std::pow(x_tmp[i], y_tmp[i]); + } + return loadu(x_tmp); + } + // Comparison using the _CMP_**_OQ predicate. + // `O`: get false if an operand is NaN + // `Q`: do not raise if an operand is NaN + Vectorized> operator==( + const Vectorized>& other) const { + auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_EQ_OQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vector, mask, 0xFFFFFFFF)); + } + Vectorized> operator!=( + const Vectorized>& other) const { + auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_NEQ_UQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vector, mask, 0xFFFFFFFF)); + } + Vectorized> operator<( + const Vectorized>& other [[maybe_unused]]) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized> operator<=( + const Vectorized>& other [[maybe_unused]]) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized> operator>( + const Vectorized>& other [[maybe_unused]]) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized> operator>=( + const Vectorized>& other [[maybe_unused]]) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vectorized> eq( + const Vectorized>& other) const; + Vectorized> ne( + const Vectorized>& other) const; +}; + +template <> +Vectorized> inline operator+( + const Vectorized>& a, + const Vectorized>& b) { + return _mm512_add_ps(a, b); +} + +template <> +Vectorized> inline operator-( + const Vectorized>& a, + const Vectorized>& b) { + return _mm512_sub_ps(a, b); +} + +template <> +Vectorized> inline operator*( + const Vectorized>& a, + const Vectorized>& b) { + //(a + bi) * (c + di) = (ac - bd) + (ad + bc)i + const __m512 sign_mask = _mm512_setr_ps( + 0.0, + -0.0, + 0.0, + -0.0, + 0.0, + -0.0, + 0.0, + -0.0, + 0.0, + -0.0, + 0.0, + -0.0, + 0.0, + -0.0, + 0.0, + -0.0); + auto ac_bd = _mm512_mul_ps(a, b); // ac bd + + auto d_c = _mm512_permute_ps(b, 0xB1); // d c + d_c = _mm512_xor_ps(sign_mask, d_c); // d -c + auto ad_bc = _mm512_mul_ps(a, d_c); // ad -bc + + auto ret = Vectorized>::hsub_ps( + ac_bd, ad_bc); // ac - bd ad + bc + return ret; +} + +template <> +Vectorized> inline operator/( + const Vectorized>& a, + const Vectorized>& b) { + // TODO: The vectorized implementation requires special handling for the case + // where real number/imag number is 0/Inf/NaN. + // //re + im*i = (a + bi) / (c + di) + // auto mask = _mm512_set1_ps(-0.f); + // auto fabs_cd = _mm512_andnot_ps(mask, b); // |c| |d| + // auto fabs_dc = _mm512_permute_ps(fabs_cd, 0xB1); // |d| |c| + // auto scale = _mm512_rcp14_ps(_mm512_max_ps(fabs_cd, fabs_dc)); // 1/sc + // 1/sc auto a2 = _mm512_mul_ps(a, scale); // a/sc b/sc auto b2 = + // _mm512_mul_ps(b, scale); // c/sc d/sc auto acbd2 = + // _mm512_mul_ps(a2, b2); + + // const __m512 sign_mask = _mm512_setr_ps(-0.0, 0.0, -0.0, 0.0, -0.0, 0.0, + // -0.0, 0.0, + // -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, + // -0.0, 0.0); + // auto dc2 = _mm512_permute_ps(b2, 0xB1); // d/sc c/sc + // dc2 = _mm512_xor_ps(sign_mask, dc2); // -d/|c,d| c/sc + // auto adbc2 = _mm512_mul_ps(a2, dc2); //-ad/sc^2 bc/sc^2 + // auto res2 = Vectorized>::hadd_ps(acbd2, adbc2); + // //(ac+bd)/sc^2 (bc-ad)/sc^2 + + // // get the denominator + // auto denom2 = Vectorized>(b2).abs_2_(); // + // (c^2+d^2)/sc^2 (c^2+d^2)/sc^2 res2 = _mm512_div_ps(res2, denom2); return + // res2; + __at_align__ c10::complex + tmp1[Vectorized>::size()]; + __at_align__ c10::complex + tmp2[Vectorized>::size()]; + __at_align__ c10::complex out[Vectorized>::size()]; + a.store(tmp1); + b.store(tmp2); + for (const auto i : c10::irange(Vectorized>::size())) { + out[i] = tmp1[i] / tmp2[i]; + } + return _mm512_loadu_ps(reinterpret_cast(out)); +} + +// reciprocal. Implement this here so we can use multiplication. +inline Vectorized> Vectorized< + c10::complex>::reciprocal() const { + // TODO: The vectorized implementation requires special handling for the case + // where real number/imag number is 0/Inf/NaN. + // //re + im*i = (a + bi) / (c + di) + // //re = (ac + bd)/abs_2() = c/abs_2() + // //im = (bc - ad)/abs_2() = d/abs_2() + // const __m512 sign_mask = _mm512_setr_ps(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, + // 0.0, -0.0, + // 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, + // 0.0, -0.0); + // auto c_d = _mm512_xor_ps(sign_mask, values); //c -d + // return _mm512_div_ps(c_d, abs_2_()); + __at_align__ c10::complex tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + tmp[i] = c10::complex(1) / tmp[i]; + } + return loadu(tmp); +} + +inline Vectorized> Vectorized>::atan() + const { + // TODO: The vectorized implementation requires special handling for the case + // where real number/imag number is 0/Inf/NaN. + // // atan(x) = i/2 * ln((i + z)/(i - z)) + // const __m512 i = _mm512_setr_ps(0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, + // 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0); + // const Vectorized i_half = _mm512_setr_ps(0.0, 0.5, 0.0, 0.5, 0.0, 0.5, 0.0, + // 0.5, + // 0.0, 0.5, 0.0, 0.5, 0.0, 0.5, 0.0, + // 0.5); + + // auto sum = Vectorized(_mm512_add_ps(i, values)); // a + // 1+b auto sub = Vectorized(_mm512_sub_ps(i, values)); // -a 1-b auto + // ln = (sum/sub).log(); // ln((i + + // z)/(i - z)) return i_half*ln; // i/2*ln() + return map(std::atan); +} + +template <> +Vectorized> inline maximum( + const Vectorized>& a, + const Vectorized>& b) { + auto zero_vector = _mm512_set1_epi32(0); + auto abs_a = a.abs_2_(); + auto abs_b = b.abs_2_(); + auto mask = _mm512_cmp_ps_mask(abs_a, abs_b, _CMP_LT_OQ); + auto max = _mm512_mask_blend_ps(mask, a, b); + // Exploit the fact that all-ones is a NaN. + auto isnan_mask = _mm512_cmp_ps_mask(abs_a, abs_b, _CMP_UNORD_Q); + auto isnan = _mm512_mask_set1_epi32(zero_vector, isnan_mask, 0xFFFFFFFF); + return _mm512_or_ps(max, _mm512_castsi512_ps(isnan)); +} + +template <> +Vectorized> inline minimum( + const Vectorized>& a, + const Vectorized>& b) { + auto zero_vector = _mm512_set1_epi32(0); + auto abs_a = a.abs_2_(); + auto abs_b = b.abs_2_(); + auto mask = _mm512_cmp_ps_mask(abs_a, abs_b, _CMP_GT_OQ); + auto min = _mm512_mask_blend_ps(mask, a, b); + // Exploit the fact that all-ones is a NaN. + auto isnan_mask = _mm512_cmp_ps_mask(abs_a, abs_b, _CMP_UNORD_Q); + auto isnan = _mm512_mask_set1_epi32(zero_vector, isnan_mask, 0xFFFFFFFF); + return _mm512_or_ps(min, _mm512_castsi512_ps(isnan)); +} + +template <> +Vectorized> inline operator&( + const Vectorized>& a, + const Vectorized>& b) { + return _mm512_and_ps(a, b); +} + +template <> +Vectorized> inline operator|( + const Vectorized>& a, + const Vectorized>& b) { + return _mm512_or_ps(a, b); +} + +template <> +Vectorized> inline operator^( + const Vectorized>& a, + const Vectorized>& b) { + return _mm512_xor_ps(a, b); +} + +inline Vectorized> Vectorized>::eq( + const Vectorized>& other) const { + auto eq = (*this == other); // compares real and imag individually + // If both real numbers and imag numbers are equal, then the complex numbers + // are equal + return (eq.real() & eq.imag()) & + Vectorized>(_mm512_set1_ps(1.0f)); +} + +inline Vectorized> Vectorized>::ne( + const Vectorized>& other) const { + auto ne = (*this != other); // compares real and imag individually + // If either real numbers or imag numbers are not equal, then the complex + // numbers are not equal + return (ne.real() | ne.imag()) & + Vectorized>(_mm512_set1_ps(1.0f)); +} + +#endif + +} // namespace CPU_CAPABILITY +} // namespace at::vec + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_convert.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_convert.h new file mode 100644 index 0000000000000000000000000000000000000000..44d8b70fa3c512d3b30557631b7cfed674252df9 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_convert.h @@ -0,0 +1,345 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include + +namespace at::vec { +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER) + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + VectorizedN result; + __m512 value; + cvtbf16_fp32(_mm512_castsi512_si256(src[0]), value); + result[0] = value; + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply(const VectorizedN& src) { + VectorizedN result; + __m512 value; + cvtfp16_fp32(_mm512_castsi512_si256(src[0]), value); + result[0] = value; + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + VectorizedN result; + result[0] = _mm512_castsi256_si512(cvtfp32_bf16(src[0])); + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + VectorizedN result; + result[0] = convert_float_bfloat16(src[0], src[1]); + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + VectorizedN result; + std::tie(result[0], result[1]) = convert_bfloat16_float(src[0]); + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply(const VectorizedN& src) { + VectorizedN result; + result[0] = _mm512_castsi256_si512(cvtfp32_fp16(src[0])); + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply(const VectorizedN& src) { + VectorizedN result; + result[0] = convert_float_half(src[0], src[1]); + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply(const VectorizedN& src) { + VectorizedN result; + std::tie(result[0], result[1]) = convert_half_float(src[0]); + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + auto low = _mm512_cvtepi64_ps(src[0]); + auto high = _mm512_cvtepi64_ps(src[1]); + return Vectorized( + _mm512_insertf32x8(_mm512_castps256_ps512(low), high, 1)); + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + at::vec::VectorizedN result; + result[0] = _mm512_cvt_roundps_epi64( + _mm512_castps512_ps256(src[0]), _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC); + result[1] = _mm512_cvt_roundps_epi64( + _mm512_extractf32x8_ps(src[0], 1), + _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC); + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + auto low = _mm512_cvtepi64_epi32(src[0]); + auto high = _mm512_cvtepi64_epi32(src[1]); + return Vectorized( + _mm512_inserti32x8(_mm512_castsi256_si512(low), high, 1)); + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + at::vec::VectorizedN result; + result[0] = _mm512_cvtepi32_epi64(_mm512_castsi512_si256(src[0])); + result[1] = _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(src[0], 1)); + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + auto src128 = _mm512_castsi512_si128(src[0]); + return Vectorized(_mm512_cvtepi8_epi32(src128)); + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + auto src128 = _mm512_castsi512_si128(src[0]); + return Vectorized(_mm512_cvtepu8_epi32(src128)); + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + return Vectorized(_mm512_cvttps_epi32(src[0])); + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + return Vectorized(_mm512_cvtepi32_ps(src[0])); + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + auto src256 = _mm512_castsi512_si256(src[0]); + return Vectorized(_mm512_cvtepu8_epi16(src256)); + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + auto src128 = _mm512_cvtepi32_epi8(src[0]); + return Vectorized(_mm512_castsi128_si512(src128)); + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + auto src256 = _mm512_cvtepi16_epi8(src[0]); + return Vectorized(_mm512_castsi256_si512(src256)); + } +}; + +template +struct VecConvert< + dst_t, + 1, + src_t, + 1, + typename std::enable_if_t< + (is_reduced_floating_point_v && is_8bit_integer_v) || + (is_reduced_floating_point_v && is_8bit_integer_v), + void>> { + static inline VectorizedN apply(const VectorizedN& src) { + VectorizedN tmp_fp32 = VecConvert::apply(src); + return VecConvert::apply(tmp_fp32); + } +}; + +template +struct VecConvert< + dst_t, + 1, + float, + 2, + typename std::enable_if_t, void>> { + static inline VectorizedN apply(const VectorizedN& src) { + at::vec::Vectorized vec1 = convert_float_to_int8(src[0]); + at::vec::Vectorized vec2 = convert_float_to_int8(src[1]); + __m128 lane2 = _mm512_castps512_ps128(_mm512_castsi512_ps(vec2)); + __m512 result = _mm512_insertf32x4( + _mm512_castsi512_ps(vec1), + lane2, + 1); // Insert lane2 into the second 128-bit lane + return at::vec::Vectorized(_mm512_castps_si512(result)); + } +}; + +template +struct VecConvert< + dst_t, + 1, + float, + 1, + typename std::enable_if_t, void>> { + static inline VectorizedN apply(const VectorizedN& src) { + return convert_float_to_int8(src[0]); + } +}; + +template +struct VecConvert< + float, + 2, + src_t, + 1, + typename std::enable_if_t, void>> { + static inline VectorizedN apply(const VectorizedN& src) { + __m512i src2 = + _mm512_castsi128_si512(_mm_castps_si128(_mm512_extractf32x4_ps( + _mm512_castsi512_ps(src[0]), 1) // Extract the second 128-bit lane + )); + return VectorizedN( + convert_int8_to_float(src[0]), + convert_int8_to_float(src2)); + } +}; + +template +struct VecConvert< + float, + 1, + src_t, + 1, + typename std::enable_if_t, void>> { + static inline VectorizedN apply(const VectorizedN& src) { + return convert_int8_to_float(src[0]); + } +}; + +template +struct VecConvert< + dst_t, + 1, + int64_t, + 2, + std::enable_if_t< + std::is_same_v || std::is_same_v>> { + static inline VectorizedN apply( + const VectorizedN& src) { + return VecConvert::apply( + VecConvert::apply(src)); + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src_n) { + at::vec::Vectorized src = src_n[0]; + __m128i res128 = cvtfp32_fp8e4m3(src); + return at::vec::Vectorized(_mm512_castsi128_si512(res128)); + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src_n) { + // cvt first 16x8 bits from Float8_e4m3fn to float + at::vec::Vectorized src = src_n[0]; + __m512 result; + cvtfp8e4m3_fp32(_mm512_castsi512_si128(src), result); + return at::vec::Vectorized(result); + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src_n) { + at::vec::Vectorized src = src_n[0]; + __m128i res128 = cvtfp32_fp8e5m2(src); + return at::vec::Vectorized(_mm512_castsi128_si512(res128)); + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src_n) { + // cvt first 16x8 bits from Float8_e5m2 to float + at::vec::Vectorized src = src_n[0]; + __m512 result; + cvtfp8e5m2_fp32(_mm512_castsi512_si128(src), result); + return at::vec::Vectorized(result); + } +}; + +#endif + +} // namespace CPU_CAPABILITY +} // namespace at::vec + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_double.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_double.h new file mode 100644 index 0000000000000000000000000000000000000000..d1ca121d301df6c9fb71b0eef28a9efe8fd03f8b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_double.h @@ -0,0 +1,571 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include +#include +#if (defined(CPU_CAPABILITY_AVX512)) +#define SLEEF_STATIC_LIBS +#include +#endif + +namespace at::vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_AVX512) + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized { + private: + static constexpr __m512i zero_vector{0, 0, 0, 0, 0, 0, 0, 0}; + + public: + // values needs to be public for compilation with clang + // as vec512.h uses it + __m512d values; + using value_type = double; + using size_type = int; + static constexpr size_type size() { + return 8; + } + Vectorized() { + values = _mm512_setzero_pd(); + } + Vectorized(__m512d v) : values(v) {} + Vectorized(double val) { + values = _mm512_set1_pd(val); + } + Vectorized( + double val1, + double val2, + double val3, + double val4, + double val5, + double val6, + double val7, + double val8) { + values = _mm512_setr_pd(val1, val2, val3, val4, val5, val6, val7, val8); + } + operator __m512d() const { + return values; + } + template + static Vectorized blend( + const Vectorized& a, + const Vectorized& b) { + return _mm512_mask_blend_pd(mask, a.values, b.values); + } + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + auto all_ones = _mm512_set1_epi64(0xFFFFFFFFFFFFFFFF); + auto mmask = _mm512_cmp_epi64_mask( + _mm512_castpd_si512(mask.values), all_ones, _MM_CMPINT_EQ); + return _mm512_mask_blend_pd(mmask, a.values, b.values); + } + template + static Vectorized arange( + double base = 0., + step_t step = static_cast(1)) { + return Vectorized( + base, + base + step, + base + 2 * step, + base + 3 * step, + base + 4 * step, + base + 5 * step, + base + 6 * step, + base + 7 * step); + } + static Vectorized set( + const Vectorized& a, + const Vectorized& b, + int64_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + case 4: + return blend<15>(a, b); + case 5: + return blend<31>(a, b); + case 6: + return blend<63>(a, b); + case 7: + return blend<127>(a, b); + } + return b; + } + static Vectorized loadu(const void* ptr, int64_t count = size()) { + if (count == size()) + return _mm512_loadu_pd(reinterpret_cast(ptr)); + + __mmask8 mask = (1ULL << count) - 1; + return _mm512_maskz_loadu_pd(mask, ptr); + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + _mm512_storeu_pd(reinterpret_cast(ptr), values); + } else if (count > 0) { + __mmask8 mask = (1ULL << count) - 1; + _mm512_mask_storeu_pd(reinterpret_cast(ptr), mask, values); + } + } + const double& operator[](int idx) const = delete; + double& operator[](int idx) = delete; + int zero_mask() const { + // returns an integer mask where all zero elements are translated to 1-bit + // and others are translated to 0-bit + __mmask8 cmp = _mm512_cmp_pd_mask(values, _mm512_set1_pd(0.0), _CMP_EQ_OQ); + return static_cast(cmp); + } + Vectorized isnan() const { + auto cmp_mask = + _mm512_cmp_pd_mask(values, _mm512_set1_pd(0.0), _CMP_UNORD_Q); + return _mm512_castsi512_pd( + _mm512_mask_set1_epi64(zero_vector, cmp_mask, 0xFFFFFFFFFFFFFFFF)); + } + bool has_inf_nan() const { + __m512d self_sub = _mm512_sub_pd(values, values); + return (_mm512_movepi8_mask(_mm512_castpd_si512(self_sub)) & + 0x7777777777777777) != 0; + } + Vectorized map(double (*const f)(double)) const { + __at_align__ double tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + Vectorized abs() const { + auto mask = _mm512_set1_pd(-0.f); + return _mm512_andnot_pd(mask, values); + } + Vectorized angle() const { + const auto zero_vec = _mm512_castsi512_pd(zero_vector); + const auto nan_vec = _mm512_set1_pd(NAN); + const auto not_nan_mask = _mm512_cmp_pd_mask(values, values, _CMP_EQ_OQ); + const auto not_nan = + _mm512_mask_set1_epi64(zero_vector, not_nan_mask, 0xFFFFFFFFFFFFFFFF); + const auto nan_mask = + _mm512_cmp_pd_mask(_mm512_castsi512_pd(not_nan), zero_vec, _CMP_EQ_OQ); + const auto pi = _mm512_set1_pd(c10::pi); + + const auto neg_mask = _mm512_cmp_pd_mask(values, zero_vec, _CMP_LT_OQ); + auto angle = _mm512_mask_blend_pd(neg_mask, zero_vec, pi); + angle = _mm512_mask_blend_pd(nan_mask, angle, nan_vec); + return angle; + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return _mm512_set1_pd(0); + } + Vectorized conj() const { + return *this; + } + Vectorized acos() const { + return Vectorized(Sleef_acosd8_u10(values)); + } + Vectorized acosh() const { + return Vectorized(Sleef_acoshd8_u10(values)); + } + Vectorized asin() const { + return Vectorized(Sleef_asind8_u10(values)); + } + Vectorized asinh() const { + return Vectorized(Sleef_asinhd8_u10(values)); + } + Vectorized atan() const { + return Vectorized(Sleef_atand8_u10(values)); + } + Vectorized atanh() const { + return Vectorized(Sleef_atanhd8_u10(values)); + } + Vectorized atan2(const Vectorized& b) const { + return Vectorized(Sleef_atan2d8_u10(values, b)); + } + Vectorized copysign(const Vectorized& sign) const { + return Vectorized(Sleef_copysignd8(values, sign)); + } + Vectorized erf() const { + return Vectorized(Sleef_erfd8_u10(values)); + } + Vectorized erfc() const { + return Vectorized(Sleef_erfcd8_u15(values)); + } + Vectorized erfinv() const { + return map(calc_erfinv); + } + Vectorized exp() const { + return Vectorized(Sleef_expd8_u10(values)); + } + Vectorized exp2() const { + return Vectorized(Sleef_exp2d8_u10(values)); + } + Vectorized expm1() const { + return Vectorized(Sleef_expm1d8_u10(values)); + } + Vectorized exp_u20() const { + return exp(); + } + Vectorized fexp_u20() const { + return exp(); + } + Vectorized fmod(const Vectorized& q) const { + return Vectorized(Sleef_fmodd8(values, q)); + } + Vectorized hypot(const Vectorized& b) const { + return Vectorized(Sleef_hypotd8_u05(values, b)); + } + Vectorized i0() const { + return map(calc_i0); + } + Vectorized i0e() const { + return map(calc_i0e); + } + Vectorized digamma() const { + return map(calc_digamma); + } + Vectorized igamma(const Vectorized& x) const { + __at_align__ double tmp[size()]; + __at_align__ double tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (const auto i : c10::irange(size())) { + tmp[i] = calc_igamma(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } + Vectorized igammac(const Vectorized& x) const { + __at_align__ double tmp[size()]; + __at_align__ double tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (const auto i : c10::irange(size())) { + tmp[i] = calc_igammac(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } + Vectorized log() const { + return Vectorized(Sleef_logd8_u10(values)); + } + Vectorized log2() const { + return Vectorized(Sleef_log2d8_u10(values)); + } + Vectorized log10() const { + return Vectorized(Sleef_log10d8_u10(values)); + } + Vectorized log1p() const { + return Vectorized(Sleef_log1pd8_u10(values)); + } + Vectorized sin() const { + return Vectorized(Sleef_sind8_u10(values)); + } + Vectorized sinh() const { + return Vectorized(Sleef_sinhd8_u10(values)); + } + Vectorized cos() const { + return Vectorized(Sleef_cosd8_u10(values)); + } + Vectorized cosh() const { + return Vectorized(Sleef_coshd8_u10(values)); + } + Vectorized ceil() const { + return _mm512_ceil_pd(values); + } + Vectorized floor() const { + return _mm512_floor_pd(values); + } + Vectorized frac() const; + Vectorized neg() const { + return _mm512_xor_pd(_mm512_set1_pd(-0.), values); + } + Vectorized nextafter(const Vectorized& b) const { + return Vectorized(Sleef_nextafterd8(values, b)); + } + Vectorized round() const { + return _mm512_roundscale_pd( + values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + } + Vectorized tan() const { + return Vectorized(Sleef_tand8_u10(values)); + } + Vectorized tanh() const { + return Vectorized(Sleef_tanhd8_u10(values)); + } + Vectorized trunc() const { + return _mm512_roundscale_pd( + values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); + } + Vectorized lgamma() const { + return Vectorized(Sleef_lgammad8_u10(values)); + } + Vectorized sqrt() const { + return _mm512_sqrt_pd(values); + } + Vectorized reciprocal() const { + return _mm512_div_pd(_mm512_set1_pd(1), values); + } + Vectorized rsqrt() const { + return _mm512_div_pd(_mm512_set1_pd(1), _mm512_sqrt_pd(values)); + } + Vectorized pow(const Vectorized& b) const { + return Vectorized(Sleef_powd8_u10(values, b)); + } + // Comparison using the _CMP_**_OQ predicate. + // `O`: get false if an operand is NaN + // `Q`: do not raise if an operand is NaN + Vectorized operator==(const Vectorized& other) const { + auto cmp_mask = _mm512_cmp_pd_mask(values, other.values, _CMP_EQ_OQ); + return _mm512_castsi512_pd( + _mm512_mask_set1_epi64(zero_vector, cmp_mask, 0xFFFFFFFFFFFFFFFF)); + } + + Vectorized operator!=(const Vectorized& other) const { + auto cmp_mask = _mm512_cmp_pd_mask(values, other.values, _CMP_NEQ_UQ); + return _mm512_castsi512_pd( + _mm512_mask_set1_epi64(zero_vector, cmp_mask, 0xFFFFFFFFFFFFFFFF)); + } + + Vectorized operator<(const Vectorized& other) const { + auto cmp_mask = _mm512_cmp_pd_mask(values, other.values, _CMP_LT_OQ); + return _mm512_castsi512_pd( + _mm512_mask_set1_epi64(zero_vector, cmp_mask, 0xFFFFFFFFFFFFFFFF)); + } + + Vectorized operator<=(const Vectorized& other) const { + auto cmp_mask = _mm512_cmp_pd_mask(values, other.values, _CMP_LE_OQ); + return _mm512_castsi512_pd( + _mm512_mask_set1_epi64(zero_vector, cmp_mask, 0xFFFFFFFFFFFFFFFF)); + } + + Vectorized operator>(const Vectorized& other) const { + auto cmp_mask = _mm512_cmp_pd_mask(values, other.values, _CMP_GT_OQ); + return _mm512_castsi512_pd( + _mm512_mask_set1_epi64(zero_vector, cmp_mask, 0xFFFFFFFFFFFFFFFF)); + } + + Vectorized operator>=(const Vectorized& other) const { + auto cmp_mask = _mm512_cmp_pd_mask(values, other.values, _CMP_GE_OQ); + return _mm512_castsi512_pd( + _mm512_mask_set1_epi64(zero_vector, cmp_mask, 0xFFFFFFFFFFFFFFFF)); + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; +}; + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return _mm512_add_pd(a, b); +} + +template <> +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return _mm512_sub_pd(a, b); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return _mm512_mul_pd(a, b); +} + +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return _mm512_div_pd(a, b); +} + +// frac. Implement this here so we can use subtraction. +inline Vectorized Vectorized::frac() const { + return *this - this->trunc(); +} + +// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + auto zero_vec = _mm512_set1_epi64(0); + Vectorized max = _mm512_max_pd(a, b); + auto isnan_mask = _mm512_cmp_pd_mask(a, b, _CMP_UNORD_Q); + auto isnan = _mm512_castsi512_pd( + _mm512_mask_set1_epi64(zero_vec, isnan_mask, 0xFFFFFFFFFFFFFFFF)); + // Exploit the fact that all-ones is a NaN. + return _mm512_or_pd(max, isnan); +} + +// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + auto zero_vec = _mm512_set1_epi64(0); + Vectorized min = _mm512_min_pd(a, b); + auto isnan_mask = _mm512_cmp_pd_mask(a, b, _CMP_UNORD_Q); + auto isnan = _mm512_castsi512_pd( + _mm512_mask_set1_epi64(zero_vec, isnan_mask, 0xFFFFFFFFFFFFFFFF)); + // Exploit the fact that all-ones is a NaN. + return _mm512_or_pd(min, isnan); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min, + const Vectorized& max) { + return _mm512_min_pd(max, _mm512_max_pd(min, a)); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min) { + return _mm512_max_pd(min, a); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max) { + return _mm512_min_pd(max, a); +} + +template <> +Vectorized inline operator&( + const Vectorized& a, + const Vectorized& b) { + return _mm512_and_pd(a, b); +} + +template <> +Vectorized inline operator|( + const Vectorized& a, + const Vectorized& b) { + return _mm512_or_pd(a, b); +} + +template <> +Vectorized inline operator^( + const Vectorized& a, + const Vectorized& b) { + return _mm512_xor_pd(a, b); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1.0); +} + +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1.0); +} + +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1.0); +} + +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1.0); +} + +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1.0); +} + +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1.0); +} + +template <> +inline void convert(const double* src, double* dst, int64_t n) { + int64_t i; +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (i = 0; i <= (n - Vectorized::size()); + i += Vectorized::size()) { + _mm512_storeu_pd(dst + i, _mm512_loadu_pd(src + i)); + } +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (; i < n; i++) { + dst[i] = src[i]; + } +} + +template <> +Vectorized inline fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return _mm512_fmadd_pd(a, b, c); +} + +template <> +Vectorized inline fnmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return _mm512_fnmadd_pd(a, b, c); +} + +template <> +Vectorized inline fmsub( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return _mm512_fmsub_pd(a, b, c); +} + +template <> +Vectorized inline fnmsub( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return _mm512_fnmsub_pd(a, b, c); +} + +#endif + +} // namespace CPU_CAPABILITY +} // namespace at::vec + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_float.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_float.h new file mode 100644 index 0000000000000000000000000000000000000000..e390db15bfa62b8607ffa72e8bca018e8e1a9432 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_float.h @@ -0,0 +1,945 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include +#include +#if defined(CPU_CAPABILITY_AVX512) +#define SLEEF_STATIC_LIBS +#include +#endif + +namespace at::vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_AVX512) + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized { + private: + static constexpr __m512i zero_vec{0, 0, 0, 0, 0, 0, 0, 0}; + + public: + __m512 values; + using value_type = float; + using size_type = int; + static constexpr size_type size() { + return 16; + } + Vectorized() { + values = _mm512_setzero_ps(); + } + Vectorized(__m512 v) : values(v) {} + Vectorized(float val) { + values = _mm512_set1_ps(val); + } + Vectorized( + float val1, + float val2, + float val3, + float val4, + float val5, + float val6, + float val7, + float val8, + float val9, + float val10, + float val11, + float val12, + float val13, + float val14, + float val15, + float val16) { + values = _mm512_setr_ps( + val1, + val2, + val3, + val4, + val5, + val6, + val7, + val8, + val9, + val10, + val11, + val12, + val13, + val14, + val15, + val16); + } + Vectorized(const float (&arr)[16]) + : Vectorized( + arr[0], + arr[1], + arr[2], + arr[3], + arr[4], + arr[5], + arr[6], + arr[7], + arr[8], + arr[9], + arr[10], + arr[11], + arr[12], + arr[13], + arr[14], + arr[15]) {} + operator __m512() const { + return values; + } + template + static Vectorized blend( + const Vectorized& a, + const Vectorized& b) { + return _mm512_mask_blend_ps(mask, a.values, b.values); + } + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + auto all_ones = _mm512_set1_epi32(0xFFFFFFFF); + auto mmask = _mm512_cmp_epi32_mask( + _mm512_castps_si512(mask.values), all_ones, _MM_CMPINT_EQ); + return _mm512_mask_blend_ps(mmask, a.values, b.values); + } + template + static Vectorized arange( + float base = 0.f, + step_t step = static_cast(1)) { + return Vectorized( + base, + base + step, + base + 2 * step, + base + 3 * step, + base + 4 * step, + base + 5 * step, + base + 6 * step, + base + 7 * step, + base + 8 * step, + base + 9 * step, + base + 10 * step, + base + 11 * step, + base + 12 * step, + base + 13 * step, + base + 14 * step, + base + 15 * step); + } + static Vectorized set( + const Vectorized& a, + const Vectorized& b, + int64_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + case 4: + return blend<15>(a, b); + case 5: + return blend<31>(a, b); + case 6: + return blend<63>(a, b); + case 7: + return blend<127>(a, b); + case 8: + return blend<255>(a, b); + case 9: + return blend<511>(a, b); + case 10: + return blend<1023>(a, b); + case 11: + return blend<2047>(a, b); + case 12: + return blend<4095>(a, b); + case 13: + return blend<8191>(a, b); + case 14: + return blend<16383>(a, b); + case 15: + return blend<32767>(a, b); + } + return b; + } + static Vectorized loadu(const void* ptr, int64_t count = size()) { + if (count == size()) + return _mm512_loadu_ps(reinterpret_cast(ptr)); + + __mmask16 mask = (1ULL << count) - 1; + return _mm512_maskz_loadu_ps(mask, ptr); + } + void store(void* ptr, int64_t count = size()) const { + if (count == size()) { + _mm512_storeu_ps(reinterpret_cast(ptr), values); + } else if (count > 0) { + __mmask16 mask = (1ULL << count) - 1; + _mm512_mask_storeu_ps(reinterpret_cast(ptr), mask, values); + } + } + const float& operator[](int idx) const = delete; + float& operator[](int idx) = delete; + int zero_mask() const { + // returns an integer mask where all zero elements are translated to 1-bit + // and others are translated to 0-bit + __mmask16 cmp = _mm512_cmp_ps_mask(values, _mm512_set1_ps(0.0), _CMP_EQ_OQ); + return static_cast(cmp); + } + Vectorized isnan() const { + auto mask = _mm512_cmp_ps_mask(values, _mm512_set1_ps(0.0), _CMP_UNORD_Q); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, mask, 0xFFFFFFFF)); + } + bool has_inf_nan() const { + __m512 self_sub = _mm512_sub_ps(values, values); + return (_mm512_movepi8_mask(_mm512_castps_si512(self_sub)) & + 0x7777777777777777) != 0; + } + Vectorized map(float (*const f)(float)) const { + __at_align__ float tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + Vectorized abs() const { + auto mask = _mm512_set1_ps(-0.f); + return _mm512_andnot_ps(mask, values); + } + Vectorized angle() const { + __m512 zero_vec = _mm512_set1_ps(0.f); + const auto nan_vec = _mm512_set1_ps(NAN); + const auto not_nan_mask = _mm512_cmp_ps_mask(values, values, _CMP_EQ_OQ); + const auto not_nan_vec = _mm512_mask_set1_epi32( + _mm512_castps_si512(zero_vec), not_nan_mask, 0xFFFFFFFF); + const auto nan_mask = _mm512_cmp_ps_mask( + _mm512_castsi512_ps(not_nan_vec), zero_vec, _CMP_EQ_OQ); + const auto pi = _mm512_set1_ps(c10::pi); + + const auto neg_mask = _mm512_cmp_ps_mask(values, zero_vec, _CMP_LT_OQ); + auto angle = _mm512_mask_blend_ps(neg_mask, zero_vec, pi); + angle = _mm512_mask_blend_ps(nan_mask, angle, nan_vec); + return angle; + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return _mm512_set1_ps(0); + } + Vectorized conj() const { + return *this; + } + Vectorized acos() const { + return Vectorized(Sleef_acosf16_u10(values)); + } + Vectorized acosh() const { + return Vectorized(Sleef_acoshf16_u10(values)); + } + Vectorized asin() const { + return Vectorized(Sleef_asinf16_u10(values)); + } + Vectorized asinh() const { + return Vectorized(Sleef_asinhf16_u10(values)); + } + Vectorized atan() const { + return Vectorized(Sleef_atanf16_u10(values)); + } + Vectorized atanh() const { + return Vectorized(Sleef_atanhf16_u10(values)); + } + Vectorized atan2(const Vectorized& b) const { + return Vectorized(Sleef_atan2f16_u10(values, b)); + } + Vectorized copysign(const Vectorized& sign) const { + return Vectorized(Sleef_copysignf16(values, sign)); + } + Vectorized erf() const { + // constants + const auto neg_zero_vec = _mm512_set1_ps(-0.f); + const auto one_vec = _mm512_set1_ps(1.0f); + const auto p = _mm512_set1_ps(0.3275911f); + const auto p1 = _mm512_set1_ps(0.254829592f); + const auto p2 = _mm512_set1_ps(-0.284496736f); + const auto p3 = _mm512_set1_ps(1.421413741f); + const auto p4 = _mm512_set1_ps(-1.453152027f); + const auto p5 = _mm512_set1_ps(1.061405429f); + // sign(x) + auto sign_mask = _mm512_and_ps(neg_zero_vec, values); + auto abs_vec = _mm512_abs_ps(values); + // t = 1 / (p * abs(x) + 1) + auto tmp0 = _mm512_fmadd_ps(p, abs_vec, one_vec); + auto t = _mm512_div_ps(one_vec, tmp0); + // r = p5 * t ^ 4 + p4 * t ^ 3 + p3 * t ^ 2 + p2 * t + p1 + auto tmp1 = _mm512_fmadd_ps(p5, t, p4); + auto tmp2 = _mm512_fmadd_ps(tmp1, t, p3); + auto tmp3 = _mm512_fmadd_ps(tmp2, t, p2); + auto r = _mm512_fmadd_ps(tmp3, t, p1); + // - exp(- x * x) + auto pow_2 = _mm512_mul_ps(values, values); + auto neg_pow_2 = _mm512_xor_ps(neg_zero_vec, pow_2); + // auto tmp4 = exp(neg_pow_2); + auto tmp4 = Vectorized(Sleef_expf16_u10(neg_pow_2)); + auto tmp5 = _mm512_xor_ps(neg_zero_vec, tmp4); + // erf(x) = sign(x) * (1 - r * t * exp(- x * x)) + auto tmp6 = _mm512_mul_ps(tmp5, t); + auto tmp7 = _mm512_fmadd_ps(tmp6, r, one_vec); + return _mm512_xor_ps(sign_mask, tmp7); + } + Vectorized erfc() const { + return Vectorized(Sleef_erfcf16_u15(values)); + } + Vectorized erfinv() const { + return map(calc_erfinv); + } + Vectorized exp() const { + return Vectorized(Sleef_expf16_u10(values)); + } + Vectorized exp2() const { + return Vectorized(Sleef_exp2f16_u10(values)); + } + Vectorized expm1() const { + return Vectorized(Sleef_expm1f16_u10(values)); + } + Vectorized fexp_u20() const { + const __m512 vec_c0 = _mm512_set1_ps(0.00010703434948458272f); + const __m512 vec_c1 = _mm512_set1_ps(0.30354260500649682f); + const __m512 vec_c2 = _mm512_set1_ps(-0.22433836478672356); + const __m512 vec_c3 = _mm512_set1_ps(-0.079204240219773236); + + const __m512 vec_exp_log2ef = + _mm512_castsi512_ps(_mm512_set1_epi32(0x3fb8aa3b)); // log2(e) + + const __m512 vec_a = _mm512_set1_ps(std::pow(2, 23) / std::log2(2)); + const __m512 vec_b = _mm512_set1_ps(std::pow(2, 23) * 127.f); + + const __m512 vec_ln_flt_min = + _mm512_castsi512_ps(_mm512_set1_epi32(0xc2aeac50)); + const __m512 vec_ln_flt_max = + _mm512_castsi512_ps(_mm512_set1_epi32(0x42b17218)); + __m512i vec_infinity = _mm512_set1_epi32(0x7F800000); + __m512i vec_zero = _mm512_setzero_epi32(); + + // Fast Exponential Computation on SIMD Architectures + // A. Cristiano I. Malossi, Yves Ineichen, Costas Bekas, and Alessandro + // Curioni exp(x) = 2**(x * log2(e)) + // = 2**xi * 2**xf - TIPS we are using the EEEE floating point + // representation with identification to the exponent and the + // mentissa + // 2**xf will be approximated to a polynomial of degree 3 computed with + // Horner method + // mask for the boundary condition + auto min_mask = _mm512_cmp_ps_mask(values, vec_ln_flt_min, _CMP_LT_OS); + auto max_mask = _mm512_cmp_ps_mask(values, vec_ln_flt_max, _CMP_GT_OS); + + // transformation with log2(e) + auto vec_src = _mm512_mul_ps(values, vec_exp_log2ef); + auto vec_fractional = _mm512_sub_ps(vec_src, _mm512_floor_ps(vec_src)); + + // compute polynomial using Horner Scheme, for superscalar processor + auto vec_res = _mm512_fmadd_ps(vec_fractional, vec_c3, vec_c2); + vec_res = _mm512_fmadd_ps(vec_fractional, vec_res, vec_c1); + vec_res = _mm512_fmadd_ps(vec_fractional, vec_res, vec_c0); + + vec_src = _mm512_sub_ps(vec_src, vec_res); + // the tips is here, headache in perspective + auto tmp = _mm512_fmadd_ps(vec_a, vec_src, vec_b); + // headache bis - we loose precision with the cast but it "fits", but ok + // after f32 -> f16 later + __m512i casted_integer = _mm512_cvttps_epi32(tmp); + // boundary condition, lower than the min -> 0 + casted_integer = _mm512_mask_mov_epi32(casted_integer, min_mask, vec_zero); + // boundary condition, larger than the max -> +oo + casted_integer = + _mm512_mask_mov_epi32(casted_integer, max_mask, vec_infinity); + // final interpretation to float + return _mm512_castsi512_ps(casted_integer); + } + Vectorized exp_u20() const { + // A faster version of exp with ULP=20 + const __m512 vec_factorial_1 = + _mm512_set1_ps(0.999999701f); // 1/factorial(1) + const __m512 vec_factorial_2 = + _mm512_set1_ps(0.499991506f); // 1/factorial(2) + const __m512 vec_factorial_3 = + _mm512_set1_ps(0.166676521f); // 1/factorial(3) + const __m512 vec_factorial_4 = + _mm512_set1_ps(0.0418978221f); // 1/factorial(4) + const __m512 vec_factorial_5 = + _mm512_set1_ps(0.00828929059f); // 1/factorial(5) + const __m512 vec_exp_log2ef = + _mm512_castsi512_ps(_mm512_set1_epi32(0x3fb8aa3b)); // log2(e) + const __m512 vec_half = _mm512_set1_ps(0.5f); + const __m512 vec_one = _mm512_set1_ps(1.f); + const __m512 vec_zero = _mm512_set1_ps(0.f); + const __m512 vec_two = _mm512_set1_ps(2.f); + const __m512 vec_ln2f = + _mm512_castsi512_ps(_mm512_set1_epi32(0x3f317218)); // ln(2) + const __m512 vec_ln_flt_min = + _mm512_castsi512_ps(_mm512_set1_epi32(0xc2aeac50)); + const __m512 vec_ln_flt_max = + _mm512_castsi512_ps(_mm512_set1_epi32(0x42b17218)); + const __m512i vec_127 = _mm512_set1_epi32(0x0000007f); + const int n_mantissa_bits = 23; + + // exp(x) = + // = exp(n * ln(2) + r) // divide x by ln(2) and get quot and rem + // = 2^n * exp(r) // simplify the exp(n*ln(2)) expression + + auto less_ln_flt_min_mask = + _mm512_cmp_ps_mask(values, vec_ln_flt_min, 1 /*_CMP_LT_OS*/); + auto vec_src = _mm512_min_ps(values, vec_ln_flt_max); + vec_src = _mm512_max_ps(vec_src, vec_ln_flt_min); + + // fx = floorf(x * log2ef + 0.5) + auto vec_fx = _mm512_fmadd_ps(vec_src, vec_exp_log2ef, vec_half); + auto vec_fx_i = _mm512_cvt_roundps_epi32( + vec_fx, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC); + vec_fx = _mm512_cvtepi32_ps(vec_fx_i); + + // x = x - fx * ln2 + auto vec_exp_poly = _mm512_fnmadd_ps(vec_fx, vec_ln2f, vec_src); + + // compute polynomial + auto vec_res = + _mm512_fmadd_ps(vec_exp_poly, vec_factorial_5, vec_factorial_4); + vec_res = _mm512_fmadd_ps(vec_exp_poly, vec_res, vec_factorial_3); + vec_res = _mm512_fmadd_ps(vec_exp_poly, vec_res, vec_factorial_2); + vec_res = _mm512_fmadd_ps(vec_exp_poly, vec_res, vec_factorial_1); + vec_res = _mm512_fmadd_ps(vec_exp_poly, vec_res, vec_one); + + // compute 2^(n-1) + auto vec_exp_number = _mm512_sub_ps(vec_fx, vec_one); + auto vec_exp_number_i = _mm512_cvtps_epi32(vec_exp_number); + auto vec_two_pow_n_i = _mm512_add_epi32(vec_exp_number_i, vec_127); + vec_two_pow_n_i = _mm512_slli_epi32(vec_two_pow_n_i, n_mantissa_bits); + auto vec_two_pow_n = _mm512_castsi512_ps(vec_two_pow_n_i); + vec_two_pow_n = + _mm512_mask_blend_ps(less_ln_flt_min_mask, vec_two_pow_n, vec_zero); + + // y = y * 2^n + vec_res = _mm512_mul_ps(vec_res, vec_two_pow_n); + vec_res = _mm512_mul_ps(vec_res, vec_two); + return vec_res; + } + Vectorized fmod(const Vectorized& q) const { + return Vectorized(Sleef_fmodf16(values, q)); + } + Vectorized log() const { + return Vectorized(Sleef_logf16_u10(values)); + } + Vectorized log2() const { + return Vectorized(Sleef_log2f16_u10(values)); + } + Vectorized log10() const { + return Vectorized(Sleef_log10f16_u10(values)); + } + Vectorized log1p() const { + return Vectorized(Sleef_log1pf16_u10(values)); + } + Vectorized frac() const; + Vectorized sin() const { + return Vectorized(Sleef_sinf16_u35(values)); + } + Vectorized sinh() const { + return Vectorized(Sleef_sinhf16_u10(values)); + } + Vectorized cos() const { + return Vectorized(Sleef_cosf16_u35(values)); + } + Vectorized cosh() const { + return Vectorized(Sleef_coshf16_u10(values)); + } + Vectorized ceil() const { + return _mm512_ceil_ps(values); + } + Vectorized floor() const { + return _mm512_floor_ps(values); + } + Vectorized hypot(const Vectorized& b) const { + return Vectorized(Sleef_hypotf16_u05(values, b)); + } + Vectorized i0() const { + return map(calc_i0); + } + Vectorized i0e() const { + return map(calc_i0e); + } + Vectorized digamma() const { + return map(calc_digamma); + } + Vectorized igamma(const Vectorized& x) const { + __at_align__ float tmp[size()]; + __at_align__ float tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (const auto i : c10::irange(size())) { + tmp[i] = calc_igamma(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } + Vectorized igammac(const Vectorized& x) const { + __at_align__ float tmp[size()]; + __at_align__ float tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (const auto i : c10::irange(size())) { + tmp[i] = calc_igammac(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } + Vectorized neg() const { + return _mm512_xor_ps(_mm512_set1_ps(-0.f), values); + } + Vectorized nextafter(const Vectorized& b) const { + return Vectorized(Sleef_nextafterf16(values, b)); + } + Vectorized round() const { + return _mm512_roundscale_ps( + values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + } + Vectorized tan() const { + return Vectorized(Sleef_tanf16_u10(values)); + } + Vectorized tanh() const { + return Vectorized(Sleef_tanhf16_u10(values)); + } + Vectorized trunc() const { + return _mm512_roundscale_ps( + values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); + } + Vectorized lgamma() const { + return Vectorized(Sleef_lgammaf16_u10(values)); + } + Vectorized sqrt() const { + return _mm512_sqrt_ps(values); + } + Vectorized reciprocal() const { + return _mm512_div_ps(_mm512_set1_ps(1), values); + } + Vectorized rsqrt() const { + return _mm512_div_ps(_mm512_set1_ps(1), _mm512_sqrt_ps(values)); + } + Vectorized pow(const Vectorized& b) const { + return Vectorized(Sleef_powf16_u10(values, b)); + } + float reduce_add() const { + return _mm512_reduce_add_ps(values); + } + float reduce_max() const { + return _mm512_reduce_max_ps(values); + } + // Comparison using the _CMP_**_OQ predicate. + // `O`: get false if an operand is NaN + // `Q`: do not raise if an operand is NaN + Vectorized operator==(const Vectorized& other) const { + auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_EQ_OQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, mask, 0xFFFFFFFF)); + } + + Vectorized operator!=(const Vectorized& other) const { + auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_NEQ_UQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, mask, 0xFFFFFFFF)); + } + + Vectorized operator<(const Vectorized& other) const { + auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_LT_OQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, mask, 0xFFFFFFFF)); + } + + Vectorized operator<=(const Vectorized& other) const { + auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_LE_OQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, mask, 0xFFFFFFFF)); + } + + Vectorized operator>(const Vectorized& other) const { + auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_GT_OQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, mask, 0xFFFFFFFF)); + } + + Vectorized operator>=(const Vectorized& other) const { + auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_GE_OQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, mask, 0xFFFFFFFF)); + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return _mm512_add_ps(a, b); +} + +template <> +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return _mm512_sub_ps(a, b); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return _mm512_mul_ps(a, b); +} + +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return _mm512_div_ps(a, b); +} + +// frac. Implement this here so we can use subtraction +inline Vectorized Vectorized::frac() const { + return *this - this->trunc(); +} + +// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + auto zero_vec = _mm512_set1_epi32(0); + auto max = _mm512_max_ps(a, b); + auto isnan_mask = _mm512_cmp_ps_mask(a, b, _CMP_UNORD_Q); + auto isnan = _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, isnan_mask, 0xFFFFFFFF)); + // Exploit the fact that all-ones is a NaN. + return _mm512_or_ps(max, isnan); +} + +// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + auto zero_vec = _mm512_set1_epi32(0); + auto min = _mm512_min_ps(a, b); + auto isnan_mask = _mm512_cmp_ps_mask(a, b, _CMP_UNORD_Q); + auto isnan = _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, isnan_mask, 0xFFFFFFFF)); + // Exploit the fact that all-ones is a NaN. + return _mm512_or_ps(min, isnan); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min, + const Vectorized& max) { + return _mm512_min_ps(max, _mm512_max_ps(min, a)); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max) { + return _mm512_min_ps(max, a); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min) { + return _mm512_max_ps(min, a); +} + +template <> +Vectorized inline operator&( + const Vectorized& a, + const Vectorized& b) { + return _mm512_and_ps(a, b); +} + +template <> +Vectorized inline operator|( + const Vectorized& a, + const Vectorized& b) { + return _mm512_or_ps(a, b); +} + +template <> +Vectorized inline operator^( + const Vectorized& a, + const Vectorized& b) { + return _mm512_xor_ps(a, b); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1.0f); +} + +template <> +inline void convert(const float* src, float* dst, int64_t n) { + int64_t i; +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (i = 0; i <= (n - Vectorized::size()); + i += Vectorized::size()) { + _mm512_storeu_ps(dst + i, _mm512_loadu_ps(src + i)); + } +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (; i < n; i++) { + dst[i] = src[i]; + } +} + +template <> +Vectorized inline fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return _mm512_fmadd_ps(a, b, c); +} + +template <> +Vectorized inline fnmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return _mm512_fnmadd_ps(a, b, c); +} + +template <> +Vectorized inline fmsub( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return _mm512_fmsub_ps(a, b, c); +} + +template <> +Vectorized inline fnmsub( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return _mm512_fnmsub_ps(a, b, c); +} + +// TODO: rewrite with ATEN vectorized (need to add unpack and shuffle) +// Used by Inductor CPP codegen for micro gemm +// Code referred to FBGEMM: +// https://github.com/pytorch/FBGEMM/blob/39a423e4ad1a04b77fea81c7d09c3e6f8984fae9/src/UtilsAvx512.cc#L230-L304 +// kernel for transposing mxn where m, n <= 16 +// (M + 1) / 2 * 2 + (M + 3) / 4 * 4 + (M + 7) / 8 * 8 + N instructions +inline void transpose_block( + at::vec::VectorizedN& input, + int M = 16, + int N = 16) { + TORCH_CHECK(M <= 16 && N <= 16, "transpose_block expects M, N <= 16."); + // unpacking and interleaving 32-bit elements + __m512 temp[16]; + int i; + for (i = 0; i < (M + 1) / 2; ++i) { + temp[2 * i] = _mm512_unpacklo_ps(input[2 * i], input[2 * i + 1]); + temp[2 * i + 1] = _mm512_unpackhi_ps(input[2 * i], input[2 * i + 1]); + } + for (i = i * 2; i < 16; ++i) { + temp[i] = _mm512_setzero_ps(); + } + + // unpacking and interleaving 64-bit elements + for (i = 0; i < (M + 3) / 4; ++i) { + input[4 * i] = _mm512_castpd_ps(_mm512_unpacklo_pd( + _mm512_castps_pd(temp[4 * i]), _mm512_castps_pd(temp[4 * i + 2]))); + input[4 * i + 1] = _mm512_castpd_ps(_mm512_unpackhi_pd( + _mm512_castps_pd(temp[4 * i]), _mm512_castps_pd(temp[4 * i + 2]))); + input[4 * i + 2] = _mm512_castpd_ps(_mm512_unpacklo_pd( + _mm512_castps_pd(temp[4 * i + 1]), _mm512_castps_pd(temp[4 * i + 3]))); + input[4 * i + 3] = _mm512_castpd_ps(_mm512_unpackhi_pd( + _mm512_castps_pd(temp[4 * i + 1]), _mm512_castps_pd(temp[4 * i + 3]))); + } + + // shuffle 128-bits (composed of 4 32-bit elements) + for (i = 0; i < (M + 7) / 8; ++i) { + temp[8 * i] = _mm512_shuffle_f32x4(input[8 * i], input[8 * i + 4], 0x88); + temp[8 * i + 1] = + _mm512_shuffle_f32x4(input[8 * i + 1], input[8 * i + 5], 0x88); + temp[8 * i + 2] = + _mm512_shuffle_f32x4(input[8 * i + 2], input[8 * i + 6], 0x88); + temp[8 * i + 3] = + _mm512_shuffle_f32x4(input[8 * i + 3], input[8 * i + 7], 0x88); + temp[8 * i + 4] = + _mm512_shuffle_f32x4(input[8 * i], input[8 * i + 4], 0xdd); + temp[8 * i + 5] = + _mm512_shuffle_f32x4(input[8 * i + 1], input[8 * i + 5], 0xdd); + temp[8 * i + 6] = + _mm512_shuffle_f32x4(input[8 * i + 2], input[8 * i + 6], 0xdd); + temp[8 * i + 7] = + _mm512_shuffle_f32x4(input[8 * i + 3], input[8 * i + 7], 0xdd); + } + + for (i = 0; i < N; ++i) { + if (i < 8) { + input[i] = _mm512_shuffle_f32x4(temp[i], temp[8 + i], 0x88); + } else { + input[i] = _mm512_shuffle_f32x4(temp[i - 8], temp[i], 0xdd); + } + } +} + +// TODO(jgong5): rewrite with ATEN vectorized (need to add unpack and shuffle) +// Used by Inductor CPP codegen +// Code referred to FBGEMM: +// https://github.com/pytorch/FBGEMM/blob/39a423e4ad1a04b77fea81c7d09c3e6f8984fae9/src/UtilsAvx512.cc#L230-L304 +// kernel for transposing mxn where m, n <= 16 +// M + (M + 1) / 2 * 2 + (M + 3) / 4 * 4 + (M + 7) / 8 * 8 + 2 * N instructions +inline void transpose_mxn_16x16( + const float* src, + int64_t ld_src, + float* dst, + int64_t ld_dst, + int M, + int N) { + TORCH_CHECK(M <= 16 && N <= 16, "transpose_mxn expects M, N <= 16."); + // load from src to registers + at::vec::VectorizedN input; + int i; + if (N == 16) { + for (i = 0; i < M; ++i) { + input[i] = _mm512_loadu_ps(&src[i * ld_src]); + } + } else { + __mmask16 src_mask = (1 << N) - 1; + for (i = 0; i < M; ++i) { + input[i] = _mm512_maskz_loadu_ps(src_mask, &src[i * ld_src]); + } + } + for (; i < 16; ++i) { + // Not really needed but to avoid uninitialized variable warning. + // Shouldn't be much overhead because xor can be executed in parallel with + // other instructions. + input[i] = _mm512_setzero_ps(); + } + + transpose_block(input, M, N); + + // store from registers to dst + if (M == 16) { + for (i = 0; i < N; ++i) { + _mm512_storeu_ps(&dst[i * ld_dst], input[i]); + } + } else { + __mmask16 dst_mask = (1 << M) - 1; + for (i = 0; i < N; ++i) { + _mm512_mask_storeu_ps(&dst[i * ld_dst], dst_mask, input[i]); + } + } +} + +template <> +inline void transpose_mxn( + const float* src, + int64_t ld_src, + float* dst, + int64_t ld_dst, + int M, + int N) { + int64_t i = 0; + for (; i < M / 16 * 16; i += 16) { + int64_t j = 0; + for (; j < N / 16 * 16; j += 16) { + transpose_mxn_16x16( + src + i * ld_src + j, ld_src, dst + j * ld_dst + i, ld_dst, 16, 16); + } + // handle remainder j + int nrem = N - j; + if (nrem > 0) { + transpose_mxn_16x16( + src + i * ld_src + j, ld_src, dst + j * ld_dst + i, ld_dst, 16, nrem); + } + } + // handle remainder i + int mrem = M - i; + if (mrem > 0) { + int j = 0; + for (; j < N / 16 * 16; j += 16) { + transpose_mxn_16x16( + src + i * ld_src + j, ld_src, dst + j * ld_dst + i, ld_dst, mrem, 16); + } + // handle remainder j + int nrem = N - j; + transpose_mxn_16x16( + src + i * ld_src + j, ld_src, dst + j * ld_dst + i, ld_dst, mrem, nrem); + } +} + +template < + typename T, + int M, + int N, + typename std::enable_if_t, int> = 0> +inline void transpose_mxn( + const float* src, + int64_t ld_src, + float* dst, + int64_t ld_dst) { + transpose_mxn(src, ld_src, dst, ld_dst, M, N); +} + +#endif + +} // namespace CPU_CAPABILITY +} // namespace at::vec + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_float8.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_float8.h new file mode 100644 index 0000000000000000000000000000000000000000..b0aa8e3a05cd29529145415da9ba08f356e24d7e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_float8.h @@ -0,0 +1,666 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include +#if (defined(CPU_CAPABILITY_AVX512)) +#define SLEEF_STATIC_LIBS +#include +#endif + +namespace at::vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER) + +static inline void cvtfp8e4m3_fp32(const __m128i& a, __m512& o) { + // Zero Extend + __m512i x = _mm512_cvtepu8_epi32(a); + __m512i val = _mm512_and_epi32( + _mm512_slli_epi32(x, 24), _mm512_set1_epi32(0x7FFFFFFF)); // nonsign_val + __m512i mant = + _mm512_and_si512(x, _mm512_set1_epi32(0x07)); // mantissa = x & 0x07 + __m512i exp = _mm512_and_si512( + _mm512_srli_epi32(x, 3), + _mm512_set1_epi32(0x0F)); // exp = (x >> 3) & 0x0F + __m512i sign = + _mm512_and_si512(x, _mm512_set1_epi32(0x80)); // sign = x & 0x80 + __m512i _zeros = _mm512_setzero_si512(); + + // --- Step 1: Calculate the renorm_shift + __m512i renorm_shift = _zeros; + // Denorm case (exp == 0 && mant != 0) --- + __mmask16 denormal_mask = _mm512_cmpeq_epi32_mask(exp, _zeros) & + _mm512_cmpneq_epi32_mask(mant, _zeros); + if (denormal_mask) { + // An alternative solution is as what scalar did in + // pytorch/c10/util/Float8_e4m3fn.h To count the num of leading zeros, since + // here we know the unsigned denorm value has zero sign and exp which is 5 + // leading zeros, we need to count the leading zero of mant (3bit) which may + // done through table lookup for example: const uint8_t lz_table[8] = {3, 2, + // 1, 1, 0, 0, 0, 0}; num_leading_zero = lz_table[mant] + 5; + + __m512i _ones = _mm512_set1_epi32(1); + __m512i _twos = _mm512_set1_epi32(2); + __m512i _threes = _mm512_set1_epi32(3); + + // Default leading zero number for denorm value is 1 = 5 - 4 + __m512i denorm_renorm_shift = _ones; + // For mant 001, leading zero number is 3 = 7 -4 + __mmask16 leading_Zero_mask = _mm512_cmpeq_epi32_mask(mant, _ones); + denorm_renorm_shift = + _mm512_mask_mov_epi32(denorm_renorm_shift, leading_Zero_mask, _threes); + // For mant 010 and 011, leading zero number is 2 = 6 -4 + leading_Zero_mask = _mm512_cmpeq_epi32_mask(mant, _twos); + denorm_renorm_shift = + _mm512_mask_mov_epi32(denorm_renorm_shift, leading_Zero_mask, _twos); + leading_Zero_mask = _mm512_cmpeq_epi32_mask(mant, _threes); + denorm_renorm_shift = + _mm512_mask_mov_epi32(denorm_renorm_shift, leading_Zero_mask, _twos); + + renorm_shift = + _mm512_mask_mov_epi32(renorm_shift, denormal_mask, denorm_renorm_shift); + } + + // --- Step 2: calculate norm and denorm --- + __m512i norm_shifted = + _mm512_srli_epi32(_mm512_sllv_epi32(val, renorm_shift), 4); + // exponent bias adjustment: (0x78 - renorm_shift) << 23 + __m512i exp_bias = _mm512_slli_epi32( + _mm512_sub_epi32(_mm512_set1_epi32(0x78), renorm_shift), 23); + val = _mm512_add_epi32(norm_shifted, exp_bias); + + // --- Step 3: Nan case (exp == 0xF && mant == 0x07) --- + __mmask16 nan_mask = _mm512_cmpeq_epi32_mask(exp, _mm512_set1_epi32(0xF)) & + _mm512_cmpeq_epi32_mask(mant, _mm512_set1_epi32(0x07)); + if (nan_mask) { + const __m512i nan_values = _mm512_set1_epi32(0x7FC00000); + val = _mm512_mask_mov_epi32(val, nan_mask, nan_values); + } + + // --- Step 4: Zero case (exp == 0x00 && mant == 0x00) --- + __mmask16 zero_mask = _mm512_cmpeq_epi32_mask(exp, _zeros) & + _mm512_cmpeq_epi32_mask(mant, _zeros); + if (zero_mask) { + val = _mm512_mask_mov_epi32(val, zero_mask, _zeros); + } + + // --- Step 5: OR with sign (sign bit << 24 to get to bit 31) --- + val = _mm512_or_si512(val, _mm512_slli_epi32(sign, 24)); + + o = _mm512_castsi512_ps(val); +} + +static inline __m128i cvtfp32_fp8e4m3(const __m512& src) { + // cvt 16x32 from fp32 to fp8 e4m3 + const __m512i sign_mask = _mm512_set1_epi32(0x80000000); + const __m512i fp8_max = _mm512_set1_epi32(UINT32_C(1087) << 20); + const __m512i denorm_thresh = _mm512_set1_epi32(UINT32_C(121) << 23); + const __m512i denorm_mask = _mm512_set1_epi32(UINT32_C(141) << 23); + const __m512i bias_part1 = _mm512_set1_epi32((uint32_t)(7 - 127) << 23); + const __m512i rounding_bias = _mm512_set1_epi32(0x7FFFF); + __m512i f_bits = _mm512_castps_si512(src); + // Extract and save sign + __m512i sign = _mm512_and_epi32(f_bits, sign_mask); + f_bits = _mm512_xor_epi32(f_bits, sign); + + // Prepare result containers + __m512i result = _mm512_setzero_si512(); + + // Step 1: Handle case of overflow + // (f_bits >= fp8_max): set result = 0x7f + __mmask16 overflow_mask = _mm512_cmpge_epu32_mask(f_bits, fp8_max); + if (overflow_mask) { + result = _mm512_mask_set1_epi32(result, overflow_mask, 0x7f); + } + + // Step 2: Handle small numbers (denormals) + // Small numbers (f_bits < denorm_thresh) + __mmask16 denorm_thresh_mask = _mm512_cmplt_epu32_mask(f_bits, denorm_thresh); + + if (denorm_thresh_mask) { + __m512 small_input = _mm512_castsi512_ps(f_bits); + __m512 small_denorm = + _mm512_add_ps(small_input, _mm512_castsi512_ps(denorm_mask)); + __m512i small_denorm_bits = _mm512_castps_si512(small_denorm); + __m512i small_result = _mm512_sub_epi32(small_denorm_bits, denorm_mask); + result = _mm512_mask_mov_epi32(result, denorm_thresh_mask, small_result); + } + + // Step 3: Handle normal numbers + __mmask16 normal_mask = ~(overflow_mask | denorm_thresh_mask); + + if (normal_mask) { + // mant_odd = (f_bits >> 20) & 1 + __m512i mant_odd = + _mm512_and_epi32(_mm512_srli_epi32(f_bits, 20), _mm512_set1_epi32(1)); + // f_bits += bias_part1 + rounding_bias + __m512i rounded = _mm512_add_epi32(f_bits, bias_part1); + rounded = _mm512_add_epi32(rounded, rounding_bias); + // Add mant_odd + rounded = _mm512_add_epi32(rounded, mant_odd); + // Shift right by 20 bits + __m512i normal_result = _mm512_srli_epi32(rounded, 20); + result = _mm512_mask_mov_epi32(result, normal_mask, normal_result); + } + + // Merge back the sign + __m512i sign_shifted = _mm512_srli_epi32(sign, 24); + result = _mm512_or_epi32(result, sign_shifted); + + // Now result is 16 x 32-bit integers, but we only need 8-bit for each + __m512i packed = _mm512_and_si512(result, _mm512_set1_epi32(0xFF)); + + // Narrow 32-bit integers to 8-bit + return _mm512_cvtepi32_epi8(packed); +} + +static inline float fp8e4m3_to_fp32_scalar(uint8_t val) { + __m512i v = _mm512_set1_epi8(val); + __m128i v_128 = _mm512_castsi512_si128(v); + __m512 o; + cvtfp8e4m3_fp32(v_128, o); + return _mm512_cvtss_f32(o); +} + +static inline uint8_t fp32_to_fp8e4m3_scalar(float val) { + __m512 v = _mm512_set1_ps(val); + __m128i o = cvtfp32_fp8e4m3(v); + return static_cast(_mm_cvtsi128_si32(o)); +} + +static inline void cvtfp8e5m2_fp32(const __m128i& a, __m512& o) { + __m256i a_256 = _mm256_castsi128_si256(a); + __m512i a_512 = _mm512_cvtepu8_epi16(a_256); + a_512 = _mm512_slli_epi16(a_512, 8); + a_256 = _mm512_castsi512_si256(a_512); + cvtfp16_fp32(a_256, o); +} + +static inline __m128i cvtfp32_fp8e5m2(const __m512& src) { + constexpr uint32_t fp32_inf = UINT32_C(255) << 23; + constexpr uint32_t fp8_max = UINT32_C(143) << 23; + constexpr uint32_t denorm_mask = UINT32_C(134) << 23; + + // Cvt to bits + __m512i input_bits = _mm512_castps_si512(src); + __m512i result = _mm512_setzero_si512(); + + // Get the sign + __m512i sign = _mm512_and_si512(input_bits, _mm512_set1_epi32(0x80000000)); + + // Get the unsigned input + input_bits = _mm512_xor_si512(input_bits, sign); + + // Calculate the mask for inf, nan and denorm + __mmask16 greater_than_fp8_max = + _mm512_cmpge_epi32_mask(input_bits, _mm512_set1_epi32(fp8_max)); + __mmask16 greater_than_fp32_inf = + _mm512_cmpgt_epi32_mask(input_bits, _mm512_set1_epi32(fp32_inf)); + __mmask16 less_than_normal = _mm512_cmpgt_epi32_mask( + _mm512_set1_epi32((UINT32_C(113) << 23)), input_bits); + __m512i temp_bits_for_denorm = _mm512_setzero_si512(); + if (less_than_normal) { + __m512i denorm_mask_512i = _mm512_set1_epi32(denorm_mask); + temp_bits_for_denorm = _mm512_castps_si512(_mm512_add_ps( + _mm512_castsi512_ps(input_bits), + _mm512_castsi512_ps(denorm_mask_512i))); + temp_bits_for_denorm = + _mm512_sub_epi32(temp_bits_for_denorm, denorm_mask_512i); + } + + // Step 1: Norm Val + __m512i mant_odd_mask = + _mm512_and_epi32(_mm512_srli_epi32(input_bits, 21), _mm512_set1_epi32(1)); + input_bits = _mm512_add_epi32( + input_bits, _mm512_set1_epi32(((uint32_t)(15 - 127) << 23) + 0xFFFFF)); + input_bits = _mm512_add_epi32(input_bits, mant_odd_mask); + result = _mm512_srli_epi32(input_bits, 21); + + // Step 2: INF and NAN + if (greater_than_fp8_max) { + result = _mm512_mask_mov_epi32( + result, greater_than_fp8_max, _mm512_set1_epi8(0x7C)); + if (greater_than_fp32_inf) { + result = _mm512_mask_mov_epi32( + result, greater_than_fp32_inf, _mm512_set1_epi8(0x7F)); + } + } + + // Step 3: Denorm val + if (less_than_normal) { + result = + _mm512_mask_mov_epi32(result, less_than_normal, temp_bits_for_denorm); + } + + // Step 4: restore sign + result = _mm512_or_si512(result, _mm512_srli_epi32(sign, 24)); + + return _mm512_cvtepi32_epi8(result); +} + +static inline float fp8e5m2_to_fp32_scalar(uint8_t val) { + __m512i v = _mm512_set1_epi8(val); + __m128i v_128 = _mm512_castsi512_si128(v); + __m512 o; + cvtfp8e5m2_fp32(v_128, o); + return _mm512_cvtss_f32(o); +} + +static inline uint8_t fp32_to_fp8e5m2_scalar(float val) { + __m512 v = _mm512_set1_ps(val); + __m128i o = cvtfp32_fp8e5m2(v); + return static_cast(_mm_cvtsi128_si32(o)); +} + +template +class Vectorizedf8 { + static_assert( + std::integral_constant < bool, + std::is_same_v || std::is_same_v < T, + at::Float8_e5m2 >> ::value, + "Support only float8 e4m3."); + + private: + __m512i values; + template + Vectorized inline binary_compare(const VectorizedType& b, Op op) const { + __m512 a0, a1, a2, a3; + __m512 b0, b1, b2, b3; + __m512 o0, o1, o2, o3; + if constexpr (std::is_same_v) { + cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(values, 0), a0); + cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(b.values, 0), b0); + cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(values, 1), a1); + cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(b.values, 1), b1); + cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(values, 2), a2); + cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(b.values, 2), b2); + cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(values, 3), a3); + cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(b.values, 3), b3); + } else { + cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(values, 0), a0); + cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(b.values, 0), b0); + cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(values, 1), a1); + cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(b.values, 1), b1); + cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(values, 2), a2); + cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(b.values, 2), b2); + cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(values, 3), a3); + cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(b.values, 3), b3); + } + + o0 = op(a0, b0); + o1 = op(a1, b1); + o2 = op(a2, b2); + o3 = op(a3, b3); + __m128i o128_0, o128_1, o128_2, o128_3; + if constexpr (std::is_same_v) { + o128_0 = cvtfp32_fp8e4m3(o0); + o128_1 = cvtfp32_fp8e4m3(o1); + o128_2 = cvtfp32_fp8e4m3(o2); + o128_3 = cvtfp32_fp8e4m3(o3); + } else { + o128_0 = cvtfp32_fp8e5m2(o0); + o128_1 = cvtfp32_fp8e5m2(o1); + o128_2 = cvtfp32_fp8e5m2(o2); + o128_3 = cvtfp32_fp8e5m2(o3); + } + + __m512i result = _mm512_setzero_si512(); + result = _mm512_inserti32x4(result, o128_0, 0); + result = _mm512_inserti32x4(result, o128_1, 1); + result = _mm512_inserti32x4(result, o128_2, 2); + result = _mm512_inserti32x4(result, o128_3, 3); + + return result; + } + + public: + using value_type = uint8_t; + using size_type = int; + static constexpr size_type size() { + return 64; + } + Vectorizedf8() {} + Vectorizedf8(__m512i v) : values(v) {} + Vectorizedf8(T val) { + value_type uw = val.x; + values = _mm512_set1_epi8(uw); + } + operator __m512i() const { + return values; + } + T& operator[](int idx) = delete; + const T& operator[](int idx) const = delete; + static Vectorized loadu(const void* ptr, int16_t count = size()) { + if (count == size()) { + return _mm512_loadu_si512(reinterpret_cast(ptr)); + } else if (count == 16) { + // Fast path if only load element number of 16 + __m128i input_128 = + _mm_loadu_si128(reinterpret_cast(ptr)); + return _mm512_castsi128_si512(input_128); + } else { + __mmask64 mask = (1ULL << count) - 1; + return _mm512_maskz_loadu_epi8(mask, ptr); + } + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + _mm512_storeu_si512(reinterpret_cast<__m512i*>(ptr), values); + } else if (count > 0) { + if (count == 16) { + // Fast path if only store element number of 16 + _mm_storeu_si128( + reinterpret_cast<__m128i*>(ptr), _mm512_castsi512_si128(values)); + } else { + __mmask64 mask = (1ULL << count) - 1; + _mm512_mask_storeu_epi8(ptr, mask, values); + } + } + } + + Vectorized abs() const { + return _mm512_andnot_si512(_mm512_set1_epi8(0x80), values); + } + + Vectorized inline operator==(const Vectorizedf8& other) const { + return binary_compare(other, [](__m512 x, __m512 y) { + auto zero_vec = _mm512_set1_epi32(0); + auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_EQ_OQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF)); + }); + } + + Vectorized inline operator!=(const Vectorizedf8& other) const { + return binary_compare(other, [](__m512 x, __m512 y) { + auto zero_vec = _mm512_set1_epi32(0); + auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_NEQ_UQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF)); + }); + } + + Vectorized inline operator>(const Vectorizedf8& other) const { + return binary_compare(other, [](__m512 x, __m512 y) { + auto zero_vec = _mm512_set1_epi32(0); + auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_GT_OQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF)); + }); + } + + Vectorized inline operator>=(const Vectorizedf8& other) const { + return binary_compare(other, [](__m512 x, __m512 y) { + auto zero_vec = _mm512_set1_epi32(0); + auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_GE_OQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF)); + }); + } + + Vectorized inline operator<(const Vectorizedf8& other) const { + return binary_compare(other, [](__m512 x, __m512 y) { + auto zero_vec = _mm512_set1_epi32(0); + auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_LT_OQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF)); + }); + } + + Vectorized inline operator<=(const Vectorizedf8& other) const { + return binary_compare(other, [](__m512 x, __m512 y) { + auto zero_vec = _mm512_set1_epi32(0); + auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_LE_OQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF)); + }); + } +}; + +template <> +class Vectorized : public Vectorizedf8 { + public: + using Vectorizedf8::Vectorizedf8; + + using value_type = Float8_e4m3fn; + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template < + typename T, + typename Op, + std::enable_if_t< + std::is_same_v || + std::is_same_v, + int> = 0> +static inline Vectorized binary_fp8_op_as_fp32( + const Vectorized& a, + const Vectorized& b, + Op op) { + __m512 a0, a1, a2, a3; + __m512 b0, b1, b2, b3; + __m512 o0, o1, o2, o3; + if constexpr (std::is_same_v) { + cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(a, 0), a0); + cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(b, 0), b0); + cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(a, 1), a1); + cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(b, 1), b1); + cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(a, 2), a2); + cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(b, 2), b2); + cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(a, 3), a3); + cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(b, 3), b3); + } else { + cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(a, 0), a0); + cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(b, 0), b0); + cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(a, 1), a1); + cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(b, 1), b1); + cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(a, 2), a2); + cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(b, 2), b2); + cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(a, 3), a3); + cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(b, 3), b3); + } + o0 = op(a0, b0); + o1 = op(a1, b1); + o2 = op(a2, b2); + o3 = op(a3, b3); + + __m128i o128_0, o128_1, o128_2, o128_3; + if constexpr (std::is_same_v) { + o128_0 = cvtfp32_fp8e4m3(o0); + o128_1 = cvtfp32_fp8e4m3(o1); + o128_2 = cvtfp32_fp8e4m3(o2); + o128_3 = cvtfp32_fp8e4m3(o3); + } else { + o128_0 = cvtfp32_fp8e5m2(o0); + o128_1 = cvtfp32_fp8e5m2(o1); + o128_2 = cvtfp32_fp8e5m2(o2); + o128_3 = cvtfp32_fp8e5m2(o3); + } + + __m512i result = _mm512_setzero_si512(); + result = _mm512_inserti32x4(result, o128_0, 0); + result = _mm512_inserti32x4(result, o128_1, 1); + result = _mm512_inserti32x4(result, o128_2, 2); + result = _mm512_inserti32x4(result, o128_3, 3); + + return result; +} + +// Refer to +// https://github.com/pytorch/pytorch/pull/153364#discussion_r2086509353 FP8 +, +// -, *, /, planned to be deleted in the future and here is just to make +// compiler happy +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return binary_fp8_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { + return _mm512_add_ps(x, y); + }); +} + +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return binary_fp8_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { + return _mm512_sub_ps(x, y); + }); +} + +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return binary_fp8_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { + return _mm512_mul_ps(x, y); + }); +} + +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return binary_fp8_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { + return _mm512_div_ps(x, y); + }); +} + +Vectorized inline operator&( + const Vectorized& a, + const Vectorized& b) { + return _mm512_and_si512(a, b); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this == other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1.0f); +} + +template <> +class Vectorized : public Vectorizedf8 { + public: + using Vectorizedf8::Vectorizedf8; + + using value_type = Float8_e5m2; + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +// Refer to +// https://github.com/pytorch/pytorch/pull/153364#discussion_r2086509353 FP8 +, +// -, *, /, planned to be deleted in the future and here is just to make +// compiler happy +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return binary_fp8_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { + return _mm512_add_ps(x, y); + }); +} + +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return binary_fp8_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { + return _mm512_sub_ps(x, y); + }); +} + +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return binary_fp8_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { + return _mm512_mul_ps(x, y); + }); +} + +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return binary_fp8_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { + return _mm512_div_ps(x, y); + }); +} + +Vectorized inline operator&( + const Vectorized& a, + const Vectorized& b) { + return _mm512_and_si512(a, b); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this == other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1.0f); +} + +#endif + +} // namespace CPU_CAPABILITY +} // namespace at::vec + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_int.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_int.h new file mode 100644 index 0000000000000000000000000000000000000000..2044a199105a3dfe76e9fda09acc68251510651b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_int.h @@ -0,0 +1,2126 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include +#include +#include + +namespace at::vec { +inline namespace CPU_CAPABILITY { + +#ifdef CPU_CAPABILITY_AVX512 + +struct Vectorizedi { + protected: + __m512i values; + static constexpr __m512i zero_vector{0, 0, 0, 0, 0, 0, 0, 0}; + static inline __m512i invert(const __m512i& v) { + const auto ones = _mm512_set1_epi64(-1); + return _mm512_xor_si512(ones, v); + } + + public: + Vectorizedi() {} + Vectorizedi(__m512i v) : values(v) {} + operator __m512i() const { + return values; + } +}; + +#else + +struct Vectorizedi {}; // dummy definition to make Vectorizedi always defined + +#endif // CPU_CAPABILITY_AVX512 + +#ifdef CPU_CAPABILITY_AVX512 + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized : public Vectorizedi { + private: + static const Vectorized ones; + + public: + using value_type = int64_t; + using size_type = int; + static constexpr size_type size() { + return 8; + } + using Vectorizedi::Vectorizedi; + Vectorized() { + values = _mm512_setzero_si512(); + } + Vectorized(int64_t v) { + values = _mm512_set1_epi64(v); + } + Vectorized( + int64_t val1, + int64_t val2, + int64_t val3, + int64_t val4, + int64_t val5, + int64_t val6, + int64_t val7, + int64_t val8) { + values = _mm512_setr_epi64(val1, val2, val3, val4, val5, val6, val7, val8); + } + template + static Vectorized blend( + Vectorized a, + Vectorized b) { + return _mm512_mask_blend_epi64(mask, a.values, b.values); + } + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + auto msb_one = _mm512_set1_epi64(0xFFFFFFFFFFFFFFFF); + auto mask_ = _mm512_cmp_epi64_mask(mask, msb_one, _MM_CMPINT_EQ); + return _mm512_mask_blend_epi64(mask_, a.values, b.values); + } + template + static Vectorized arange( + int64_t base = 0, + step_t step = static_cast(1)) { + return Vectorized( + base, + base + step, + base + 2 * step, + base + 3 * step, + base + 4 * step, + base + 5 * step, + base + 6 * step, + base + 7 * step); + } + static Vectorized set( + Vectorized a, + Vectorized b, + int64_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + case 4: + return blend<15>(a, b); + case 5: + return blend<31>(a, b); + case 6: + return blend<63>(a, b); + case 7: + return blend<127>(a, b); + } + return b; + } + static Vectorized loadu(const void* ptr) { + return _mm512_loadu_si512(reinterpret_cast(ptr)); + } + static Vectorized loadu(const void* ptr, int64_t count) { + if (count == size()) { + return _mm512_loadu_si512(reinterpret_cast(ptr)); + } else { + __mmask8 mask = (1ULL << count) - 1; + auto ones = _mm512_set1_epi64(1); + return _mm512_mask_loadu_epi64(ones, mask, ptr); + } + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + // ptr need not to be aligned here. See + // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm512-storeu-si512.html + _mm512_storeu_si512(reinterpret_cast<__m512i*>(ptr), values); + } else if (count > 0) { + __mmask8 mask = (1ULL << count) - 1; + _mm512_mask_storeu_epi64(ptr, mask, values); + } + } + const int64_t& operator[](int idx) const = delete; + int64_t& operator[](int idx) = delete; + Vectorized abs() const { + auto is_larger_mask = _mm512_cmpgt_epi64_mask(zero_vector, values); + auto is_larger = + _mm512_mask_set1_epi64(zero_vector, is_larger_mask, 0xFFFFFFFFFFFFFFFF); + auto inverse = _mm512_xor_si512(values, is_larger); + return _mm512_sub_epi64(inverse, is_larger); + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return _mm512_set1_epi64(0); + } + Vectorized conj() const { + return *this; + } + Vectorized neg() const; + Vectorized operator==(const Vectorized& other) const { + auto mask = _mm512_cmpeq_epi64_mask(values, other.values); + return _mm512_mask_set1_epi64(zero_vector, mask, 0xFFFFFFFFFFFFFFFF); + } + Vectorized operator!=(const Vectorized& other) const { + auto mask = _mm512_cmpneq_epi64_mask(values, other.values); + return _mm512_mask_set1_epi64(zero_vector, mask, 0xFFFFFFFFFFFFFFFF); + } + Vectorized operator<(const Vectorized& other) const { + auto mask = _mm512_cmplt_epi64_mask(values, other.values); + return _mm512_mask_set1_epi64(zero_vector, mask, 0xFFFFFFFFFFFFFFFF); + } + Vectorized operator<=(const Vectorized& other) const { + auto mask = _mm512_cmple_epi64_mask(values, other.values); + return _mm512_mask_set1_epi64(zero_vector, mask, 0xFFFFFFFFFFFFFFFF); + } + Vectorized operator>(const Vectorized& other) const { + auto mask = _mm512_cmpgt_epi64_mask(values, other.values); + return _mm512_mask_set1_epi64(zero_vector, mask, 0xFFFFFFFFFFFFFFFF); + } + Vectorized operator>=(const Vectorized& other) const { + auto mask = _mm512_cmpge_epi64_mask(values, other.values); + return _mm512_mask_set1_epi64(zero_vector, mask, 0xFFFFFFFFFFFFFFFF); + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template <> +struct is_vec_specialized_for : std::bool_constant {}; +template <> +class Vectorized : public Vectorizedi { + private: + static constexpr __m512i zero_vector{0, 0, 0, 0, 0, 0, 0, 0}; + static const Vectorized ones; + + public: + using value_type = int32_t; + static constexpr int size() { + return 16; + } + using Vectorizedi::Vectorizedi; + Vectorized() {} + Vectorized(int32_t v) { + values = _mm512_set1_epi32(v); + } + Vectorized( + int32_t val1, + int32_t val2, + int32_t val3, + int32_t val4, + int32_t val5, + int32_t val6, + int32_t val7, + int32_t val8, + int32_t val9, + int32_t val10, + int32_t val11, + int32_t val12, + int32_t val13, + int32_t val14, + int32_t val15, + int32_t val16) { + values = _mm512_setr_epi32( + val1, + val2, + val3, + val4, + val5, + val6, + val7, + val8, + val9, + val10, + val11, + val12, + val13, + val14, + val15, + val16); + } + template + static Vectorized blend( + Vectorized a, + Vectorized b) { + return _mm512_mask_blend_epi32(mask, a.values, b.values); + } + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + auto msb_one = _mm512_set1_epi32(0xFFFFFFFF); + auto mask_ = _mm512_cmp_epi32_mask(mask, msb_one, _MM_CMPINT_EQ); + return _mm512_mask_blend_epi32(mask_, a.values, b.values); + } + template + static Vectorized arange( + int32_t base = 0, + step_t step = static_cast(1)) { + return Vectorized( + base, + base + step, + base + 2 * step, + base + 3 * step, + base + 4 * step, + base + 5 * step, + base + 6 * step, + base + 7 * step, + base + 8 * step, + base + 9 * step, + base + 10 * step, + base + 11 * step, + base + 12 * step, + base + 13 * step, + base + 14 * step, + base + 15 * step); + } + static Vectorized set( + Vectorized a, + Vectorized b, + int32_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + case 4: + return blend<15>(a, b); + case 5: + return blend<31>(a, b); + case 6: + return blend<63>(a, b); + case 7: + return blend<127>(a, b); + case 8: + return blend<255>(a, b); + case 9: + return blend<511>(a, b); + case 10: + return blend<1023>(a, b); + case 11: + return blend<2047>(a, b); + case 12: + return blend<4095>(a, b); + case 13: + return blend<8191>(a, b); + case 14: + return blend<16383>(a, b); + case 15: + return blend<32767>(a, b); + } + return b; + } + static Vectorized loadu(const void* ptr) { + return _mm512_loadu_si512(reinterpret_cast(ptr)); + } + static Vectorized loadu(const void* ptr, int32_t count) { + if (count == size()) { + return _mm512_loadu_si512(reinterpret_cast(ptr)); + } else { + __mmask16 mask = (1ULL << count) - 1; + auto ones = _mm512_set1_epi32(1); + return _mm512_mask_loadu_epi32(ones, mask, ptr); + } + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + // ptr need not to be aligned here. See + // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm512-storeu-si512.html + _mm512_storeu_si512(reinterpret_cast<__m512i*>(ptr), values); + } else if (count > 0) { + __mmask16 mask = (1ULL << count) - 1; + _mm512_mask_storeu_epi32(ptr, mask, values); + } + } + const int32_t& operator[](int idx) const = delete; + int32_t& operator[](int idx) = delete; + Vectorized abs() const { + return _mm512_abs_epi32(values); + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return _mm512_set1_epi32(0); + } + Vectorized conj() const { + return *this; + } + Vectorized neg() const; + int32_t reduce_add() const { + return _mm512_reduce_add_epi32(values); + } + int32_t reduce_max() const { + return _mm512_reduce_max_epi32(values); + } + Vectorized operator==(const Vectorized& other) const { + auto mask = _mm512_cmpeq_epi32_mask(values, other.values); + return _mm512_mask_set1_epi32(zero_vector, mask, 0xFFFFFFFF); + } + Vectorized operator!=(const Vectorized& other) const { + auto mask = _mm512_cmpneq_epi32_mask(values, other.values); + return _mm512_mask_set1_epi32(zero_vector, mask, 0xFFFFFFFF); + } + Vectorized operator<(const Vectorized& other) const { + auto mask = _mm512_cmplt_epi32_mask(values, other.values); + return _mm512_mask_set1_epi32(zero_vector, mask, 0xFFFFFFFF); + } + Vectorized operator<=(const Vectorized& other) const { + auto mask = _mm512_cmple_epi32_mask(values, other.values); + return _mm512_mask_set1_epi32(zero_vector, mask, 0xFFFFFFFF); + } + Vectorized operator>(const Vectorized& other) const { + auto mask = _mm512_cmpgt_epi32_mask(values, other.values); + return _mm512_mask_set1_epi32(zero_vector, mask, 0xFFFFFFFF); + } + Vectorized operator>=(const Vectorized& other) const { + auto mask = _mm512_cmpge_epi32_mask(values, other.values); + return _mm512_mask_set1_epi32(zero_vector, mask, 0xFFFFFFFF); + } + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template <> +inline void convert(const int32_t* src, float* dst, int64_t n) { + int64_t i; + // int32_t and float have same size +#ifndef _MSC_VER +#pragma unroll +#endif + for (i = 0; i <= (n - Vectorized::size()); + i += Vectorized::size()) { + auto input_vec = + _mm512_loadu_si512(reinterpret_cast(src + i)); + auto output_vec = _mm512_cvtepi32_ps(input_vec); + _mm512_storeu_ps(reinterpret_cast(dst + i), output_vec); + } +#ifndef _MSC_VER +#pragma unroll +#endif + for (; i < n; i++) { + dst[i] = static_cast(src[i]); + } +} + +template <> +inline void convert(const int32_t* src, double* dst, int64_t n) { + int64_t i; + // int32_t has half the size of double +#ifndef _MSC_VER +#pragma unroll +#endif + for (i = 0; i <= (n - Vectorized::size()); + i += Vectorized::size()) { + auto input_256_vec = + _mm256_loadu_si256(reinterpret_cast(src + i)); + auto output_vec = _mm512_cvtepi32_pd(input_256_vec); + _mm512_storeu_pd(reinterpret_cast(dst + i), output_vec); + } +#ifndef _MSC_VER +#pragma unroll +#endif + for (; i < n; i++) { + dst[i] = static_cast(src[i]); + } +} + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized : public Vectorizedi { + private: + static const Vectorized ones; + static constexpr __m512i zero_vector{0, 0, 0, 0, 0, 0, 0, 0}; + + public: + using value_type = int16_t; + static constexpr int size() { + return 32; + } + using Vectorizedi::Vectorizedi; + Vectorized() {} + Vectorized(int16_t v) { + values = _mm512_set1_epi16(v); + } + Vectorized( + int16_t val1, + int16_t val2, + int16_t val3, + int16_t val4, + int16_t val5, + int16_t val6, + int16_t val7, + int16_t val8, + int16_t val9, + int16_t val10, + int16_t val11, + int16_t val12, + int16_t val13, + int16_t val14, + int16_t val15, + int16_t val16, + int16_t val17, + int16_t val18, + int16_t val19, + int16_t val20, + int16_t val21, + int16_t val22, + int16_t val23, + int16_t val24, + int16_t val25, + int16_t val26, + int16_t val27, + int16_t val28, + int16_t val29, + int16_t val30, + int16_t val31, + int16_t val32) { + values = _mm512_set_epi16( + val32, + val31, + val30, + val29, + val28, + val27, + val26, + val25, + val24, + val23, + val22, + val21, + val20, + val19, + val18, + val17, + val16, + val15, + val14, + val13, + val12, + val11, + val10, + val9, + val8, + val7, + val6, + val5, + val4, + val3, + val2, + val1); + } + template + static Vectorized blend( + Vectorized a, + Vectorized b) { + return _mm512_mask_blend_epi16(mask, a.values, b.values); + } + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + auto msb_one = _mm512_set1_epi16(0xFFFF); + auto mask_ = _mm512_cmp_epi16_mask(mask, msb_one, _MM_CMPINT_EQ); + return _mm512_mask_blend_epi16(mask_, a.values, b.values); + } + template + static Vectorized arange( + int16_t base = 0, + step_t step = static_cast(1)) { + return Vectorized( + base, + base + step, + base + 2 * step, + base + 3 * step, + base + 4 * step, + base + 5 * step, + base + 6 * step, + base + 7 * step, + base + 8 * step, + base + 9 * step, + base + 10 * step, + base + 11 * step, + base + 12 * step, + base + 13 * step, + base + 14 * step, + base + 15 * step, + base + 16 * step, + base + 17 * step, + base + 18 * step, + base + 19 * step, + base + 20 * step, + base + 21 * step, + base + 22 * step, + base + 23 * step, + base + 24 * step, + base + 25 * step, + base + 26 * step, + base + 27 * step, + base + 28 * step, + base + 29 * step, + base + 30 * step, + base + 31 * step); + } + static Vectorized set( + Vectorized a, + Vectorized b, + int16_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<0x1>(a, b); + case 2: + return blend<0x3>(a, b); + case 3: + return blend<0x7>(a, b); + case 4: + return blend<0xF>(a, b); + case 5: + return blend<0x1F>(a, b); + case 6: + return blend<0x3F>(a, b); + case 7: + return blend<0x7F>(a, b); + case 8: + return blend<0xFF>(a, b); + case 9: + return blend<0x1FF>(a, b); + case 10: + return blend<0x3FF>(a, b); + case 11: + return blend<0x7FF>(a, b); + case 12: + return blend<0xFFF>(a, b); + case 13: + return blend<0x1FFF>(a, b); + case 14: + return blend<0x3FFF>(a, b); + case 15: + return blend<0x7FFF>(a, b); + case 16: + return blend<0xFFFF>(a, b); + case 17: + return blend<0x1FFFF>(a, b); + case 18: + return blend<0x3FFFF>(a, b); + case 19: + return blend<0x7FFFF>(a, b); + case 20: + return blend<0xFFFFF>(a, b); + case 21: + return blend<0x1FFFFF>(a, b); + case 22: + return blend<0x3FFFFF>(a, b); + case 23: + return blend<0x7FFFFF>(a, b); + case 24: + return blend<0xFFFFFF>(a, b); + case 25: + return blend<0x1FFFFFF>(a, b); + case 26: + return blend<0x3FFFFFF>(a, b); + case 27: + return blend<0x7FFFFFF>(a, b); + case 28: + return blend<0xFFFFFFF>(a, b); + case 29: + return blend<0x1FFFFFFF>(a, b); + case 30: + return blend<0x3FFFFFFF>(a, b); + case 31: + return blend<0x7FFFFFFF>(a, b); + } + return b; + } + static Vectorized loadu(const void* ptr) { + return _mm512_loadu_si512(reinterpret_cast(ptr)); + } + static Vectorized loadu(const void* ptr, int16_t count) { + if (count == size()) { + return _mm512_loadu_si512(reinterpret_cast(ptr)); + } else { + __mmask32 mask = (1ULL << count) - 1; + auto ones = _mm512_set1_epi16(1); + return _mm512_mask_loadu_epi16(ones, mask, ptr); + } + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + // ptr need not to be aligned here. See + // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm512-storeu-si512.html + _mm512_storeu_si512(reinterpret_cast<__m512i*>(ptr), values); + } else if (count > 0) { + __mmask32 mask = (1ULL << count) - 1; + _mm512_mask_storeu_epi16(ptr, mask, values); + } + } + const int16_t& operator[](int idx) const = delete; + int16_t& operator[](int idx) = delete; + Vectorized abs() const { + return _mm512_abs_epi16(values); + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return _mm512_set1_epi16(0); + } + Vectorized conj() const { + return *this; + } + Vectorized neg() const; + Vectorized operator==(const Vectorized& other) const { + auto mask = _mm512_cmpeq_epi16_mask(values, other.values); + return _mm512_mask_set1_epi16(zero_vector, mask, 0xFFFF); + } + Vectorized operator!=(const Vectorized& other) const { + auto mask = _mm512_cmpneq_epi16_mask(values, other.values); + return _mm512_mask_set1_epi16(zero_vector, mask, 0xFFFF); + } + Vectorized operator<(const Vectorized& other) const { + auto mask = _mm512_cmplt_epi16_mask(values, other.values); + return _mm512_mask_set1_epi16(zero_vector, mask, 0xFFFF); + } + Vectorized operator<=(const Vectorized& other) const { + auto mask = _mm512_cmple_epi16_mask(values, other.values); + return _mm512_mask_set1_epi16(zero_vector, mask, 0xFFFF); + } + Vectorized operator>(const Vectorized& other) const { + auto mask = _mm512_cmpgt_epi16_mask(values, other.values); + return _mm512_mask_set1_epi16(zero_vector, mask, 0xFFFF); + } + Vectorized operator>=(const Vectorized& other) const { + auto mask = _mm512_cmpge_epi16_mask(values, other.values); + return _mm512_mask_set1_epi16(zero_vector, mask, 0xFFFF); + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template +class Vectorized8 : public Vectorizedi { + static_assert( + std::is_same_v || std::is_same_v, + "Only int8_t/uint8_t are supported"); + + protected: + static constexpr __m512i zero_vector{0, 0, 0, 0, 0, 0, 0, 0}; + static const Vectorized ones; + + public: + using value_type = T; + static constexpr int size() { + return 64; + } + using Vectorizedi::Vectorizedi; + Vectorized8() {} + Vectorized8(T v) { + values = _mm512_set1_epi8(v); + } + Vectorized8( + T val1, + T val2, + T val3, + T val4, + T val5, + T val6, + T val7, + T val8, + T val9, + T val10, + T val11, + T val12, + T val13, + T val14, + T val15, + T val16, + T val17, + T val18, + T val19, + T val20, + T val21, + T val22, + T val23, + T val24, + T val25, + T val26, + T val27, + T val28, + T val29, + T val30, + T val31, + T val32, + T val33, + T val34, + T val35, + T val36, + T val37, + T val38, + T val39, + T val40, + T val41, + T val42, + T val43, + T val44, + T val45, + T val46, + T val47, + T val48, + T val49, + T val50, + T val51, + T val52, + T val53, + T val54, + T val55, + T val56, + T val57, + T val58, + T val59, + T val60, + T val61, + T val62, + T val63, + T val64) { + values = _mm512_set_epi8( + val64, + val63, + val62, + val61, + val60, + val59, + val58, + val57, + val56, + val55, + val54, + val53, + val52, + val51, + val50, + val49, + val48, + val47, + val46, + val45, + val44, + val43, + val42, + val41, + val40, + val39, + val38, + val37, + val36, + val35, + val34, + val33, + val32, + val31, + val30, + val29, + val28, + val27, + val26, + val25, + val24, + val23, + val22, + val21, + val20, + val19, + val18, + val17, + val16, + val15, + val14, + val13, + val12, + val11, + val10, + val9, + val8, + val7, + val6, + val5, + val4, + val3, + val2, + val1); + } + template + static Vectorized blend(Vectorized a, Vectorized b) { + return _mm512_mask_blend_epi8(mask, a.values, b.values); + } + template + static Vectorized arange( + T base = 0, + step_t step = static_cast(1)) { + return Vectorized( + base, + base + step, + base + 2 * step, + base + 3 * step, + base + 4 * step, + base + 5 * step, + base + 6 * step, + base + 7 * step, + base + 8 * step, + base + 9 * step, + base + 10 * step, + base + 11 * step, + base + 12 * step, + base + 13 * step, + base + 14 * step, + base + 15 * step, + base + 16 * step, + base + 17 * step, + base + 18 * step, + base + 19 * step, + base + 20 * step, + base + 21 * step, + base + 22 * step, + base + 23 * step, + base + 24 * step, + base + 25 * step, + base + 26 * step, + base + 27 * step, + base + 28 * step, + base + 29 * step, + base + 30 * step, + base + 31 * step, + base + 32 * step, + base + 33 * step, + base + 34 * step, + base + 35 * step, + base + 36 * step, + base + 37 * step, + base + 38 * step, + base + 39 * step, + base + 40 * step, + base + 41 * step, + base + 42 * step, + base + 43 * step, + base + 44 * step, + base + 45 * step, + base + 46 * step, + base + 47 * step, + base + 48 * step, + base + 49 * step, + base + 50 * step, + base + 51 * step, + base + 52 * step, + base + 53 * step, + base + 54 * step, + base + 55 * step, + base + 56 * step, + base + 57 * step, + base + 58 * step, + base + 59 * step, + base + 60 * step, + base + 61 * step, + base + 62 * step, + base + 63 * step); + } + static Vectorized set(Vectorized a, Vectorized b, T count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<0x1>(a, b); + case 2: + return blend<0x3>(a, b); + case 3: + return blend<0x7>(a, b); + case 4: + return blend<0xF>(a, b); + case 5: + return blend<0x1F>(a, b); + case 6: + return blend<0x3F>(a, b); + case 7: + return blend<0x7F>(a, b); + case 8: + return blend<0xFF>(a, b); + case 9: + return blend<0x1FF>(a, b); + case 10: + return blend<0x3FF>(a, b); + case 11: + return blend<0x7FF>(a, b); + case 12: + return blend<0xFFF>(a, b); + case 13: + return blend<0x1FFF>(a, b); + case 14: + return blend<0x3FFF>(a, b); + case 15: + return blend<0x7FFF>(a, b); + case 16: + return blend<0xFFFF>(a, b); + case 17: + return blend<0x1FFFF>(a, b); + case 18: + return blend<0x3FFFF>(a, b); + case 19: + return blend<0x7FFFF>(a, b); + case 20: + return blend<0xFFFFF>(a, b); + case 21: + return blend<0x1FFFFF>(a, b); + case 22: + return blend<0x3FFFFF>(a, b); + case 23: + return blend<0x7FFFFF>(a, b); + case 24: + return blend<0xFFFFFF>(a, b); + case 25: + return blend<0x1FFFFFF>(a, b); + case 26: + return blend<0x3FFFFFF>(a, b); + case 27: + return blend<0x7FFFFFF>(a, b); + case 28: + return blend<0xFFFFFFF>(a, b); + case 29: + return blend<0x1FFFFFFF>(a, b); + case 30: + return blend<0x3FFFFFFF>(a, b); + case 31: + return blend<0x7FFFFFFF>(a, b); + case 32: + return blend<0xFFFFFFFF>(a, b); + case 33: + return blend<0x1FFFFFFFF>(a, b); + case 34: + return blend<0x3FFFFFFFF>(a, b); + case 35: + return blend<0x7FFFFFFFF>(a, b); + case 36: + return blend<0xFFFFFFFFF>(a, b); + case 37: + return blend<0x1FFFFFFFFF>(a, b); + case 38: + return blend<0x3FFFFFFFFF>(a, b); + case 39: + return blend<0x7FFFFFFFFF>(a, b); + case 40: + return blend<0xFFFFFFFFFF>(a, b); + case 41: + return blend<0x1FFFFFFFFFF>(a, b); + case 42: + return blend<0x3FFFFFFFFFF>(a, b); + case 43: + return blend<0x7FFFFFFFFFF>(a, b); + case 44: + return blend<0xFFFFFFFFFFF>(a, b); + case 45: + return blend<0x1FFFFFFFFFFF>(a, b); + case 46: + return blend<0x3FFFFFFFFFFF>(a, b); + case 47: + return blend<0x7FFFFFFFFFFF>(a, b); + case 48: + return blend<0xFFFFFFFFFFFF>(a, b); + case 49: + return blend<0x1FFFFFFFFFFFF>(a, b); + case 50: + return blend<0x3FFFFFFFFFFFF>(a, b); + case 51: + return blend<0x7FFFFFFFFFFFF>(a, b); + case 52: + return blend<0xFFFFFFFFFFFFF>(a, b); + case 53: + return blend<0x1FFFFFFFFFFFFF>(a, b); + case 54: + return blend<0x3FFFFFFFFFFFFF>(a, b); + case 55: + return blend<0x7FFFFFFFFFFFFF>(a, b); + case 56: + return blend<0xFFFFFFFFFFFFFF>(a, b); + case 57: + return blend<0x1FFFFFFFFFFFFFF>(a, b); + case 58: + return blend<0x3FFFFFFFFFFFFFF>(a, b); + case 59: + return blend<0x7FFFFFFFFFFFFFF>(a, b); + case 60: + return blend<0xFFFFFFFFFFFFFFF>(a, b); + case 61: + return blend<0x1FFFFFFFFFFFFFFF>(a, b); + case 62: + return blend<0x3FFFFFFFFFFFFFFF>(a, b); + case 63: + return blend<0x7FFFFFFFFFFFFFFF>(a, b); + } + return b; + } + static Vectorized loadu(const void* ptr) { + return _mm512_loadu_si512(reinterpret_cast(ptr)); + } + static Vectorized loadu_one_fourth(const void* ptr) { + // Fast path if only load element number of 16. + // Note: We didn't merge it as fast path of loadu(const void* ptr, T count), + // Because loadu(const void* ptr, T count) requires zero initialization for + // upper 384 bits. However, by using _mm512_castsi128_si512, the upper 384 + // bits of the result are undefined. + // TODO We can use _mm512_zextsi128_si512 in the future, + // since gcc 9.3 doesn't support it now. + __m128i input_128 = _mm_loadu_si128(reinterpret_cast(ptr)); + return _mm512_castsi128_si512(input_128); + } + static Vectorized loadu(const void* ptr, T count) { + if (count == size()) { + return _mm512_loadu_si512(reinterpret_cast(ptr)); + } else if (count == 16) { + // Fast path if only load element number of 16 + return loadu_one_fourth(ptr); + } else { + __mmask64 mask = (1ULL << count) - 1; + auto ones = _mm512_set1_epi8(1); + return _mm512_mask_loadu_epi8(ones, mask, ptr); + } + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + // ptr need not to be aligned here. See + // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm512-storeu-si512.html + _mm512_storeu_si512(reinterpret_cast<__m512i*>(ptr), values); + } else if (count > 0) { + if (count == 16) { + // Fast path if only store element number of 16 + _mm_storeu_si128( + reinterpret_cast<__m128i*>(ptr), _mm512_castsi512_si128(values)); + } else { + __mmask64 mask = (1ULL << count) - 1; + _mm512_mask_storeu_epi8(ptr, mask, values); + } + } + } + const T& operator[](int idx) const = delete; + T& operator[](int idx) = delete; + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return _mm512_set1_epi8(0); + } + Vectorized conj() const { + return *this; + } +}; + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized : public Vectorized8 { + public: + using Vectorized8::Vectorized8; + + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + auto msb_one = _mm512_set1_epi8(0xFF); + auto mask_ = _mm512_cmp_epi8_mask(mask, msb_one, _MM_CMPINT_EQ); + return _mm512_mask_blend_epi8(mask_, a.values, b.values); + } + + Vectorized neg() const; + + Vectorized abs() const { + return _mm512_abs_epi8(values); + } + + Vectorized operator==(const Vectorized& other) const { + auto mask = _mm512_cmpeq_epi8_mask(values, other.values); + return _mm512_mask_set1_epi8(zero_vector, mask, 0xFF); + } + Vectorized operator!=(const Vectorized& other) const { + auto mask = _mm512_cmpneq_epi8_mask(values, other.values); + return _mm512_mask_set1_epi8(zero_vector, mask, 0xFF); + } + Vectorized operator<(const Vectorized& other) const { + auto mask = _mm512_cmplt_epi8_mask(values, other.values); + return _mm512_mask_set1_epi8(zero_vector, mask, 0xFF); + } + Vectorized operator<=(const Vectorized& other) const { + auto mask = _mm512_cmple_epi8_mask(values, other.values); + return _mm512_mask_set1_epi8(zero_vector, mask, 0xFF); + } + Vectorized operator>(const Vectorized& other) const { + return other < *this; + } + Vectorized operator>=(const Vectorized& other) const { + return other <= *this; + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +class Vectorized : public Vectorized8 { + public: + using Vectorized8::Vectorized8; + + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + auto msb_one = _mm512_set1_epi8(0xFF); + auto mask_ = _mm512_cmp_epu8_mask(mask, msb_one, _MM_CMPINT_EQ); + return _mm512_mask_blend_epi8(mask_, a.values, b.values); + } + + Vectorized neg() const; + + Vectorized abs() const { + return *this; + } + + Vectorized operator==(const Vectorized& other) const { + auto mask = _mm512_cmpeq_epu8_mask(values, other.values); + return _mm512_mask_set1_epi8(zero_vector, mask, 0xFF); + } + Vectorized operator!=(const Vectorized& other) const { + auto mask = _mm512_cmpneq_epu8_mask(values, other.values); + return _mm512_mask_set1_epi8(zero_vector, mask, 0xFF); + } + Vectorized operator<(const Vectorized& other) const { + auto mask = _mm512_cmplt_epu8_mask(values, other.values); + return _mm512_mask_set1_epi8(zero_vector, mask, 0xFF); + } + Vectorized operator<=(const Vectorized& other) const { + auto mask = _mm512_cmple_epu8_mask(values, other.values); + return _mm512_mask_set1_epi8(zero_vector, mask, 0xFF); + } + Vectorized operator>(const Vectorized& other) const { + return other < *this; + } + Vectorized operator>=(const Vectorized& other) const { + return other <= *this; + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return _mm512_add_epi64(a, b); +} + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return _mm512_add_epi32(a, b); +} + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return _mm512_add_epi16(a, b); +} + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return _mm512_add_epi8(a, b); +} + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return _mm512_add_epi8(a, b); +} + +template <> +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return _mm512_sub_epi64(a, b); +} + +template <> +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return _mm512_sub_epi32(a, b); +} + +template <> +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return _mm512_sub_epi16(a, b); +} + +template <> +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return _mm512_sub_epi8(a, b); +} + +template <> +Vectorized inline operator-( + const Vectorized& a, + const Vectorized& b) { + return _mm512_sub_epi8(a, b); +} + +// Negation. Defined here so we can utilize operator- +inline Vectorized Vectorized::neg() const { + return Vectorized(0) - *this; +} + +inline Vectorized Vectorized::neg() const { + return Vectorized(0) - *this; +} + +inline Vectorized Vectorized::neg() const { + return Vectorized(0) - *this; +} + +inline Vectorized Vectorized::neg() const { + return Vectorized(0) - *this; +} + +inline Vectorized Vectorized::neg() const { + return Vectorized(0) - *this; +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return _mm512_mullo_epi64(a, b); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return _mm512_mullo_epi32(a, b); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return _mm512_mullo_epi16(a, b); +} + +template +Vectorized inline int_elementwise_binary_512( + const Vectorized& a, + const Vectorized& b, + Op op) { + T values_a[Vectorized::size()]; + T values_b[Vectorized::size()]; + a.store(values_a); + b.store(values_b); + for (int i = 0; i != Vectorized::size(); i++) { + values_a[i] = op(values_a[i], values_b[i]); + } + return Vectorized::loadu(values_a); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + // We don't have an instruction for multiplying int8_t +#ifndef CPU_CAPABILITY_AVX512 + return int_elementwise_binary_512(a, b, std::multiplies()); +#else + __m512i mask00FF = _mm512_set1_epi16(0x00FF); + __m512i a_lo = _mm512_srai_epi16(_mm512_slli_epi16(a, 8), 8); + __m512i b_lo = _mm512_srai_epi16(_mm512_slli_epi16(b, 8), 8); + __m512i a_hi = _mm512_srai_epi16(a, 8); + __m512i b_hi = _mm512_srai_epi16(b, 8); + __m512i res_lo = _mm512_and_si512(_mm512_mullo_epi16(a_lo, b_lo), mask00FF); + __m512i res_hi = _mm512_slli_epi16(_mm512_mullo_epi16(a_hi, b_hi), 8); + __m512i res = _mm512_or_si512(res_hi, res_lo); + return res; +#endif +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + // We don't have an instruction for multiplying uint8_t +#ifndef CPU_CAPABILITY_AVX512 + return int_elementwise_binary_512(a, b, std::multiplies()); +#else + __m512i mask00FF = _mm512_set1_epi16(0x00FF); + __m512i a_lo = _mm512_and_si512(a, mask00FF); + __m512i b_lo = _mm512_and_si512(b, mask00FF); + __m512i a_hi = _mm512_srli_epi16(a, 8); + __m512i b_hi = _mm512_srli_epi16(b, 8); + __m512i res_lo = _mm512_and_si512(_mm512_mullo_epi16(a_lo, b_lo), mask00FF); + __m512i res_hi = _mm512_slli_epi16(_mm512_mullo_epi16(a_hi, b_hi), 8); + __m512i res = _mm512_or_si512(res_hi, res_lo); + return res; +#endif +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return _mm512_min_epi64(a, b); +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return _mm512_min_epi32(a, b); +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return _mm512_min_epi16(a, b); +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return _mm512_min_epi8(a, b); +} + +template <> +Vectorized inline minimum( + const Vectorized& a, + const Vectorized& b) { + return _mm512_min_epu8(a, b); +} + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return _mm512_max_epi64(a, b); +} + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return _mm512_max_epi32(a, b); +} + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return _mm512_max_epi16(a, b); +} + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return _mm512_max_epi8(a, b); +} + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return _mm512_max_epu8(a, b); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min_val, + const Vectorized& max_val) { + return _mm512_min_epi64(max_val, _mm512_max_epi64(a, min_val)); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min_val, + const Vectorized& max_val) { + return _mm512_min_epi32(max_val, _mm512_max_epi32(a, min_val)); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min_val, + const Vectorized& max_val) { + return _mm512_min_epi16(max_val, _mm512_max_epi16(a, min_val)); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min_val, + const Vectorized& max_val) { + return _mm512_min_epi8(max_val, _mm512_max_epi8(a, min_val)); +} + +template <> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min_val, + const Vectorized& max_val) { + return _mm512_min_epu8(max_val, _mm512_max_epu8(a, min_val)); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max_val) { + return _mm512_min_epi64(max_val, a); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max_val) { + return _mm512_min_epi32(max_val, a); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max_val) { + return _mm512_min_epi16(max_val, a); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max_val) { + return _mm512_min_epi8(max_val, a); +} + +template <> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max_val) { + return _mm512_min_epu8(max_val, a); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min_val) { + return _mm512_max_epi64(min_val, a); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min_val) { + return _mm512_max_epi32(min_val, a); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min_val) { + return _mm512_max_epi16(min_val, a); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min_val) { + return _mm512_max_epi8(min_val, a); +} + +template <> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min_val) { + return _mm512_max_epu8(min_val, a); +} + +template +std::enable_if_t< + !(std::is_same_v || std::is_same_v), + Vectorized< + int32_t>> inline convert_to_int32(const T* ptr, int count = Vectorized::size()) { + return Vectorized::loadu(ptr, count); +} + +template +std:: + enable_if_t, Vectorized> inline convert_to_int32( + const int8_t* ptr, + int count = Vectorized::size()) { + if (count == Vectorized::size()) { + return _mm512_cvtepi8_epi32( + _mm_loadu_si128(reinterpret_cast(ptr))); + } else { + auto a = Vectorized::loadu(ptr, count); + return _mm512_cvtepi8_epi32(_mm512_castsi512_si128(a)); + } +} + +template +std:: + enable_if_t, Vectorized> inline convert_to_int32( + const uint8_t* ptr, + int count = Vectorized::size()) { + if (count == Vectorized::size()) { + return _mm512_cvtepu8_epi32( + _mm_loadu_si128(reinterpret_cast(ptr))); + } else { + auto a = Vectorized::loadu(ptr, count); + return _mm512_cvtepu8_epi32(_mm512_castsi512_si128(a)); + } +} + +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return int_elementwise_binary_512(a, b, std::divides()); +} +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return int_elementwise_binary_512(a, b, std::divides()); +} +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return int_elementwise_binary_512(a, b, std::divides()); +} +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return int_elementwise_binary_512(a, b, std::divides()); +} +template <> +Vectorized inline operator/( + const Vectorized& a, + const Vectorized& b) { + return int_elementwise_binary_512(a, b, std::divides()); +} + +template < + class T, + typename std::enable_if_t< + std::is_base_of>::value, + int> = 0> +inline Vectorized operator&(const Vectorized& a, const Vectorized& b) { + return _mm512_and_si512(a, b); +} +template < + class T, + typename std::enable_if_t< + std::is_base_of>::value, + int> = 0> +inline Vectorized operator|(const Vectorized& a, const Vectorized& b) { + return _mm512_or_si512(a, b); +} +template < + class T, + typename std::enable_if_t< + std::is_base_of>::value, + int> = 0> +inline Vectorized operator^(const Vectorized& a, const Vectorized& b) { + return _mm512_xor_si512(a, b); +} +template < + class T, + typename std::enable_if_t< + std::is_base_of>::value, + int> = 0> +inline Vectorized operator~(const Vectorized& a) { + return _mm512_xor_si512(a, _mm512_set1_epi32(-1)); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1); +} + +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1); +} + +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1); +} + +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1); +} + +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1); +} + +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1); +} + +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1); +} + +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1); +} + +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1); +} + +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1); +} + +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1); +} + +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1); +} + +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1); +} + +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1); +} + +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1); +} + +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1); +} + +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1); +} + +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1); +} + +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1); +} + +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1); +} + +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1); +} + +inline Vectorized Vectorized::eq( + const Vectorized& other) const { + return (*this == other) & Vectorized(1); +} + +inline Vectorized Vectorized::ne( + const Vectorized& other) const { + return (*this != other) & Vectorized(1); +} + +inline Vectorized Vectorized::gt( + const Vectorized& other) const { + return (*this > other) & Vectorized(1); +} + +inline Vectorized Vectorized::ge( + const Vectorized& other) const { + return (*this >= other) & Vectorized(1); +} + +inline Vectorized Vectorized::lt( + const Vectorized& other) const { + return (*this < other) & Vectorized(1); +} + +inline Vectorized Vectorized::le( + const Vectorized& other) const { + return (*this <= other) & Vectorized(1); +} + +template < + bool left_shift, + typename T, + typename std::enable_if_t< + std::is_same_v || std::is_same_v, + int> = 0> +Vectorized inline shift_512_8( + const Vectorized& a, + const Vectorized& b) { + // No vector instruction for shifting int8_t/uint8_t, so emulating + // it instead. + + // Control masks for shuffle operation, treating 512 bits as an + // array of 8-bit elements, and considering pairs of neighboring + // elements. Specifically, a mask named "ctl_M_N" (M,N in [0,1], and + // M!=N) is set so that shuffle will move element with index M from + // input pair into element with index N in output pair, and element + // with index M in output pair will be set to all 0s. + __m512i ctl_0_1 = _mm512_set_epi8( + 62, + 0x80, + 60, + 0x80, + 58, + 0x80, + 56, + 0x80, + 54, + 0x80, + 52, + 0x80, + 50, + 0x80, + 48, + 0x80, + 46, + 0x80, + 44, + 0x80, + 42, + 0x80, + 40, + 0x80, + 38, + 0x80, + 36, + 0x80, + 34, + 0x80, + 32, + 0x80, + 30, + 0x80, + 28, + 0x80, + 26, + 0x80, + 24, + 0x80, + 22, + 0x80, + 20, + 0x80, + 18, + 0x80, + 16, + 0x80, + 14, + 0x80, + 12, + 0x80, + 10, + 0x80, + 8, + 0x80, + 6, + 0x80, + 4, + 0x80, + 2, + 0x80, + 0, + 0x80); + __m512i ctl_1_0 = _mm512_set_epi8( + 0x80, + 63, + 0x80, + 61, + 0x80, + 59, + 0x80, + 57, + 0x80, + 55, + 0x80, + 53, + 0x80, + 51, + 0x80, + 49, + 0x80, + 47, + 0x80, + 45, + 0x80, + 43, + 0x80, + 41, + 0x80, + 39, + 0x80, + 37, + 0x80, + 35, + 0x80, + 33, + 0x80, + 31, + 0x80, + 29, + 0x80, + 27, + 0x80, + 25, + 0x80, + 23, + 0x80, + 21, + 0x80, + 19, + 0x80, + 17, + 0x80, + 15, + 0x80, + 13, + 0x80, + 11, + 0x80, + 9, + 0x80, + 7, + 0x80, + 5, + 0x80, + 3, + 0x80, + 1); + + // Masks for bitwise and operation, treating 512 bits as an array of + // 8-bit elements, and considering them in pairs of neighboring + // elements. A mask named "keep_M" (M in [0,1]) is set so that + // bitwise and will copy element with index M from input pair into + // element with the same index in output pair, while the other + // element in output pair will be set to all 0s. + __m512i keep_0 = _mm512_set1_epi16(0xFF); + __m512i keep_1 = _mm512_set1_epi16(0xFF00); + + // Take each 8-bit element with idx%2==0 from input array to be + // shifted and extend it to 16 bits so that 0s are added to the + // right. Then, perform shifting on this 16-bit number. Upper 8 + // bits will be proper result of shifting original 8-bit number, so + // write them to result array, into the same position from which + // corresponding input element is taken. Also, make sure that + // result array elements with idx%2!=0 are set to all 0s. + // + // Note that number of bits to shift for is extended to 16 bits by + // adding 0s to the left. That means this number is not properly + // sign-extended for negative values. However, number of bits to + // shift is treated as an unsigned integer by respective shift + // intrinsics anyway so if negative then either with or without + // proper sign extension, it will be interpreted as a number greater + // than 32, and the shifting result will be the same. + __m512i a0 = _mm512_shuffle_epi8(a, ctl_0_1); + __m512i b0 = _mm512_and_si512(b, keep_0); + __m512i c0; + if (left_shift) + c0 = _mm512_sllv_epi16(a0, b0); + else if constexpr (std::is_same_v) + c0 = _mm512_srav_epi16(a0, b0); + else + c0 = _mm512_srlv_epi16(a0, b0); + c0 = _mm512_shuffle_epi8(c0, ctl_1_0); + + // Perform shifting the same way for input array elements with + // idx%2==1. + __m512i a1 = _mm512_and_si512(a, keep_1); + __m512i b1 = _mm512_shuffle_epi8(b, ctl_1_0); + __m512i c1; + if (left_shift) + c1 = _mm512_sllv_epi16(a1, b1); + else if constexpr (std::is_same_v) + c1 = _mm512_srav_epi16(a1, b1); + else + c1 = _mm512_srlv_epi16(a1, b1); + c1 = _mm512_and_si512(c1, keep_1); + + // Merge partial results into the final result. + __m512i c = _mm512_or_si512(c0, c1); + + return c; +} + +template <> +Vectorized inline operator<<( + const Vectorized& a, + const Vectorized& b) { + return _mm512_sllv_epi64(a, b); +} + +template <> +Vectorized inline operator<<( + const Vectorized& a, + const Vectorized& b) { + return _mm512_sllv_epi32(a, b); +} + +template <> +Vectorized inline operator<<( + const Vectorized& a, + const Vectorized& b) { + return _mm512_sllv_epi16(a, b); +} + +template <> +Vectorized inline operator<<( + const Vectorized& a, + const Vectorized& b) { + return shift_512_8(a, b); +} + +template <> +Vectorized inline operator<<( + const Vectorized& a, + const Vectorized& b) { + return shift_512_8(a, b); +} + +template <> +Vectorized inline operator>>( + const Vectorized& a, + const Vectorized& b) { + return _mm512_srav_epi64(a, b); +} + +template <> +Vectorized inline operator>>( + const Vectorized& a, + const Vectorized& b) { + return _mm512_srav_epi32(a, b); +} + +template <> +Vectorized inline operator>>( + const Vectorized& a, + const Vectorized& b) { + return _mm512_srav_epi16(a, b); +} + +template <> +Vectorized inline operator>>( + const Vectorized& a, + const Vectorized& b) { + return shift_512_8(a, b); +} + +template <> +Vectorized inline operator>>( + const Vectorized& a, + const Vectorized& b) { + return shift_512_8(a, b); +} + +#endif + +} // namespace CPU_CAPABILITY +} // namespace at::vec + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_mask.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_mask.h new file mode 100644 index 0000000000000000000000000000000000000000..5ad0997df7d03d19214f50c9fa81b8d1f03ab02c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_mask.h @@ -0,0 +1,395 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include + +namespace at::vec { +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER) + +template +struct VecMaskLoad< + T, + dst_n, + mask_t, + mask_n, + typename std::enable_if_t< + (mask_n == dst_n * 2 && dst_n >= 1) && + (std::is_same_v || std::is_same_v), + void>> { + static inline VectorizedN apply( + const T* ptr, + const VecMask& vec_mask) { + at::vec::Vectorized zero_vec(0); + auto all_ones = _mm512_set1_epi32(0xFFFFFFFF); + VectorizedN tmp_vec; + VectorizedN result; + for (int i = 0; i < dst_n; i++) { + tmp_vec[0] = vec_mask[2 * i]; + tmp_vec[1] = vec_mask[2 * i + 1]; + auto int64_mask = VecMask(tmp_vec).template cast(); + auto int_mask = int64_mask.template cast()[0]; + auto mmask = _mm512_cmp_epi32_mask(int_mask, all_ones, _MM_CMPINT_EQ); + if constexpr (std::is_same_v) { + result[i] = Vectorized(_mm512_mask_loadu_ps( + zero_vec, mmask, ptr + i * Vectorized::size())); + } else { + result[i] = Vectorized(_mm512_mask_loadu_epi32( + zero_vec, mmask, ptr + i * Vectorized::size())); + } + } + return result; + } +}; + +template +struct VecMaskLoad< + T, + dst_n, + mask_t, + dst_n, + typename std::enable_if_t< + std::is_same_v || std::is_same_v, + void>> { + static inline VectorizedN apply( + const T* ptr, + const VecMask& vec_mask) { + at::vec::Vectorized zero_vec(0); + auto all_ones = _mm512_set1_epi32(0xFFFFFFFF); + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < dst_n; i++) { + auto tmp_mask = VecMask(vec_mask[i]); + auto int_mask = tmp_mask.template cast()[0]; + auto mmask = _mm512_cmp_epi32_mask(int_mask, all_ones, _MM_CMPINT_EQ); + if constexpr (std::is_same_v) { + result[i] = Vectorized(_mm512_mask_loadu_ps( + zero_vec, mmask, ptr + i * Vectorized::size())); + } else { + result[i] = Vectorized(_mm512_mask_loadu_epi32( + zero_vec, mmask, ptr + i * Vectorized::size())); + } + } + return result; + } +}; + +template +struct VecMaskLoad< + data_t, + dst_n, + mask_t, + dst_n, + std::enable_if_t< + std::is_same_v || std::is_same_v>> { + static inline VectorizedN apply( + const data_t* ptr, + const VecMask& vec_mask) { + auto all_ones = _mm512_set1_epi32(0xFFFFFFFF); + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < dst_n; i++) { + auto tmp_mask = VecMask(vec_mask[i]); + auto int_mask = tmp_mask.template cast(); + auto mmask0 = _mm512_cmp_epi32_mask(int_mask[0], all_ones, _MM_CMPINT_EQ); + auto mmask1 = _mm512_cmp_epi32_mask(int_mask[1], all_ones, _MM_CMPINT_EQ); + auto zero = _mm256_set1_epi16(0); + auto temp0 = _mm256_mask_loadu_epi16( + zero, mmask0, ptr + (2 * i) * Vectorized::size()); + auto temp1 = _mm256_mask_loadu_epi16( + zero, mmask1, ptr + (2 * i + 1) * Vectorized::size()); + result[i] = Vectorized( + _mm512_inserti32x8(_mm512_castsi256_si512(temp0), temp1, 1)); + } + return result; + } +}; + +template +struct VecMaskLoad< + data_t, + dst_n, + mask_t, + mask_n, + typename std::enable_if_t< + (mask_n == 2 * dst_n && dst_n >= 1) && + (std::is_same_v || std::is_same_v)>> { + static inline VectorizedN apply( + const data_t* ptr, + const VecMask& vec_mask) { + auto all_ones = _mm512_set1_epi32(0xFFFFFFFF); + VectorizedN result; + VectorizedN tmp_vec; + for (int i = 0; i < dst_n; i++) { + tmp_vec[0] = vec_mask[2 * i]; + tmp_vec[1] = vec_mask[2 * i + 1]; + auto int_mask = VecMask(tmp_vec).template cast(); + auto mmask0 = _mm512_cmp_epi32_mask(int_mask[0], all_ones, _MM_CMPINT_EQ); + auto mmask1 = _mm512_cmp_epi32_mask(int_mask[1], all_ones, _MM_CMPINT_EQ); + auto zero = _mm256_set1_epi16(0); + auto temp0 = _mm256_mask_loadu_epi16( + zero, mmask0, ptr + (2 * i) * Vectorized::size()); + auto temp1 = _mm256_mask_loadu_epi16( + zero, mmask1, ptr + (2 * i + 1) * Vectorized::size()); + result[i] = Vectorized( + _mm512_inserti32x8(_mm512_castsi256_si512(temp0), temp1, 1)); + } + return result; + } +}; + +template +struct VecMaskLoad< + data_t, + 1, + mask_t, + 1, + std::enable_if_t< + std::is_same_v || std::is_same_v>> { + static inline VectorizedN apply( + const data_t* ptr, + const VecMask& vec_mask) { + auto all_ones = _mm512_set1_epi32(0xFFFFFFFF); + auto int_mask = vec_mask.template cast()[0]; + auto mmask = _mm512_cmp_epi32_mask(int_mask, all_ones, _MM_CMPINT_EQ); + auto zero = _mm_set1_epi8(0); + auto temp = _mm_mask_loadu_epi8(zero, mmask, ptr); + return Vectorized( + _mm512_inserti64x2(_mm512_set1_epi32(0), temp, 0)); + } +}; + +template +struct VecMaskLoad< + data_t, + 2, + mask_t, + 1, + std::enable_if_t< + std::is_same_v || std::is_same_v>> { + static inline VectorizedN apply( + const data_t* ptr, + const VecMask& vec_mask) { + auto all_ones = _mm512_set1_epi32(0xFFFFFFFF); + at::vec::Vectorized zero_vec(0); + auto int_mask = vec_mask.template cast()[0]; + auto mmask = _mm512_cmp_epi32_mask(int_mask, all_ones, _MM_CMPINT_EQ); + at::vec::VectorizedN result; + if constexpr (std::is_same_v) { + result[0] = _mm512_mask_loadu_pd(zero_vec, (__mmask8)mmask, ptr); + result[1] = + _mm512_mask_loadu_pd(zero_vec, (__mmask8)(mmask >> 8), ptr + 8); + } else { + result[0] = _mm512_mask_loadu_epi64(zero_vec, (__mmask8)mmask, ptr); + result[1] = + _mm512_mask_loadu_epi64(zero_vec, (__mmask8)(mmask >> 8), ptr + 8); + } + return result; + } +}; + +template +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < N; ++i) { + result[i] = _mm512_castsi512_ps(vec_mask[i]); + } + return result; + } +}; + +template +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < N; ++i) { + result[i] = _mm512_castps_si512(vec_mask[i]); + } + return result; + } +}; + +template +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < N; ++i) { + result[i] = _mm512_castpd_si512(vec_mask[i]); + } + return result; + } +}; + +template +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < N; ++i) { + result[i] = _mm512_castsi512_pd(vec_mask[i]); + } + return result; + } +}; + +template +struct VecMaskCast< + int64_t, + dst_n, + mask_t, + mask_n, + typename std::enable_if_t< + (dst_n == 2 * mask_n) && + (std::is_same_v || std::is_same_v), + void>> { + static inline VecMask apply( + const VecMask& vec_mask) { + VectorizedN result; + auto int_mask = vec_mask.template cast(); +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < mask_n; ++i) { + auto int64_vec = + convert(VectorizedN(int_mask[i])); + result[2 * i] = int64_vec[0]; + result[2 * i + 1] = int64_vec[1]; + } + return VecMask(result); + } +}; + +template +struct VecMaskCast< + dst_t, + dst_n, + int64_t, + mask_n, + typename std::enable_if_t< + (mask_n == 2 * dst_n) && + (std::is_same_v || std::is_same_v), + void>> { + static inline VecMask apply( + const VecMask& vec_mask) { + VectorizedN result; + VectorizedN int64_vec; + for (int i = 0; i < dst_n; ++i) { + int64_vec[0] = vec_mask[2 * i]; + int64_vec[1] = vec_mask[2 * i + 1]; + result[i] = convert(int64_vec); + } + return VecMask(result).template cast(); + } +}; + +template <> +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + auto int64_mask = VecMaskCast::apply(vec_mask); + return VecMaskCast::apply(int64_mask); + } +}; + +template <> +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + auto int64_mask = VecMaskCast::apply(vec_mask); + return VecMaskCast::apply(int64_mask); + } +}; + +template <> +inline bool VecMask::all_zero() const { + __mmask16 mask = _mm512_test_epi32_mask(mask_[0], mask_[0]); + return mask == 0; +} + +template <> +inline bool VecMask::is_masked(int i) const { + return _mm512_movepi32_mask(mask_[0]) & (1 << i); +} + +template <> +inline bool VecMask::all_masked() const { + __mmask16 mask = _mm512_movepi32_mask(mask_[0]); + return mask == 0xffff; +} + +template +struct VecMaskCheck { + static inline bool all_zero(const VectorizedN& vec_mask) { + bool all_zero = true; + for (int i = 0; i < N; ++i) { + all_zero = + all_zero && (_mm512_test_epi64_mask(vec_mask[i], vec_mask[i]) == 0); + if (!all_zero) { + return all_zero; + } + } + return all_zero; + } + + static inline bool is_masked(const VectorizedN& vec_mask, int i) { + for (int j = 0; j < N; ++j) { + if (i < (j + 1) * 8) { + return _mm512_movepi64_mask(vec_mask[j]) & (1 << (i - j * 8)); + } + } + return false; + } + + static inline bool all_masked(const VectorizedN& vec_mask) { + bool all_masked = true; + for (int i = 0; i < N; ++i) { + all_masked = all_masked && (_mm512_movepi64_mask(vec_mask[i]) == 0xff); + if (!all_masked) { + return all_masked; + } + } + return all_masked; + } +}; + +#define VEC_MASK_METHOD_WITH_CAST_TO_INT( \ + T, N, return_type, method, args_def, args) \ + template <> \ + inline return_type VecMask::method args_def const { \ + return cast().method args; \ + } + +VEC_MASK_METHOD_WITH_CAST_TO_INT(float, 1, bool, all_zero, (), ()) +VEC_MASK_METHOD_WITH_CAST_TO_INT(int64_t, 2, bool, all_zero, (), ()) +VEC_MASK_METHOD_WITH_CAST_TO_INT(float, 1, bool, is_masked, (int i), (i)) +VEC_MASK_METHOD_WITH_CAST_TO_INT(int64_t, 2, bool, is_masked, (int i), (i)) +VEC_MASK_METHOD_WITH_CAST_TO_INT(float, 1, bool, all_masked, (), ()) +VEC_MASK_METHOD_WITH_CAST_TO_INT(int64_t, 2, bool, all_masked, (), ()) + +#undef VEC_MASK_DEFINE_METHOD_WITH_CAST_TO_INT + +#endif + +} // namespace CPU_CAPABILITY +} // namespace at::vec + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_qint.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_qint.h new file mode 100644 index 0000000000000000000000000000000000000000..270b96bac433b52d68329bf0a452381d0c8170a3 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_qint.h @@ -0,0 +1,1552 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include +#include + +#include +#include +#include +#include + +#include +#include + +// This file defines Vectorized<> for the quantized types. +// +// +// Currently, we simply use these classes as efficient converters between +// the quantized types and Vectorized, usually in bandwidth-bound cases +// where doing the arithmetic in full-precision is acceptable (e.g. +// elementwise operators). +// +// +// Conversions are as follows: +// Vectorized -> 4x Vectorized +// Vectorized -> 4x Vectorized +// Vectorized -> 1x Vectorized +// +// The size of the returned float vector is specified by the special +// constexpr function float_num_vecs. The type of the value returned +// from dequantize (and expected as an argument to quantize) is +// specified by float_vec_return_type. +// +// When writing kernels with these vectors, it is expected that floating- +// point operations will be carried out in a loop over +// Vectorized::float_num_vecs iterations. + +namespace at { +namespace vec { +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_AVX512) + +#ifdef _MSC_VER +__declspec(align(64)) struct Vectorizedqi { + protected: + __m512i vals; +#else +struct Vectorizedqi { + protected: + __m512i vals __attribute__((aligned(64))); +#endif + + public: + Vectorizedqi() { + vals = _mm512_setzero_si512(); + } + Vectorizedqi(__m512i v) : vals(v) {} + operator __m512i() const { + return vals; + } +}; + +template +__m512i pack_saturate_and_clamp( + __m512i first, + __m512i second, + T min_val, + T max_val); + +template <> +inline __m512i pack_saturate_and_clamp( + __m512i first [[maybe_unused]], + __m512i second [[maybe_unused]], + int32_t min_val [[maybe_unused]], + int32_t max_val [[maybe_unused]]) { + // This function is for linkage only, will not be used + TORCH_CHECK(false, "pack_saturate_and_clamp is not supported"); + return __m512i{}; +} + +template <> +inline __m512i pack_saturate_and_clamp( + __m512i first, + __m512i second, + int8_t min_val, + int8_t max_val) { + __m512i packed_and_sat = _mm512_packs_epi16(first, second); + return _mm512_max_epi8( + _mm512_set1_epi8(min_val), + _mm512_min_epi8(packed_and_sat, _mm512_set1_epi8(max_val))); +} + +template <> +inline __m512i pack_saturate_and_clamp( + __m512i first, + __m512i second, + uint8_t min_val, + uint8_t max_val) { + __m512i packed_and_sat = _mm512_packus_epi16(first, second); + return _mm512_max_epu8( + _mm512_set1_epi8(min_val), + _mm512_min_epu8(packed_and_sat, _mm512_set1_epi8(max_val))); +} + +template +typename std::enable_if_t< + std::is_same_v || std::is_same_v, + at::vec::Vectorized< + float>> inline convert_int8_to_float(at::vec::Vectorized src) { + // Note: this function only convert inputs number of elements equal to + // at::vec::Vectorized.size() Only handle first 16*8 bits + __m128i input_128 = _mm512_castsi512_si128(src); + // Convert from 16*uint8/int8 to 16*int32 + __m512i input_512_extended; + if constexpr (std::is_same_v) + input_512_extended = _mm512_cvtepu8_epi32(input_128); + else + input_512_extended = _mm512_cvtepi8_epi32(input_128); + // Convert from 16*int32 to 16*float32 + return _mm512_cvtepi32_ps(input_512_extended); +} + +template +at::vec::Vectorized inline convert_float_to_int8( + at::vec::Vectorized src); + +template <> +at::vec::Vectorized inline convert_float_to_int8( + at::vec::Vectorized src) { + // Convert from float32 to int32 with truncation + __m512i x_values_int32 = _mm512_cvttps_epi32(src); + + // Convert from int32 to int16 using signed saturation + __m512i xy_packed_v = _mm512_packs_epi32(x_values_int32, x_values_int32); + + constexpr auto min_val = std::numeric_limits::min(); + constexpr auto max_val = std::numeric_limits::max(); + + // Convert from int16 to int8 using unsigned saturation + __m512i xyzw_clamped_v = pack_saturate_and_clamp( + xy_packed_v, xy_packed_v, min_val, max_val); + __m512i permute_mask_v = _mm512_set_epi32( + 0x0f, + 0x0b, + 0x07, + 0x03, + 0x0e, + 0x0a, + 0x06, + 0x02, + 0x0d, + 0x09, + 0x05, + 0x01, + 0x0c, + 0x08, + 0x04, + 0x00); + return _mm512_permutexvar_epi32(permute_mask_v, xyzw_clamped_v); +} + +template <> +at::vec::Vectorized inline convert_float_to_int8( + at::vec::Vectorized src) { + // The type of *_val should be int32_t to ensure correct clamping behavior. + constexpr auto min_val = std::numeric_limits::min(); + constexpr auto max_val = std::numeric_limits::max(); + __m512 float32_min_val = _mm512_set1_ps(float(min_val)); + __m512 float32_max_val = _mm512_set1_ps(float(max_val)); + __m512 float32_src = _mm512_max_ps(src, float32_min_val); + float32_src = _mm512_min_ps(float32_src, float32_max_val); + __m512i int32_src_clamped = _mm512_cvttps_epi32(float32_src); + __m128i int8_src = _mm512_cvtepi32_epi8(int32_src_clamped); + return _mm512_castsi128_si512(int8_src); +} + +template +__FORCE_INLINE void QuantizeAvx512( + const float* src, + T* dst, + int len, + float inverse_scale, + int64_t zero_point) { + constexpr int VLEN = 16; + constexpr auto min_val = std::numeric_limits::min(); + constexpr auto max_val = std::numeric_limits::max(); + const __m512i min_v = _mm512_set1_epi32(min_val); + const __m512i max_v = _mm512_set1_epi32(max_val); + // This is the largest int32 value < int32_max exactly representable in float + constexpr int32_t int32_float_max_val = + std::numeric_limits::max() - 127; + int i = 0; + __m512 inverse_scale_v = _mm512_set1_ps(inverse_scale); + // clang-format off + static const __m512i shuffle_mask_v = _mm512_set_epi8( + 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, + 0x0c, 0x08, 0x04, 0x00, + 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, + 0x0c, 0x08, 0x04, 0x00, + 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, + 0x0c, 0x08, 0x04, 0x00, + 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, + 0x0c, 0x08, 0x04, 0x00); + // clang-format on + __m512i permute_mask_v = _mm512_set_epi32( + 0x0f, + 0x0b, + 0x07, + 0x03, + 0x0e, + 0x0a, + 0x06, + 0x02, + 0x0d, + 0x09, + 0x05, + 0x01, + 0x0c, + 0x08, + 0x04, + 0x00); + __m512i permute_mask_l8_v = _mm512_set_epi32( + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x0c, + 0x08, + 0x04, + 0x00); + int len_aligned = len / (VLEN * 4) * (VLEN * 4); + for (; i < len_aligned; i += 4 * VLEN) { + // x + __m512 x_vals = _mm512_load_ps(src + i); + __m512 x_transformed_v = _mm512_mul_ps(x_vals, inverse_scale_v); + // If the floating point value is greater than int32_max, + // _mm512_cvtps_epi32 converts them to -ve. Clip at int32_float_max_val to + // Clip at int32_float_max_val to avoid this. + x_transformed_v = + _mm512_min_ps(x_transformed_v, _mm512_set1_ps(int32_float_max_val)); + // y + __m512 y_vals = _mm512_load_ps(src + i + VLEN); + __m512 y_transformed_v = _mm512_mul_ps(y_vals, inverse_scale_v); + y_transformed_v = + _mm512_min_ps(y_transformed_v, _mm512_set1_ps(int32_float_max_val)); + // z + __m512 z_vals = _mm512_load_ps(src + i + 2 * VLEN); + __m512 z_transformed_v = _mm512_mul_ps(z_vals, inverse_scale_v); + z_transformed_v = + _mm512_min_ps(z_transformed_v, _mm512_set1_ps(int32_float_max_val)); + // w + __m512 w_vals = _mm512_load_ps(src + i + 3 * VLEN); + __m512 w_transformed_v = _mm512_mul_ps(w_vals, inverse_scale_v); + w_transformed_v = + _mm512_min_ps(w_transformed_v, _mm512_set1_ps(int32_float_max_val)); + + __m512i x_rounded_v = _mm512_cvtps_epi32(x_transformed_v); + __m512i y_rounded_v = _mm512_cvtps_epi32(y_transformed_v); + __m512i z_rounded_v = _mm512_cvtps_epi32(z_transformed_v); + __m512i w_rounded_v = _mm512_cvtps_epi32(w_transformed_v); + + // add zero point + x_rounded_v = _mm512_add_epi32(x_rounded_v, _mm512_set1_epi32(zero_point)); + y_rounded_v = _mm512_add_epi32(y_rounded_v, _mm512_set1_epi32(zero_point)); + z_rounded_v = _mm512_add_epi32(z_rounded_v, _mm512_set1_epi32(zero_point)); + w_rounded_v = _mm512_add_epi32(w_rounded_v, _mm512_set1_epi32(zero_point)); + + __m512i xy_packed_v = _mm512_packs_epi32(x_rounded_v, y_rounded_v); + __m512i zw_packed_v = _mm512_packs_epi32(z_rounded_v, w_rounded_v); + __m512i xyzw_clamped_v = + pack_saturate_and_clamp(xy_packed_v, zw_packed_v, min_val, max_val); + + xyzw_clamped_v = _mm512_permutexvar_epi32(permute_mask_v, xyzw_clamped_v); + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + i), xyzw_clamped_v); + } + + // Additional 8-lane AVX512 version to take advantage when len is smaller + // based on fbgemm::QuantizeAvx2 (https://github.com/pytorch/FBGEMM) + for (; i < len / VLEN * VLEN; i += VLEN) { + __m512 x_vals = _mm512_load_ps(src + i); + __m512 x_transformed_v = _mm512_mul_ps(x_vals, inverse_scale_v); + x_transformed_v = + _mm512_min_ps(x_transformed_v, _mm512_set1_ps(int32_float_max_val)); + __m512i x_rounded_v = _mm512_cvtps_epi32(x_transformed_v); + x_rounded_v = _mm512_add_epi32(x_rounded_v, _mm512_set1_epi32(zero_point)); + __m512i x_clipped_v = + _mm512_max_epi32(min_v, _mm512_min_epi32(max_v, x_rounded_v)); + + x_clipped_v = _mm512_shuffle_epi8(x_clipped_v, shuffle_mask_v); + x_clipped_v = _mm512_permutexvar_epi32(permute_mask_l8_v, x_clipped_v); + _mm_storeu_si128( + reinterpret_cast<__m128i*>(dst + i), + _mm512_castsi512_si128(x_clipped_v)); + } + + for (; i < len; ++i) { + float transformed = src[i] * inverse_scale; + + // Not exactly the same behavior as the vectorized code. + // The vectorized code above always rounds to even in halfway cases + // (https://software.intel.com/en-us/node/523819), but std::nearbyint + // does the same only when the current rounding mode is FE_TONEAREST. + // However, in practice, this should not be a problem because most cases + // use the default rounding mode FE_TONEAREST. + // Note that we cannot implement the same behavior as the vectorized code + // using std::round because it does rounding away from zero in halfway + // cases. + transformed = zero_point + std::nearbyint(transformed); + float clipped = + std::min(std::max(transformed, float(min_val)), float(max_val)); + dst[i] = clipped; + } +} + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +struct Vectorized : public Vectorizedqi { + using size_type = int; + static constexpr size_type size() { + return 16; + } + + static constexpr int float_num_vecs() { + return 1; + } + + static constexpr int int_num_vecs() { + return 1; + } + + using float_vec_return_type = std::array, 1>; + using int_vec_return_type = std::array, 1>; + using value_type = c10::qint32::underlying; + + public: + using Vectorizedqi::Vectorizedqi; + Vectorized() {} + + Vectorized(__m512i vals_) { + vals = vals_; + } + + // Broadcast constructor + Vectorized(const c10::qint32& val) { + value_type uw = val.val_; + vals = _mm512_set1_epi32(uw); + } + + void store(void* ptr, int count = size()) const { + if (count != size()) { + memcpy(ptr, &vals, count * sizeof(value_type)); + } else { + _mm512_storeu_si512((__m512i*)ptr, vals); + } + } + + static Vectorized loadu(const void* ptr) { + return Vectorized(ptr); + } + + static Vectorized loadu(const void* ptr, int64_t count) { + __at_align__ value_type tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(value_type)); + return loadu(tmp_values); + } + + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point, + Vectorized scale_zp_premul) const { + __m512 float_vals = _mm512_cvtepi32_ps(vals); + return {vec::fmadd(scale, Vectorized(float_vals), scale_zp_premul)}; + } + + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point) const { + __m512 float_vals = _mm512_cvtepi32_ps(vals); + return {(Vectorized(float_vals) - zero_point) * scale}; + } + + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale [[maybe_unused]]) { + Vectorized retval; + auto rhs_data = (__m512)rhs[0]; + at::native::quantize_vec( + scale, zero_point, (float*)&rhs_data, (c10::qint32*)&retval.vals, 16); + return retval; + } + + Vectorized maximum(Vectorized b) const { + return _mm512_max_epi32(vals, b.vals); + } + + Vectorized minimum(Vectorized b) const { + return _mm512_min_epi32(vals, b.vals); + } + + Vectorized relu(Vectorized zero_point) const { + return maximum(zero_point); + } + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) { + return _mm512_min_epi32( + _mm512_max_epi32(vals, zero_point.vals), q_six.vals); + } + + int_vec_return_type widening_subtract(Vectorized b) const { + return {_mm512_sub_epi32(vals, b)}; + } + + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + __m512 multiplier_v = _mm512_set1_ps(multiplier); + __m512i zero_point_v = _mm512_set1_epi32(zero_point); + + __m512 scaled = _mm512_mul_ps(_mm512_cvtepi32_ps(inp[0]), multiplier_v); + __m512i rounded = _mm512_cvtps_epi32(scaled); + return _mm512_add_epi32(rounded, zero_point_v); + } + + private: + // Load from memory constructor + Vectorized(const void* ptr) { + vals = _mm512_loadu_si512((const __m512i*)ptr); + } +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return _mm512_mullo_epi32(a, b); +} + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return _mm512_add_epi32(a, b); +} + +/* + * Convert values from int32 back to int8/uint8 + */ +template +__m512i RequantizeAvx512( + const std::array, 4>& inp, + __m512 multiplier, + __m512i zp) { + static_assert( + std::is_same_v || std::is_same_v, + "Only int8_t/uint8_t are supported"); + constexpr auto min_val = std::numeric_limits::min(); + constexpr auto max_val = std::numeric_limits::max(); + __m512i permute_mask_v = _mm512_set_epi32( + 0x0f, + 0x0b, + 0x07, + 0x03, + 0x0e, + 0x0a, + 0x06, + 0x02, + 0x0d, + 0x09, + 0x05, + 0x01, + 0x0c, + 0x08, + 0x04, + 0x00); + __m512 x_scaled_v = _mm512_mul_ps(_mm512_cvtepi32_ps(inp[0]), multiplier); + __m512 y_scaled_v = _mm512_mul_ps(_mm512_cvtepi32_ps(inp[1]), multiplier); + __m512 z_scaled_v = _mm512_mul_ps(_mm512_cvtepi32_ps(inp[2]), multiplier); + __m512 w_scaled_v = _mm512_mul_ps(_mm512_cvtepi32_ps(inp[3]), multiplier); + + __m512i x_rounded_v = _mm512_cvtps_epi32(x_scaled_v); + __m512i y_rounded_v = _mm512_cvtps_epi32(y_scaled_v); + __m512i z_rounded_v = _mm512_cvtps_epi32(z_scaled_v); + __m512i w_rounded_v = _mm512_cvtps_epi32(w_scaled_v); + + /* Add zero point */ + __m512i x_v = _mm512_add_epi32(x_rounded_v, zp); + __m512i y_v = _mm512_add_epi32(y_rounded_v, zp); + __m512i z_v = _mm512_add_epi32(z_rounded_v, zp); + __m512i w_v = _mm512_add_epi32(w_rounded_v, zp); + + /* Pack to int16_t and saturate */ + __m512i xy_packed_v = _mm512_packs_epi32(x_v, y_v); + __m512i zw_packed_v = _mm512_packs_epi32(z_v, w_v); + + __m512i xyzw_clamped_v = + pack_saturate_and_clamp(xy_packed_v, zw_packed_v, min_val, max_val); + + /* + * xyzw_clamped_v has results in the following layout so we need to + * permute: x0-3 y0-3 z0-3 w0-3 x4-7 y4-7 z4-7 w4-7 x8-11 y8-11 z8-11 w8-11 + * x12-15 y12-15 z12-15 w12-15 + */ + xyzw_clamped_v = _mm512_permutexvar_epi32(permute_mask_v, xyzw_clamped_v); + return xyzw_clamped_v; +} + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +struct Vectorized : public Vectorizedqi { + static constexpr int size() { + return 64; + } + + static constexpr int float_num_vecs() { + return 4; + } + + static constexpr int int_num_vecs() { + return 4; + } + + using float_vec_return_type = std::array, 4>; + using int_vec_return_type = std::array, 4>; + using value_type = c10::qint8::underlying; + + public: + using Vectorizedqi::Vectorizedqi; + + Vectorized() {} + Vectorized(__m512i vals_) { + vals = vals_; + } + + // Broadcast constructor + Vectorized(const c10::qint8& val) { + value_type uw = val.val_; + vals = _mm512_set1_epi8(uw); + } + + // This is needed because the compiler emits awful code for the default + // constructor for moving the enum + Vectorized(const Vectorized& other) : Vectorizedqi(other.vals) {} + + // This is added to avoid error: definition of implicit copy assignment + // operator for 'Vectorized' is deprecated because it has a + // user-declared copy constructor [-Werror,-Wdeprecated-copy] + Vectorized& operator=(const Vectorized&) = default; + + void store(void* ptr, int count = size()) const { + if (count != size()) { + memcpy(ptr, &vals, count * sizeof(value_type)); + } else { + _mm512_storeu_si512((__m512i*)ptr, vals); + } + } + + static Vectorized loadu(const void* ptr) { + return Vectorized(ptr); + } + + static Vectorized loadu(const void* ptr, int64_t count) { + __at_align__ value_type tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(value_type)); + return loadu(tmp_values); + } + + private: + __m512i cvtepi8_epi32(__m128i epi8_vals) const { + return _mm512_cvtepi8_epi32(epi8_vals); + } + + public: + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point, + Vectorized scale_neg_zp_premul) const { +#if defined(_MSC_VER) && !defined(__clang__) + __m128i int_val0 = _mm_set_epi64x(vals.m512i_u64[1], vals.m512i_u64[0]); + __m128i int_val1 = _mm_set_epi64x(vals.m512i_u64[3], vals.m512i_u64[2]); + __m128i int_val2 = _mm_set_epi64x(vals.m512i_u64[5], vals.m512i_u64[4]); + __m128i int_val3 = _mm_set_epi64x(vals.m512i_u64[7], vals.m512i_u64[6]); +#else + __m128i int_val0 = _mm_set_epi64x(vals[1], vals[0]); + __m128i int_val1 = _mm_set_epi64x(vals[3], vals[2]); + __m128i int_val2 = _mm_set_epi64x(vals[5], vals[4]); + __m128i int_val3 = _mm_set_epi64x(vals[7], vals[6]); +#endif + + __m512 float_val0 = _mm512_cvtepi32_ps(cvtepi8_epi32(int_val0)); + __m512 float_val1 = _mm512_cvtepi32_ps(cvtepi8_epi32(int_val1)); + __m512 float_val2 = _mm512_cvtepi32_ps(cvtepi8_epi32(int_val2)); + __m512 float_val3 = _mm512_cvtepi32_ps(cvtepi8_epi32(int_val3)); + + auto val0 = + vec::fmadd(scale, Vectorized(float_val0), scale_neg_zp_premul); + auto val1 = + vec::fmadd(scale, Vectorized(float_val1), scale_neg_zp_premul); + auto val2 = + vec::fmadd(scale, Vectorized(float_val2), scale_neg_zp_premul); + auto val3 = + vec::fmadd(scale, Vectorized(float_val3), scale_neg_zp_premul); + return {val0, val1, val2, val3}; + } + + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point) const { +#if defined(_MSC_VER) && !defined(__clang__) + __m128i int_val0 = _mm_set_epi64x(vals.m512i_u64[1], vals.m512i_u64[0]); + __m128i int_val1 = _mm_set_epi64x(vals.m512i_u64[3], vals.m512i_u64[2]); + __m128i int_val2 = _mm_set_epi64x(vals.m512i_u64[5], vals.m512i_u64[4]); + __m128i int_val3 = _mm_set_epi64x(vals.m512i_u64[7], vals.m512i_u64[6]); +#else + __m128i int_val0 = _mm_set_epi64x(vals[1], vals[0]); + __m128i int_val1 = _mm_set_epi64x(vals[3], vals[2]); + __m128i int_val2 = _mm_set_epi64x(vals[5], vals[4]); + __m128i int_val3 = _mm_set_epi64x(vals[7], vals[6]); +#endif + + __m512 float_val0 = _mm512_cvtepi32_ps(cvtepi8_epi32(int_val0)); + __m512 float_val1 = _mm512_cvtepi32_ps(cvtepi8_epi32(int_val1)); + __m512 float_val2 = _mm512_cvtepi32_ps(cvtepi8_epi32(int_val2)); + __m512 float_val3 = _mm512_cvtepi32_ps(cvtepi8_epi32(int_val3)); + + auto val0 = (Vectorized(float_val0) - zero_point) * scale; + auto val1 = (Vectorized(float_val1) - zero_point) * scale; + auto val2 = (Vectorized(float_val2) - zero_point) * scale; + auto val3 = (Vectorized(float_val3) - zero_point) * scale; + return {val0, val1, val2, val3}; + } + + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale) { + auto* rhs_data = (float*)rhs.data(); + int8_t quantized_values[64]; + QuantizeAvx512( + rhs_data, quantized_values, 64, inverse_scale, zero_point); + return Vectorized::loadu(quantized_values); + } + + Vectorized maximum(Vectorized b) const { + return _mm512_max_epi8(vals, b.vals); + } + + Vectorized minimum(Vectorized b) const { + return _mm512_min_epi8(vals, b.vals); + } + + Vectorized relu(Vectorized zero_point) const { + return maximum(zero_point); + } + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) { + return _mm512_min_epi8(_mm512_max_epi8(vals, zero_point.vals), q_six.vals); + } + + int_vec_return_type widening_subtract(Vectorized b) const { +#if defined(_MSC_VER) && !defined(__clang__) + __m128i int_val0 = _mm_set_epi64x(vals.m512i_u64[1], vals.m512i_u64[0]); + __m128i int_val1 = _mm_set_epi64x(vals.m512i_u64[3], vals.m512i_u64[2]); + __m128i int_val2 = _mm_set_epi64x(vals.m512i_u64[5], vals.m512i_u64[4]); + __m128i int_val3 = _mm_set_epi64x(vals.m512i_u64[7], vals.m512i_u64[6]); +#else + __m128i int_val0 = _mm_set_epi64x(vals[1], vals[0]); + __m128i int_val1 = _mm_set_epi64x(vals[3], vals[2]); + __m128i int_val2 = _mm_set_epi64x(vals[5], vals[4]); + __m128i int_val3 = _mm_set_epi64x(vals[7], vals[6]); +#endif + + __m512i int32_val0 = cvtepi8_epi32(int_val0); + __m512i int32_val1 = cvtepi8_epi32(int_val1); + __m512i int32_val2 = cvtepi8_epi32(int_val2); + __m512i int32_val3 = cvtepi8_epi32(int_val3); + +#if defined(_MSC_VER) && !defined(__clang__) + __m128i int_b0 = _mm_set_epi64x(b.vals.m512i_u64[1], b.vals.m512i_u64[0]); + __m128i int_b1 = _mm_set_epi64x(b.vals.m512i_u64[3], b.vals.m512i_u64[2]); + __m128i int_b2 = _mm_set_epi64x(b.vals.m512i_u64[5], b.vals.m512i_u64[4]); + __m128i int_b3 = _mm_set_epi64x(b.vals.m512i_u64[7], b.vals.m512i_u64[6]); +#else + __m128i int_b0 = _mm_set_epi64x(b.vals[1], b.vals[0]); + __m128i int_b1 = _mm_set_epi64x(b.vals[3], b.vals[2]); + __m128i int_b2 = _mm_set_epi64x(b.vals[5], b.vals[4]); + __m128i int_b3 = _mm_set_epi64x(b.vals[7], b.vals[6]); +#endif + + __m512i int32_b0 = cvtepi8_epi32(int_b0); + __m512i int32_b1 = cvtepi8_epi32(int_b1); + __m512i int32_b2 = cvtepi8_epi32(int_b2); + __m512i int32_b3 = cvtepi8_epi32(int_b3); + + __m512i res_0 = _mm512_sub_epi32(int32_val0, int32_b0); + __m512i res_1 = _mm512_sub_epi32(int32_val1, int32_b1); + __m512i res_2 = _mm512_sub_epi32(int32_val2, int32_b2); + __m512i res_3 = _mm512_sub_epi32(int32_val3, int32_b3); + + return { + Vectorized(res_0), + Vectorized(res_1), + Vectorized(res_2), + Vectorized(res_3)}; + } + + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + __m512 multiplier_v = _mm512_set1_ps(multiplier); + __m512i zero_point_v = _mm512_set1_epi32(zero_point); + return RequantizeAvx512(inp, multiplier_v, zero_point_v); + } + + private: + // Load from memory constructor + Vectorized(const void* ptr) { + vals = _mm512_loadu_si512((const __m512i*)ptr); + } +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +struct Vectorized : public Vectorizedqi { + static constexpr int size() { + return 64; + } + + static constexpr int float_num_vecs() { + return 4; + } + + static constexpr int int_num_vecs() { + return 4; + } + + using float_vec_return_type = std::array, 4>; + using int_vec_return_type = std::array, 4>; + using value_type = c10::quint8::underlying; + + public: + using Vectorizedqi::Vectorizedqi; + Vectorized() {} + + Vectorized(__m512i vals_) { + vals = vals_; + } + + // Broadcast constructor + Vectorized(const c10::quint8& val) { + value_type uw = val.val_; + vals = _mm512_set1_epi8(uw); + } + + Vectorized(const Vectorized& other) : Vectorizedqi(other.vals) {} + + // This is added to avoid error: definition of implicit copy assignment + // operator for 'Vectorized' is deprecated because it has a + // user-declared copy constructor [-Werror,-Wdeprecated-copy] + Vectorized& operator=(const Vectorized&) = default; + + void store(void* ptr, int count = size()) const { + if (count != size()) { + memcpy(ptr, &vals, count * sizeof(value_type)); + } else { + _mm512_storeu_si512((__m512i*)ptr, vals); + } + } + + static Vectorized loadu(const void* ptr) { + return Vectorized(ptr); + } + + static Vectorized loadu(const void* ptr, int64_t count) { + __at_align__ value_type tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(value_type)); + return loadu(tmp_values); + } + + private: + __m512i cvtepu8_epi32(__m128i epu8_vals) const { + return _mm512_cvtepu8_epi32(epu8_vals); + } + + public: + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point, + Vectorized scale_zp_premul) const { +#if defined(_MSC_VER) && !defined(__clang__) + __m128i int_val0 = _mm_set_epi64x(vals.m512i_u64[1], vals.m512i_u64[0]); + __m128i int_val1 = _mm_set_epi64x(vals.m512i_u64[3], vals.m512i_u64[2]); + __m128i int_val2 = _mm_set_epi64x(vals.m512i_u64[5], vals.m512i_u64[4]); + __m128i int_val3 = _mm_set_epi64x(vals.m512i_u64[7], vals.m512i_u64[6]); +#else + __m128i int_val0 = _mm_set_epi64x(vals[1], vals[0]); + __m128i int_val1 = _mm_set_epi64x(vals[3], vals[2]); + __m128i int_val2 = _mm_set_epi64x(vals[5], vals[4]); + __m128i int_val3 = _mm_set_epi64x(vals[7], vals[6]); +#endif + + __m512 float_val0 = _mm512_cvtepi32_ps(cvtepu8_epi32(int_val0)); + __m512 float_val1 = _mm512_cvtepi32_ps(cvtepu8_epi32(int_val1)); + __m512 float_val2 = _mm512_cvtepi32_ps(cvtepu8_epi32(int_val2)); + __m512 float_val3 = _mm512_cvtepi32_ps(cvtepu8_epi32(int_val3)); + + auto val0 = + vec::fmadd(scale, Vectorized(float_val0), scale_zp_premul); + auto val1 = + vec::fmadd(scale, Vectorized(float_val1), scale_zp_premul); + auto val2 = + vec::fmadd(scale, Vectorized(float_val2), scale_zp_premul); + auto val3 = + vec::fmadd(scale, Vectorized(float_val3), scale_zp_premul); + + return {val0, val1, val2, val3}; + } + + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point) const { +#if defined(_MSC_VER) && !defined(__clang__) + __m128i int_val0 = _mm_set_epi64x(vals.m512i_u64[1], vals.m512i_u64[0]); + __m128i int_val1 = _mm_set_epi64x(vals.m512i_u64[3], vals.m512i_u64[2]); + __m128i int_val2 = _mm_set_epi64x(vals.m512i_u64[5], vals.m512i_u64[4]); + __m128i int_val3 = _mm_set_epi64x(vals.m512i_u64[7], vals.m512i_u64[6]); +#else + __m128i int_val0 = _mm_set_epi64x(vals[1], vals[0]); + __m128i int_val1 = _mm_set_epi64x(vals[3], vals[2]); + __m128i int_val2 = _mm_set_epi64x(vals[5], vals[4]); + __m128i int_val3 = _mm_set_epi64x(vals[7], vals[6]); +#endif + + __m512 float_val0 = _mm512_cvtepi32_ps(cvtepu8_epi32(int_val0)); + __m512 float_val1 = _mm512_cvtepi32_ps(cvtepu8_epi32(int_val1)); + __m512 float_val2 = _mm512_cvtepi32_ps(cvtepu8_epi32(int_val2)); + __m512 float_val3 = _mm512_cvtepi32_ps(cvtepu8_epi32(int_val3)); + + auto val0 = (Vectorized(float_val0) - zero_point) * scale; + auto val1 = (Vectorized(float_val1) - zero_point) * scale; + auto val2 = (Vectorized(float_val2) - zero_point) * scale; + auto val3 = (Vectorized(float_val3) - zero_point) * scale; + + return {val0, val1, val2, val3}; + } + + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale) { + auto* rhs_data = (float*)rhs.data(); + uint8_t quantized_values[64]; + QuantizeAvx512( + rhs_data, quantized_values, 64, inverse_scale, zero_point); + return Vectorized::loadu(quantized_values); + } + + Vectorized maximum(Vectorized b) const { + return _mm512_max_epu8(vals, b.vals); + } + + Vectorized minimum(Vectorized b) const { + return _mm512_min_epu8(vals, b.vals); + } + + Vectorized relu(Vectorized zero_point) const { + return maximum(zero_point); + } + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) { + return _mm512_min_epu8(_mm512_max_epu8(vals, zero_point.vals), q_six.vals); + } + + int_vec_return_type widening_subtract(Vectorized b) const { +#if defined(_MSC_VER) && !defined(__clang__) + __m128i int_val0 = _mm_set_epi64x(vals.m512i_u64[1], vals.m512i_u64[0]); + __m128i int_val1 = _mm_set_epi64x(vals.m512i_u64[3], vals.m512i_u64[2]); + __m128i int_val2 = _mm_set_epi64x(vals.m512i_u64[5], vals.m512i_u64[4]); + __m128i int_val3 = _mm_set_epi64x(vals.m512i_u64[7], vals.m512i_u64[6]); +#else + __m128i int_val0 = _mm_set_epi64x(vals[1], vals[0]); + __m128i int_val1 = _mm_set_epi64x(vals[3], vals[2]); + __m128i int_val2 = _mm_set_epi64x(vals[5], vals[4]); + __m128i int_val3 = _mm_set_epi64x(vals[7], vals[6]); +#endif + + __m512i int32_val0 = cvtepu8_epi32(int_val0); + __m512i int32_val1 = cvtepu8_epi32(int_val1); + __m512i int32_val2 = cvtepu8_epi32(int_val2); + __m512i int32_val3 = cvtepu8_epi32(int_val3); + +#if defined(_MSC_VER) && !defined(__clang__) + __m128i int_b0 = _mm_set_epi64x(b.vals.m512i_u64[1], b.vals.m512i_u64[0]); + __m128i int_b1 = _mm_set_epi64x(b.vals.m512i_u64[3], b.vals.m512i_u64[2]); + __m128i int_b2 = _mm_set_epi64x(b.vals.m512i_u64[5], b.vals.m512i_u64[4]); + __m128i int_b3 = _mm_set_epi64x(b.vals.m512i_u64[7], b.vals.m512i_u64[6]); +#else + __m128i int_b0 = _mm_set_epi64x(b.vals[1], b.vals[0]); + __m128i int_b1 = _mm_set_epi64x(b.vals[3], b.vals[2]); + __m128i int_b2 = _mm_set_epi64x(b.vals[5], b.vals[4]); + __m128i int_b3 = _mm_set_epi64x(b.vals[7], b.vals[6]); +#endif + + __m512i int32_b0 = cvtepu8_epi32(int_b0); + __m512i int32_b1 = cvtepu8_epi32(int_b1); + __m512i int32_b2 = cvtepu8_epi32(int_b2); + __m512i int32_b3 = cvtepu8_epi32(int_b3); + + __m512i res_0 = _mm512_sub_epi32(int32_val0, int32_b0); + __m512i res_1 = _mm512_sub_epi32(int32_val1, int32_b1); + __m512i res_2 = _mm512_sub_epi32(int32_val2, int32_b2); + __m512i res_3 = _mm512_sub_epi32(int32_val3, int32_b3); + return { + Vectorized(res_0), + Vectorized(res_1), + Vectorized(res_2), + Vectorized(res_3)}; + } + + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + __m512 multiplier_v = _mm512_set1_ps(multiplier); + __m512i zero_point_v = _mm512_set1_epi32(zero_point); + return RequantizeAvx512(inp, multiplier_v, zero_point_v); + } + + private: + // Load from memory constructor + Vectorized(const void* ptr) { + vals = _mm512_loadu_si512((const __m512i*)ptr); + } +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +#else + +// NOTE: These are low-performance implementations that we fall back on. + +template < + typename T, + typename float_vec_return_type_, + typename int_vec_return_type_, + int size_> +struct VectorizedQuantizedConverter { + static constexpr int size() { + return size_; + } + + static constexpr int float_num_vecs() { + return size() / 8; + } + + static constexpr int int_num_vecs() { + return size() / 8; + } + + using float_vec_return_type = float_vec_return_type_; + using int_vec_return_type = int_vec_return_type_; + + using value_type = typename T::underlying; + std::array vals; + + VectorizedQuantizedConverter(T val) { + for (const auto i : c10::irange(size())) { + vals[i] = val.val_; + } + } + + VectorizedQuantizedConverter(const void* ptr) { + memcpy(vals.data(), ptr, sizeof(value_type) * size()); + } + + void store(void* ptr, int count = size()) const { + memcpy(ptr, vals.data(), count * sizeof(value_type)); + } + + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point, + Vectorized scale_zp_premul [[maybe_unused]]) const { + float_vec_return_type rv; + for (const auto i : c10::irange(float_num_vecs())) { + float tmp_vals[16]; + for (const auto j : c10::irange(16)) { + tmp_vals[j] = at::native::dequantize_val( + scale[j], zero_point[j], T(vals[16 * i + j])); + } + rv[i] = Vectorized( + tmp_vals[0], + tmp_vals[1], + tmp_vals[2], + tmp_vals[3], + tmp_vals[4], + tmp_vals[5], + tmp_vals[6], + tmp_vals[7], + tmp_vals[8], + tmp_vals[9], + tmp_vals[10], + tmp_vals[11], + tmp_vals[12], + tmp_vals[13], + tmp_vals[14], + tmp_vals[15]); + } + return rv; + } + + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point) const { + Vectorized scale_zp_premul; + return dequantize(scale, zero_point, scale_zp_premul); + } + + protected: + VectorizedQuantizedConverter() {} +}; + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +struct Vectorized : public VectorizedQuantizedConverter< + c10::qint32, + std::array, 1>, + std::array, 1>, + 16> { + Vectorized() + : VectorizedQuantizedConverter< + c10::qint32, + std::array, 1>, + std::array, 1>, + 16>() {} + Vectorized(c10::qint32 val) + : VectorizedQuantizedConverter< + c10::qint32, + std::array, 1>, + std::array, 1>, + 16>(val) {} + Vectorized(const void* ptr) + : VectorizedQuantizedConverter< + c10::qint32, + std::array, 1>, + std::array, 1>, + 16>(ptr) {} + + static Vectorized loadu(const void* ptr) { + return Vectorized(ptr); + } + + static Vectorized loadu(const void* ptr, int64_t count) { + __at_align__ value_type tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(value_type)); + return loadu(tmp_values); + } + + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale [[maybe_unused]]) { + std::array qvals; + std::array float_vals; + + for (const auto i : c10::irange(float_num_vecs())) { + rhs[i].store(&float_vals[i * 16], 16); + } + + at::native::quantize_vec( + scale, + zero_point, + float_vals.data(), + (c10::qint32*)qvals.data(), + 16 * float_num_vecs()); + + return Vectorized::loadu(qvals.data()); + } + + Vectorized maximum(Vectorized b) const { + Vectorized retval; + for (const auto i : c10::irange(size())) { + retval.vals[i] = std::max(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized minimum(Vectorized b) const { + Vectorized retval; + for (const auto i : c10::irange(size())) { + retval.vals[i] = std::min(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized relu(Vectorized zero_point) const { + return maximum(zero_point); + } + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) { + Vectorized retval; + for (const auto i : c10::irange(size())) { + retval.vals[i] = std::min( + std::max(vals[i], zero_point.vals[i]), q_six.vals[i]); + } + return retval; + } + + int_vec_return_type widening_subtract(Vectorized b) const { + int_vec_return_type retval; + for (const auto i : c10::irange(size())) { + retval[0].vals[i] = vals[i] - b.vals[i]; + } + return retval; + } + + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + Vectorized retval; + for (const auto i : c10::irange(size())) { + retval.vals[i] = + std::nearbyint(static_cast(inp[0].vals[i]) * multiplier) + + zero_point; + } + return retval; + } +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + Vectorized retval; + for (const auto i : c10::irange(std::decay_t::size())) { + retval.vals[i] = a.vals[i] * b.vals[i]; + } + return retval; +} + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + Vectorized retval; + for (const auto i : c10::irange(std::decay_t::size())) { + retval.vals[i] = a.vals[i] + b.vals[i]; + } + return retval; +} + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +struct Vectorized : public VectorizedQuantizedConverter< + c10::qint8, + std::array, 4>, + std::array, 4>, + 64> { + Vectorized() + : VectorizedQuantizedConverter< + c10::qint8, + std::array, 4>, + std::array, 4>, + 64>() {} + Vectorized(c10::qint8 val) + : VectorizedQuantizedConverter< + c10::qint8, + std::array, 4>, + std::array, 4>, + 64>(val) {} + Vectorized(const void* ptr) + : VectorizedQuantizedConverter< + c10::qint8, + std::array, 4>, + std::array, 4>, + 64>(ptr) {} + + static Vectorized loadu(const void* ptr) { + return Vectorized(ptr); + } + + static Vectorized loadu(const void* ptr, int64_t count) { + __at_align__ value_type tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(value_type)); + return loadu(tmp_values); + } + + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale [[maybe_unused]]) { + std::array qvals; + std::array float_vals; + + for (const auto i : c10::irange(float_num_vecs())) { + rhs[i].store(&float_vals[i * 16], 16); + } + + at::native::quantize_vec( + scale, + zero_point, + float_vals.data(), + (c10::qint8*)qvals.data(), + 16 * float_num_vecs()); + + return Vectorized::loadu(qvals.data()); + } + + Vectorized maximum(Vectorized b) const { + Vectorized retval; + for (const auto i : c10::irange(size())) { + retval.vals[i] = std::max(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized minimum(Vectorized b) const { + Vectorized retval; + for (const auto i : c10::irange(size())) { + retval.vals[i] = std::min(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized relu(Vectorized zero_point) const { + return maximum(zero_point); + } + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) { + Vectorized retval; + for (const auto i : c10::irange(size())) { + retval.vals[i] = std::min( + std::max(vals[i], zero_point.vals[i]), q_six.vals[i]); + } + return retval; + } + + int_vec_return_type widening_subtract(Vectorized b) const { + int_vec_return_type retval; + constexpr int elem_per_int_vec = size() / int_num_vecs(); + for (const auto i : c10::irange(int_num_vecs())) { + for (const auto j : c10::irange(elem_per_int_vec)) { + retval[i].vals[j] = + static_cast(vals[i * elem_per_int_vec + j]) - + static_cast(b.vals[i * elem_per_int_vec + j]); + } + } + return retval; + } + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + constexpr int elem_per_int_vec = size() / int_num_vecs(); + constexpr auto min_val = std::numeric_limits::min(); + constexpr auto max_val = std::numeric_limits::max(); + Vectorized retval; + for (const auto i : c10::irange(int_num_vecs())) { + for (const auto j : c10::irange(elem_per_int_vec)) { + int32_t rounded = + std::nearbyint(static_cast(inp[i].vals[j]) * multiplier) + + zero_point; + retval.vals[i * elem_per_int_vec + j] = + std::min(std::max(rounded, min_val), max_val); + } + } + return retval; + } +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +template <> +struct is_vec_specialized_for : std::bool_constant {}; + +template <> +struct Vectorized : public VectorizedQuantizedConverter< + c10::quint8, + std::array, 4>, + std::array, 4>, + 64> { + Vectorized() + : VectorizedQuantizedConverter< + c10::quint8, + std::array, 4>, + std::array, 4>, + 64>() {} + Vectorized(c10::quint8 val) + : VectorizedQuantizedConverter< + c10::quint8, + std::array, 4>, + std::array, 4>, + 64>(val) {} + Vectorized(const void* ptr) + : VectorizedQuantizedConverter< + c10::quint8, + std::array, 4>, + std::array, 4>, + 64>(ptr) {} + + static Vectorized loadu(const void* ptr) { + return Vectorized(ptr); + } + + static Vectorized loadu(const void* ptr, int64_t count) { + __at_align__ value_type tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See + // https://github.com/pytorch/pytorch/issues/32502 for more details. We do + // not initialize arrays to zero using "={0}" because gcc would compile it + // to two instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(value_type)); + return loadu(tmp_values); + } + + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale [[maybe_unused]]) { + std::array qvals; + std::array float_vals; + + for (const auto i : c10::irange(float_num_vecs())) { + rhs[i].store(&float_vals[i * 16], 16); + } + + at::native::quantize_vec( + scale, + zero_point, + float_vals.data(), + (c10::quint8*)qvals.data(), + 16 * float_num_vecs()); + + return Vectorized::loadu(qvals.data()); + } + + Vectorized maximum(Vectorized b) const { + Vectorized retval; + for (const auto i : c10::irange(size())) { + retval.vals[i] = std::max(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized minimum(Vectorized b) const { + Vectorized retval; + for (const auto i : c10::irange(size())) { + retval.vals[i] = std::min(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized relu(Vectorized zero_point) const { + return maximum(zero_point); + } + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) { + Vectorized retval; + for (const auto i : c10::irange(size())) { + retval.vals[i] = std::min( + std::max(vals[i], zero_point.vals[i]), q_six.vals[i]); + } + return retval; + } + + int_vec_return_type widening_subtract(Vectorized b) const { + int_vec_return_type retval; + constexpr int elem_per_int_vec = size() / int_num_vecs(); + for (const auto i : c10::irange(int_num_vecs())) { + for (const auto j : c10::irange(elem_per_int_vec)) { + retval[i].vals[j] = + static_cast(vals[i * elem_per_int_vec + j]) - + static_cast(b.vals[i * elem_per_int_vec + j]); + } + } + return retval; + } + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + constexpr int elem_per_int_vec = size() / int_num_vecs(); + constexpr auto min_val = std::numeric_limits::min(); + constexpr auto max_val = std::numeric_limits::max(); + Vectorized retval; + for (const auto i : c10::irange(int_num_vecs())) { + for (const auto j : c10::irange(elem_per_int_vec)) { + int32_t rounded = + std::nearbyint(static_cast(inp[i].vals[j]) * multiplier) + + zero_point; + retval.vals[i * elem_per_int_vec + j] = + std::min(std::max(rounded, min_val), max_val); + } + } + return retval; + } +}; + +template <> +Vectorized inline maximum( + const Vectorized& a, + const Vectorized& b) { + return a.maximum(b); +} + +#endif // defined(CPU_CAPABILITY_AVX512) && !defined(MSVC) + +} // namespace CPU_CAPABILITY +} // namespace vec +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec_base.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec_base.h new file mode 100644 index 0000000000000000000000000000000000000000..3f06f3fc806c9056c0e8361a320b69c0d2003ba5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec_base.h @@ -0,0 +1,1537 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#if defined(__GNUC__) && __GNUC__ == 10 && __GNUC_MINOR__ <= 2 && \ + defined(__ARM_FEATURE_SVE) +// Workaround for https: //gcc.gnu.org/bugzilla/show_bug.cgi?id=117161 +#pragma GCC optimize("no-tree-vectorize") +#endif + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] +// +// Note [Do not compile initializers with AVX] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// If you define a static initializer in this file, the initialization will use +// AVX instructions because these object files are compiled with AVX enabled. +// We need to avoid non-trivial global data in these architecture specific files +// because there's no way to guard the global initializers with CPU capability +// detection. +// +// See https://github.com/pytorch/pytorch/issues/37577 for an instance +// of this bug in the past. + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(__GNUC__) +#define __FORCE_INLINE __attribute__((always_inline)) inline +#elif defined(_MSC_VER) +#define __FORCE_INLINE __forceinline +#endif + +#if defined(_MSC_FULL_VER) +/* +https://learn.microsoft.com/en-us/cpp/overview/compiler-versions?view=msvc-170 +Use _MSC_FULL_VER to identify current compiler is msvc, +Windows llvm will not have this definition. +*/ +#define __msvc_cl__ +#endif + +// These macros helped us unify vec_base.h +#ifdef CPU_CAPABILITY_AVX512 +#if defined(__GNUC__) +#define __at_align__ __attribute__((aligned(64))) +#elif defined(_WIN32) +#define __at_align__ __declspec(align(64)) +#else +#define __at_align__ +#endif +#define VECTOR_WIDTH 64 +#define int_vector __m512i +#elif defined(__aarch64__) && \ + !defined(CPU_CAPABILITY_SVE) // CPU_CAPABILITY_AVX512 +// SVE code expects 256-vectors; leave that set for SVE? +#if defined(__GNUC__) +#define __at_align__ __attribute__((aligned(16))) +#elif defined(_WIN32) +#define __at_align__ __declspec(align(16)) +#else +#define __at_align__ +#endif +#define VECTOR_WIDTH 16 +#else // CPU_CAPABILITY_AVX512 +#if defined(__GNUC__) +#define __at_align__ __attribute__((aligned(32))) +#elif defined(_WIN32) +#define __at_align__ __declspec(align(32)) +#else +#define __at_align__ +#endif +#define VECTOR_WIDTH 32 +#define int_vector __m256i +#endif // CPU_CAPABILITY_AVX512 + +namespace at::vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { +// at::Half and at::BFloat16 should be treated as floating point +template +struct is_floating_point + : std::integral_constant< + bool, + std::is_floating_point_v || std::is_same_v || + std::is_same_v> {}; + +template +constexpr bool is_floating_point_v = is_floating_point::value; + +template +struct is_reduced_floating_point + : std::integral_constant< + bool, + std::is_same_v || std::is_same_v> {}; + +template +constexpr bool is_reduced_floating_point_v = + is_reduced_floating_point::value; + +template +struct is_8bit_integer + : std::integral_constant< + bool, + std::is_same_v || std::is_same_v> { +}; + +template +constexpr bool is_8bit_integer_v = is_8bit_integer::value; + +template +struct int_of_size; + +#define DEFINE_INT_OF_SIZE(int_t) \ + template <> \ + struct int_of_size { \ + using type = int_t; \ + } + +DEFINE_INT_OF_SIZE(int64_t); +DEFINE_INT_OF_SIZE(int32_t); +DEFINE_INT_OF_SIZE(int16_t); +DEFINE_INT_OF_SIZE(int8_t); + +#undef DEFINE_INT_OF_SIZE + +template +using int_same_size_t = typename int_of_size::type; + +/** + * Detect at compile time whether Vectorized has an explicit + * specialization for T. (You are required to specialize this type + * whenever you specialize Vectorized). Useful for generic algorithms + * to decide whether to rely on a specialization being fast. For + * example, they might choose to handle reduced-precision floating + * point types directly if they're supported, or convert through float + * if not. + */ +#if defined(__s390x__) +template +#else +template +#endif +struct is_vec_specialized_for : std::bool_constant { +}; + +template +constexpr bool is_vec_specialized_for_v = is_vec_specialized_for::value; + +// NOTE: If you specialize Vectorized on a type, you must define all +// operations! You must also specialize is_vec_specialized_for for +// that type. + +// emulates Vectorized types +#if defined(__s390x__) +template +#else +template +#endif +struct Vectorized { + private: + __at_align__ T values[VECTOR_WIDTH / sizeof(T)]; + + public: + using value_type = T; + using size_type = int; + + static constexpr size_type kSize = VECTOR_WIDTH / sizeof(T); + static constexpr size_type size() { + return kSize; + } + Vectorized() : values{static_cast(0)} {} + Vectorized(T val) { + for (int i = 0; i != size(); i++) { + values[i] = val; + } + } + template < + typename... Args, + typename = std::enable_if_t<(sizeof...(Args) == size())>> + Vectorized(Args... vals) : values{vals...} {} + Vectorized(const T (&arr)[kSize]) { + std::memcpy(values, arr, sizeof(values)); + } + // This also implies const T& operator[](int idx) const + inline operator const T*() const { + return values; + } + // This also implies T& operator[](int idx) + inline operator T*() { + return values; + } + // Return the values as char* for type punning + auto as_bytes() const -> const char* { + return reinterpret_cast(values); + } + template + static Vectorized blend(const Vectorized& a, const Vectorized& b) { + int64_t mask = mask_; + Vectorized vector; + for (const auto i : c10::irange(size())) { + if (mask & 0x01) { + vector[i] = b[i]; + } else { + vector[i] = a[i]; + } + mask = mask >> 1; + } + return vector; + } +// Workaround for https: //gcc.gnu.org/bugzilla/show_bug.cgi?id=117001 +#if __GNUC__ <= 12 && !defined(__clang__) && defined(__ARM_FEATURE_SVE) + static Vectorized __attribute__((optimize("-fno-tree-loop-vectorize"))) + blendv( + const Vectorized& a, +#else + static Vectorized blendv( + const Vectorized& a, +#endif + const Vectorized& b, + const Vectorized& mask) { + Vectorized vector; + int_same_size_t buffer[size()]; + mask.store(buffer); + for (const auto i : c10::irange(size())) { + if (buffer[i] & 0x01) { + vector[i] = b[i]; + } else { + vector[i] = a[i]; + } + } + return vector; + } + template // step sometimes requires a higher precision type + // (e.g., T=int, step_t=double) + static Vectorized arange( + T base = static_cast(0), + step_t step = static_cast(1)) { + Vectorized vector; + for (const auto i : c10::irange(size())) { + vector.values[i] = base + i * step; + } + return vector; + } + static Vectorized set( + const Vectorized& a, + const Vectorized& b, + int64_t count = size()) { + Vectorized vector; + for (const auto i : c10::irange(size())) { + if (i < count) { + vector[i] = b[i]; + } else { + vector[i] = a[i]; + } + } + return vector; + } + static Vectorized loadu(const void* ptr) { + Vectorized vector; + std::memcpy(vector.values, ptr, VECTOR_WIDTH); + return vector; + } + static Vectorized loadu(const void* ptr, int64_t count) { + Vectorized vector; + std::memcpy(vector.values, ptr, count * sizeof(T)); + return vector; + } + static Vectorized loadu_one_fourth(const void* ptr) { + static_assert( + std::is_same_v || std::is_same_v, + "For byte types only"); + return Vectorized::loadu(ptr, 8); + } + + void store(void* ptr, int count = size()) const { + std::memcpy(ptr, values, count * sizeof(T)); + } + int zero_mask() const { + // returns an integer mask where all zero elements are translated to 1-bit + // and others are translated to 0-bit + int mask = 0; + for (int i = 0; i < size(); ++i) { + if (values[i] == static_cast(0)) { + mask |= (1 << i); + } + } + return mask; + } + Vectorized isnan() const { + Vectorized vector; + for (int64_t i = 0; i != size(); i++) { + if (_isnan(values[i])) { + std::memset(static_cast(vector.values + i), 0xFF, sizeof(T)); + } else { + std::memset(static_cast(vector.values + i), 0, sizeof(T)); + } + } + return vector; + } + bool has_inf_nan() const { + for (int64_t i = 0; i != size(); i++) { + if (_isnan(values[i]) || _isinf(values[i])) { + return true; + } + } + return false; + } +// MSVC versions between 14.36 and 14.42 has a loop unrolling bug on Windows +// Arm64 +// See +// https://developercommunity.visualstudio.com/t/MSVC-loop-unrolling-problem-194033813-/10720692 +#if defined(_WIN32) && defined(__aarch64__) && \ + ((_MSVC_VER >= 1936) && (_MSVC_VER <= 1942)) + Vectorized map(T (*const f)(T)) const { + Vectorized ret; + for (int64_t i = 0; i < size(); i++) { + ret[i] = f(values[i]); + if (++i < size()) + ret[i] = f(values[i]); + } + return ret; + } + T reduce(T (*const f)(T)) const { + T ret = 0; + for (int64_t i = 0; i < size(); i++) { + ret = f(ret, values[i]); + if (++i < size()) + ret = f(ret, values[i]); + } + return ret; + } +#else + Vectorized map(T (*const f)(T)) const { + Vectorized ret; + for (int64_t i = 0; i != size(); i++) { + ret[i] = f(values[i]); + } + return ret; + } + T reduce(T (*const f)(T)) const { + T ret = 0; + for (int64_t i = 0; i != size(); i++) { + ret = f(ret, values[i]); + } + return ret; + } +#endif + Vectorized map(T (*const f)(const T&)) const { + Vectorized ret; + for (int64_t i = 0; i != size(); i++) { + ret[i] = f(values[i]); + } + return ret; + } + T reduce(T (*const f)(const T&)) const { + T ret = 0; + for (int64_t i = 0; i != size(); i++) { + ret = f(ret, values[i]); + } + return ret; + } + template < + typename other_t_abs = T, + typename std::enable_if_t< + !is_floating_point_v && + !c10::is_complex::value, + int> = 0> + Vectorized abs() const { + // other_t_abs is for SFINAE and clarity. Make sure it is not changed. + static_assert(std::is_same_v, "other_t_abs must be T"); + return map([](T x) -> T { return x < static_cast(0) ? -x : x; }); + } + template < + typename float_t_abs = T, + typename std::enable_if_t, int> = 0> + Vectorized abs() const { + // float_t_abs is for SFINAE and clarity. Make sure it is not changed. + static_assert(std::is_same_v, "float_t_abs must be T"); + // Specifically deal with floating-point because the generic code above + // won't handle -0.0 (which should result in 0.0) properly. + return map([](T x) -> T { return std::abs(x); }); + } + template < + typename complex_t_abs = T, + typename std::enable_if_t::value, int> = 0> + Vectorized abs() const { + // complex_t_abs is for SFINAE and clarity. Make sure it is not changed. + static_assert(std::is_same_v, "complex_t_abs must be T"); + // Specifically map() does not perform the type conversion needed by abs. + return map([](T x) { return static_cast(std::abs(x)); }); + } + + template < + typename other_t_sgn = T, + typename std::enable_if_t::value, int> = 0> + Vectorized sgn() const { + return map(at::native::sgn_impl); + } + + template < + typename other_t_angle = T, + typename std::enable_if_t::value, int> = + 0> + Vectorized angle() const { + // other_t_angle is for SFINAE and clarity. Make sure it is not changed. + static_assert(std::is_same_v, "other_t_angle must be T"); + return map(at::native::angle_impl); // compiler is unable to resolve the + // overload without + } + template < + typename complex_t_angle = T, + typename std::enable_if_t::value, int> = + 0> + Vectorized angle() const { + // complex_t_angle is for SFINAE and clarity. Make sure it is not changed. + static_assert( + std::is_same_v, "complex_t_angle must be T"); + return map([](T x) { return static_cast(std::arg(x)); }); + } + template < + typename other_t_real = T, + typename std::enable_if_t::value, int> = 0> + Vectorized real() const { + // other_t_real is for SFINAE and clarity. Make sure it is not changed. + static_assert(std::is_same_v, "other_t_real must be T"); + return *this; + } + template < + typename complex_t_real = T, + typename std::enable_if_t::value, int> = + 0> + Vectorized real() const { + // complex_t_real is for SFINAE and clarity. Make sure it is not changed. + static_assert( + std::is_same_v, "complex_t_real must be T"); + return map([](T x) { return static_cast(x.real()); }); + } + template < + typename other_t_imag = T, + typename std::enable_if_t::value, int> = 0> + Vectorized imag() const { + // other_t_imag is for SFINAE and clarity. Make sure it is not changed. + static_assert(std::is_same_v, "other_t_imag must be T"); + return Vectorized(0); + } + template < + typename complex_t_imag = T, + typename std::enable_if_t::value, int> = + 0> + Vectorized imag() const { + // complex_t_imag is for SFINAE and clarity. Make sure it is not changed. + static_assert( + std::is_same_v, "complex_t_imag must be T"); + return map([](T x) { return static_cast(x.imag()); }); + } + template < + typename other_t_conj = T, + typename std::enable_if_t::value, int> = 0> + Vectorized conj() const { + // other_t_conj is for SFINAE and clarity. Make sure it is not changed. + static_assert(std::is_same_v, "other_t_conj must be T"); + return *this; + } + template < + typename complex_t_conj = T, + typename std::enable_if_t::value, int> = + 0> + Vectorized conj() const { + // complex_t_conj is for SFINAE and clarity. Make sure it is not changed. + static_assert( + std::is_same_v, "complex_t_conj must be T"); + return map([](T x) { return static_cast(std::conj(x)); }); + } + Vectorized acos() const { + return map(std::acos); + } + Vectorized acosh() const { + return map(std::acosh); + } + Vectorized asin() const { + return map(std::asin); + } + Vectorized asinh() const { + return map(std::asinh); + } + Vectorized atan() const { + return map(std::atan); + } + Vectorized atanh() const { + return map(std::atanh); + } + Vectorized atan2(const Vectorized& exp) const { + Vectorized ret; + for (const auto i : c10::irange(size())) { + ret[i] = std::atan2(values[i], exp[i]); + } + return ret; + } + template < + typename U = T, + typename std::enable_if_t, int> = 0> + Vectorized copysign(const Vectorized& sign) const { + Vectorized ret; + for (size_type i = 0; i < size(); i++) { + ret[i] = c10::copysign(values[i], sign[i]); + } + return ret; + } + Vectorized erf() const { + return map(std::erf); + } + Vectorized erfc() const { + return map(std::erfc); + } + Vectorized erfinv() const { + return map(calc_erfinv); + } + Vectorized exp() const { + return map(std::exp); + } + Vectorized exp2() const { + return map(exp2_impl); + } + Vectorized expm1() const { + return map(std::expm1); + } + Vectorized exp_u20() const { + return map(std::exp); + } + Vectorized fexp_u20() const { + return map(std::exp); + } + Vectorized frac() const { + return *this - this->trunc(); + } + template < + typename U = T, + typename std::enable_if_t, int> = 0> + Vectorized fmod(const Vectorized& q) const { + // U is for SFINAE purposes only. Make sure it is not changed. + static_assert(std::is_same_v, "U must be T"); + Vectorized ret; + for (const auto i : c10::irange(size())) { + ret[i] = std::fmod(values[i], q[i]); + } + return ret; + } + Vectorized log() const { + return map(std::log); + } + Vectorized log10() const { + return map(std::log10); + } + Vectorized log1p() const { + return map(std::log1p); + } + template < + typename other_t_log2 = T, + typename std::enable_if_t::value, int> = 0> + Vectorized log2() const { + // other_t_log2 is for SFINAE and clarity. Make sure it is not changed. + static_assert(std::is_same_v, "other_t_log2 must be T"); + return map(std::log2); + } + template < + typename complex_t_log2 = T, + typename std::enable_if_t::value, int> = + 0> + Vectorized log2() const { + // complex_t_log2 is for SFINAE and clarity. Make sure it is not changed. + static_assert( + std::is_same_v, "complex_t_log2 must be T"); + const T log_2 = T(std::log(2.0)); + return Vectorized(map(std::log)) / Vectorized(log_2); + } + Vectorized ceil() const { + return map(at::native::ceil_impl); + } + Vectorized cos() const { + return map(std::cos); + } + Vectorized cosh() const { + return map(std::cosh); + } + Vectorized floor() const { + return map(at::native::floor_impl); + } + Vectorized hypot(const Vectorized& b) const { + Vectorized ret; + for (const auto i : c10::irange(size())) { + ret[i] = std::hypot(values[i], b[i]); + } + return ret; + } + Vectorized i0() const { + return map(calc_i0); + } + Vectorized i0e() const { + return map(calc_i0e); + } + Vectorized digamma() const { + return map(calc_digamma); + } + Vectorized igamma(const Vectorized& x) const { + Vectorized ret; + for (const auto i : c10::irange(size())) { + ret[i] = calc_igamma(values[i], x[i]); + } + return ret; + } + Vectorized igammac(const Vectorized& x) const { + Vectorized ret; + for (const auto i : c10::irange(size())) { + ret[i] = calc_igammac(values[i], x[i]); + } + return ret; + } + Vectorized neg() const { + // NB: the trailing return type is needed because we need to coerce the + // return value back to T in the case of unary operator- incurring a + // promotion + return map([](T x) -> T { return -x; }); + } + Vectorized nextafter(const Vectorized& b) const { + Vectorized ret; + for (const auto i : c10::irange(size())) { + ret[i] = std::nextafter(values[i], b[i]); + } + return ret; + } + Vectorized round() const { + // We do not use std::round because we would like to round midway numbers to + // the nearest even integer. + return map(at::native::round_impl); + } + Vectorized sin() const { + return map(std::sin); + } + Vectorized sinh() const { + return map(std::sinh); + } + Vectorized tan() const { + return map(std::tan); + } + Vectorized tanh() const { + return map(std::tanh); + } + Vectorized trunc() const { + return map(at::native::trunc_impl); + } + Vectorized lgamma() const { + return map(std::lgamma); + } + Vectorized sqrt() const { + return map(std::sqrt); + } + Vectorized reciprocal() const { + return map([](T x) { return (T)1 / x; }); + } + Vectorized rsqrt() const { + return map([](T x) { return (T)1 / std::sqrt(x); }); + } + Vectorized pow(const Vectorized& exp) const { + Vectorized ret; + for (const auto i : c10::irange(size())) { + ret[i] = std::pow(values[i], exp[i]); + } + return ret; + } + T reduce_add() const { + return reduce([](T x, T y) -> T { return x + y; }); + } + T reduce_max() const { + return reduce(std::max); + } + + private: + template + inline Vectorized binary_pred(const Vectorized& other, Op op) const { + // All bits are set to 1 if the pred is true, otherwise 0. + Vectorized vector; + for (int64_t i = 0; i != size(); i++) { + if (op(values[i], other.values[i])) { + std::memset(static_cast(vector.values + i), 0xFF, sizeof(T)); + } else { + std::memset(static_cast(vector.values + i), 0, sizeof(T)); + } + } + return vector; + } + + public: + Vectorized operator==(const Vectorized& other) const { + return binary_pred(other, std::equal_to()); + } + Vectorized operator!=(const Vectorized& other) const { + return binary_pred(other, std::not_equal_to()); + } + Vectorized operator>=(const Vectorized& other) const { + return binary_pred(other, std::greater_equal()); + } + Vectorized operator<=(const Vectorized& other) const { + return binary_pred(other, std::less_equal()); + } + Vectorized operator>(const Vectorized& other) const { + return binary_pred(other, std::greater()); + } + Vectorized operator<(const Vectorized& other) const { + return binary_pred(other, std::less()); + } + + private: + template + inline Vectorized binary_pred_bool(const Vectorized& other, Op op) + const { + // 1 if the pred is true, otherwise 0. + Vectorized vector; + for (int i = 0; i != size(); ++i) { + vector[i] = static_cast(op(values[i], other.values[i])); + } + return vector; + } + + public: + Vectorized eq(const Vectorized& other) const { + return binary_pred_bool(other, std::equal_to()); + } + Vectorized ne(const Vectorized& other) const { + return binary_pred_bool(other, std::not_equal_to()); + } + Vectorized gt(const Vectorized& other) const { + return binary_pred_bool(other, std::greater()); + } + Vectorized ge(const Vectorized& other) const { + return binary_pred_bool(other, std::greater_equal()); + } + Vectorized lt(const Vectorized& other) const { + return binary_pred_bool(other, std::less()); + } + Vectorized le(const Vectorized& other) const { + return binary_pred_bool(other, std::less_equal()); + } +}; + +template +Vectorized inline operator-(const Vectorized& a) { + return a.neg(); +} + +// There is an implicit conversion that would make this work if +// these operators weren't template functions, but they are template +// functions (and can't be moved to be non-member friends defined in +// the class body as suggested in +// https://stackoverflow.com/questions/9787593/implicit-type-conversion-with-template/9788255#9788255 +// because we have a lot of disparate specializations of +// Vectorized). So, just explicitly make scalars work. +#define VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(name) \ + template \ + Vectorized inline name(const Vectorized& a, T b) { \ + return name(a, Vectorized(b)); \ + } \ + template \ + Vectorized inline name(T a, const Vectorized& b) { \ + return name(Vectorized(a), b); \ + } +#define VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(op) \ + VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(operator op) + +template +Vectorized inline operator+(const Vectorized& a, const Vectorized& b) { + Vectorized c; + for (int i = 0; i != Vectorized::size(); i++) { + c[i] = a[i] + b[i]; + } + return c; +} + +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(+) + +template +Vectorized inline operator-(const Vectorized& a, const Vectorized& b) { + Vectorized c; + for (int i = 0; i != Vectorized::size(); i++) { + c[i] = a[i] - b[i]; + } + return c; +} + +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(-) + +template +Vectorized inline operator*(const Vectorized& a, const Vectorized& b) { + Vectorized c; + for (int i = 0; i != Vectorized::size(); i++) { + c[i] = a[i] * b[i]; + } + return c; +} + +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(*) + +template +Vectorized inline operator/(const Vectorized& a, const Vectorized& b) + __ubsan_ignore_float_divide_by_zero__ { + Vectorized c; + for (int i = 0; i != Vectorized::size(); i++) { + c[i] = a[i] / b[i]; + } + return c; +} + +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(/) + +template , int> = 0> +Vectorized inline operator%(const Vectorized& a, const Vectorized& b) + __ubsan_ignore_float_divide_by_zero__ { + return a - a / b * b; +} + +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(%) + +template +Vectorized inline operator||( + const Vectorized& a, + const Vectorized& b) { + Vectorized c; + for (int i = 0; i != Vectorized::size(); i++) { + c[i] = a[i] || b[i]; + } + return c; +} + +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(||) + +// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if +// either input is a NaN. +template < + class T, + typename std::enable_if_t::value, int> = 0> +Vectorized inline maximum(const Vectorized& a, const Vectorized& b) { + Vectorized c; + for (int i = 0; i != Vectorized::size(); i++) { + c[i] = (a[i] > b[i]) ? a[i] : b[i]; + if (_isnan(a[i])) { + // If either input is NaN, propagate a NaN. + // NOTE: The case where b[i] was NaN is handled correctly by the naive + // ternary operator above. + c[i] = a[i]; + } + } + return c; +} + +template < + class T, + typename std::enable_if_t::value, int> = 0> +Vectorized inline maximum(const Vectorized& a, const Vectorized& b) { + Vectorized c; + for (int i = 0; i != Vectorized::size(); i++) { + c[i] = (std::abs(a[i]) > std::abs(b[i])) ? a[i] : b[i]; + if (_isnan(a[i])) { + // If either input is NaN, propagate a NaN. + // NOTE: The case where b[i] was NaN is handled correctly by the naive + // ternary operator above. + c[i] = a[i]; + } + } + return c; +} + +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(maximum) + +// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if +// either input is a NaN. +template < + class T, + typename std::enable_if_t::value, int> = 0> +Vectorized inline minimum(const Vectorized& a, const Vectorized& b) { + Vectorized c; + for (int i = 0; i != Vectorized::size(); i++) { + c[i] = (a[i] < b[i]) ? a[i] : b[i]; + if (_isnan(a[i])) { + // If either input is NaN, propagate a NaN. + // NOTE: The case where b[i] was NaN is handled correctly by the naive + // ternary operator above. + c[i] = a[i]; + } + } + return c; +} + +template < + class T, + typename std::enable_if_t::value, int> = 0> +Vectorized inline minimum(const Vectorized& a, const Vectorized& b) { + Vectorized c; + for (int i = 0; i != Vectorized::size(); i++) { + c[i] = (std::abs(a[i]) < std::abs(b[i])) ? a[i] : b[i]; + if (_isnan(a[i])) { + // If either input is NaN, propagate a NaN. + // NOTE: The case where b[i] was NaN is handled correctly by the naive + // ternary operator above. + c[i] = a[i]; + } + } + return c; +} + +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(minimum) + +template < + class T, + typename std::enable_if_t::value, int> = 0> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min_vec, + const Vectorized& max_vec) { + Vectorized c; + for (int i = 0; i != Vectorized::size(); i++) { + c[i] = std::min(std::max(a[i], min_vec[i]), max_vec[i]); + } + return c; +} + +#define VECTORIZED_SUPPORT_SCALARS_FOR_TERNARY_FUNC(name) \ + template \ + Vectorized inline name( \ + const Vectorized& a, const Vectorized& b, T c) { \ + return name(a, b, Vectorized(c)); \ + } \ + \ + template \ + Vectorized inline name( \ + const Vectorized& a, T b, const Vectorized& c) { \ + return name(a, Vectorized(b), c); \ + } \ + \ + template \ + Vectorized inline name(const Vectorized& a, T b, T c) { \ + return name(a, Vectorized(b), Vectorized(c)); \ + } \ + \ + template \ + Vectorized inline name( \ + T a, const Vectorized& b, const Vectorized& c) { \ + return name(Vectorized(a), b, c); \ + } \ + \ + template \ + Vectorized inline name(T a, const Vectorized& b, T c) { \ + return name(Vectorized(a), b, Vectorized(c)); \ + } \ + \ + template \ + Vectorized inline name(T a, T b, const Vectorized& c) { \ + return name(Vectorized(a), Vectorized(b), c); \ + } + +VECTORIZED_SUPPORT_SCALARS_FOR_TERNARY_FUNC(clamp) + +template < + class T, + typename std::enable_if_t::value, int> = 0> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max_vec) { + Vectorized c; + for (int i = 0; i != Vectorized::size(); i++) { + c[i] = a[i] > max_vec[i] ? max_vec[i] : a[i]; + } + return c; +} + +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(clamp_max) + +template < + class T, + typename std::enable_if_t::value, int> = 0> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min_vec) { + Vectorized c; + for (int i = 0; i != Vectorized::size(); i++) { + c[i] = a[i] < min_vec[i] ? min_vec[i] : a[i]; + } + return c; +} + +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(clamp_min) + +struct Vectorizedi; + +#if defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512) +template +static inline Vectorized bitwise_binary_op( + const Vectorized& a, + const Vectorized& b, + Op op) { + int_vector buffer; +#if defined(CPU_CAPABILITY_AVX2) + int_vector a_buffer = + _mm256_load_si256(reinterpret_cast((const T*)a)); + int_vector b_buffer = + _mm256_load_si256(reinterpret_cast((const T*)b)); +#elif defined(CPU_CAPABILITY_AVX512) + int_vector a_buffer = + _mm512_load_si512(reinterpret_cast((const T*)a)); + int_vector b_buffer = + _mm512_load_si512(reinterpret_cast((const T*)b)); +#endif + buffer = op(a_buffer, b_buffer); + __at_align__ T results[Vectorized::size()]; + +#if defined(CPU_CAPABILITY_AVX2) + _mm256_store_si256(reinterpret_cast(results), buffer); +#elif defined(CPU_CAPABILITY_AVX512) + _mm512_store_si512(reinterpret_cast(results), buffer); +#endif + return Vectorized::loadu(results); +} + +template < + class T, + typename std::enable_if_t< + !std::is_base_of>::value, + int> = 0> +inline Vectorized operator&(const Vectorized& a, const Vectorized& b) { + // We enclose _mm512_and_si512 or _mm256_and_si256 with lambda because it is + // always_inline +#if defined(CPU_CAPABILITY_AVX2) + return bitwise_binary_op( + a, b, [](int_vector a, int_vector b) { return _mm256_and_si256(a, b); }); +#elif defined(CPU_CAPABILITY_AVX512) + return bitwise_binary_op( + a, b, [](int_vector a, int_vector b) { return _mm512_and_si512(a, b); }); +#endif +} +template < + class T, + typename std::enable_if_t< + !std::is_base_of>::value, + int> = 0> +inline Vectorized operator|(const Vectorized& a, const Vectorized& b) { + // We enclose _mm512_or_si512 or _mm256_or_si256 with lambda because it is + // always_inline +#if defined(CPU_CAPABILITY_AVX2) + return bitwise_binary_op( + a, b, [](int_vector a, int_vector b) { return _mm256_or_si256(a, b); }); +#elif defined(CPU_CAPABILITY_AVX512) + return bitwise_binary_op( + a, b, [](int_vector a, int_vector b) { return _mm512_or_si512(a, b); }); +#endif +} +template < + class T, + typename std::enable_if_t< + !std::is_base_of>::value, + int> = 0> +inline Vectorized operator^(const Vectorized& a, const Vectorized& b) { + // We enclose _mm512_xor_si512 or _mm256_xor_si256 with lambda because it is + // always_inline +#if defined(CPU_CAPABILITY_AVX2) + return bitwise_binary_op( + a, b, [](int_vector a, int_vector b) { return _mm256_xor_si256(a, b); }); +#elif defined(CPU_CAPABILITY_AVX512) + return bitwise_binary_op( + a, b, [](int_vector a, int_vector b) { return _mm512_xor_si512(a, b); }); +#endif +} + +#else + +template +auto load(char const* data) -> T { + T ret; + std::memcpy(&ret, data, sizeof(ret)); + return ret; +} + +template +static inline Vectorized bitwise_binary_op( + const Vectorized& a, + const Vectorized& b, + Op op) { + static constexpr uint32_t element_no = VECTOR_WIDTH / sizeof(intmax_t); + __at_align__ intmax_t buffer[element_no]; + static_assert( + VECTOR_WIDTH % sizeof(intmax_t) == 0, + "VECTOR_WIDTH not a multiple of sizeof(intmax_t)"); + static_assert( + sizeof(buffer) == sizeof(Vectorized), + "sizeof(buffer) must match sizeof(Vectorized)"); + // We should be using memcpy in order to respect the strict aliasing rule + // see: https://github.com/pytorch/pytorch/issues/66119 + // Using char* is defined in the C11 standard 6.5 Expression paragraph 7 + // (http://www.open-std.org/jtc1/sc22/wg14/www/docs/n1570.pdf) + const auto* a_data = a.as_bytes(); + const auto* b_data = b.as_bytes(); + // load each intmax_t chunk and process; increase pointers by sizeof(intmax_t) + for (auto& out : buffer) { + out = op(load(a_data), load(b_data)); + a_data += sizeof(intmax_t); + b_data += sizeof(intmax_t); + } + assert(a_data == a.as_bytes() + sizeof(a)); + assert(b_data == b.as_bytes() + sizeof(b)); + return Vectorized::loadu(buffer); +} + +template < + class T, + typename std:: + enable_if_t>, int> = 0> +inline Vectorized operator&(const Vectorized& a, const Vectorized& b) { + return bitwise_binary_op(a, b, std::bit_and()); +} +template < + class T, + typename std:: + enable_if_t>, int> = 0> +inline Vectorized operator|(const Vectorized& a, const Vectorized& b) { + return bitwise_binary_op(a, b, std::bit_or()); +} +template < + class T, + typename std:: + enable_if_t>, int> = 0> +inline Vectorized operator^(const Vectorized& a, const Vectorized& b) { + return bitwise_binary_op(a, b, std::bit_xor()); +} + +#endif // defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512) + +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(&) +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(|) +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(^) + +template < + class T, + typename std:: + enable_if_t>, int> = 0> +inline Vectorized operator~(const Vectorized& a) { + using int_t = int_same_size_t; + Vectorized ones(c10::bit_cast((int_t)(~(int_t)0))); // All bits are 1 + return a ^ ones; +} + +template +Vectorized inline operator<<( + const Vectorized& a, + const Vectorized& b) { + constexpr T max_shift = sizeof(T) * CHAR_BIT; + Vectorized c; + for (int i = 0; i != Vectorized::size(); i++) { + T shift = b[i]; + if ((static_cast>(shift) < 0) || + (shift >= max_shift)) { + c[i] = 0; + } else { + c[i] = static_cast>(a[i]) << shift; + } + } + return c; +} + +template +Vectorized inline operator>>( + const Vectorized& a, + const Vectorized& b) { + // right shift value to retain sign bit for signed and no bits for unsigned + constexpr T max_shift = sizeof(T) * CHAR_BIT - std::is_signed_v; + Vectorized c; + for (int i = 0; i != Vectorized::size(); i++) { + T shift = b[i]; + if ((static_cast>(shift) < 0) || + (shift >= max_shift)) { + c[i] = a[i] >> max_shift; + } else { + c[i] = a[i] >> shift; + } + } + return c; +} + +template +inline Vectorized& operator+=(Vectorized& a, const Vectorized& b) { + a = a + b; + return a; +} +template +inline Vectorized& operator-=(Vectorized& a, const Vectorized& b) { + a = a - b; + return a; +} +template +inline Vectorized& operator/=(Vectorized& a, const Vectorized& b) { + a = a / b; + return a; +} +template +inline Vectorized& operator%=(Vectorized& a, const Vectorized& b) { + a = a % b; + return a; +} +template +inline Vectorized& operator*=(Vectorized& a, const Vectorized& b) { + a = a * b; + return a; +} + +template +inline Vectorized& operator<<=(Vectorized& a, const Vectorized& b) { + a = a << b; + return a; +} + +template +inline Vectorized& operator>>=(Vectorized& a, const Vectorized& b) { + a = a >> b; + return a; +} + +template +inline Vectorized fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return a * b + c; +} + +VECTORIZED_SUPPORT_SCALARS_FOR_TERNARY_FUNC(fmadd) + +template +inline Vectorized fnmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return -(a * b) + c; +} + +VECTORIZED_SUPPORT_SCALARS_FOR_TERNARY_FUNC(fnmadd) + +template +inline Vectorized fmsub( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return a * b - c; +} + +VECTORIZED_SUPPORT_SCALARS_FOR_TERNARY_FUNC(fmsub) + +template +inline Vectorized fnmsub( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { + return -(a * b) - c; +} + +VECTORIZED_SUPPORT_SCALARS_FOR_TERNARY_FUNC(fnmsub) + +template +Vectorized inline operator&&( + const Vectorized& a, + const Vectorized& b) { + Vectorized ret; + for (int i = 0; i != Vectorized::size(); i++) { + ret[i] = a[i] && b[i]; + } + return ret; +} + +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(&&) + +template +std::enable_if_t< + scale == 1 || scale == 2 || scale == 4 || scale == 8, + Vectorized< + T>> inline gather(T const* base_addr, const Vectorized>& vindex) { + static constexpr int size = Vectorized::size(); + int_same_size_t index_arr[size]; + vindex.store(static_cast(index_arr)); + T buffer[size]; + for (const auto i : c10::irange(size)) { + buffer[i] = base_addr[index_arr[i] * scale / sizeof(T)]; + } + return Vectorized::loadu(static_cast(buffer)); +} + +template +std:: + enable_if_t> inline mask_gather( + const Vectorized& src, + T const* base_addr, + const Vectorized>& vindex, + Vectorized& mask) { + static constexpr int size = Vectorized::size(); + T src_arr[size]; + int_same_size_t mask_arr[size]; // use int type so we can logical and + int_same_size_t index_arr[size]; + src.store(static_cast(src_arr)); + mask.store(static_cast(mask_arr)); + vindex.store(static_cast(index_arr)); + T buffer[size]; + for (const auto i : c10::irange(size)) { + if (mask_arr[i] & 0x01) { // check highest bit + buffer[i] = base_addr[index_arr[i] * scale / sizeof(T)]; + } else { + buffer[i] = src_arr[i]; + } + } + mask = Vectorized(static_cast(0)); // "zero out" mask + return Vectorized::loadu(static_cast(buffer)); +} + +// Cast a given vector to another type without changing the bits representation. +// So a Vectorized of 512 bits containing all ones can be cast to a +// Vectorized of 512 bits containing all ones (i.e., eight negative +// 1s). A Vec of 256 bits containing all ones can be cast to a +// Vec of 256 bits containing all ones (i.e., four negative 1s). +// There is a struct here because we don't have static_if and I can't +// partially specialize a templated function. +template +struct CastImpl { + static inline Vectorized apply(const Vectorized& src) { + src_t src_arr[Vectorized::size()]; + src.store(static_cast(src_arr)); + return Vectorized::loadu(static_cast(src_arr)); + } +}; + +template +struct CastImpl { + static inline Vectorized apply(const Vectorized& src) { + return src; + } +}; + +template +inline Vectorized cast(const Vectorized& src) { + return CastImpl::apply(src); +} + +template > +inline Vectorized convert_to_int_of_same_size( + const Vectorized& src) { + static_assert(sizeof(T) == sizeof(IntType)); + static constexpr int size = Vectorized::size(); + + std::array src_arr = {}; + src.store(static_cast(src_arr.data())); + std::array buffer; + std::transform( + src_arr.cbegin(), src_arr.cend(), buffer.begin(), [](const T& x) { + return static_cast(x); + }); + return Vectorized::loadu(static_cast(buffer.data())); +} + +template > +inline Vectorized convert_to_fp_of_same_size( + const Vectorized& src) { + static_assert(sizeof(T) == sizeof(IntType)); + static constexpr int size = Vectorized::size(); + + std::array src_arr; + src.store(static_cast(src_arr.data())); + std::array buffer; + std::transform( + src_arr.cbegin(), src_arr.cend(), buffer.begin(), [](const IntType& x) { + return static_cast(x); + }); + return Vectorized::loadu(static_cast(buffer.data())); +} + +// clang-format off +// Example inputs for AVX512: +// a Vectorized = {a0, b0, a1, b1, a2, b2, a3, b3, a4, b4, a5, b5, a6, b6, a7, b7} +// b Vectorized = {a8, b8, a9, b9, a10, b10, a11, b11, a12, b12, a13, b13, a14, b14, a15, b15} +// returns: +// Vectorized = {a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15} +// Vectorized = {b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15} +// Example inputs for AVX2: a Vectorized = {a0, b0, a1, b1, a2, b2, a3, b3} +// b Vectorized = {a4, b4, a5, b5, a6, b6, a7, b7} +// returns: Vectorized = {a0, a1, a2, a3, a4, a5, a6, a7} +// Vectorized = {b0, b1, b2, b3, b4, b5, b6, b7} +// clang-format on +template +inline std::enable_if_t< + Vectorized::size() % 2 == 0, + std::pair, Vectorized>> +deinterleave2(const Vectorized& a, const Vectorized& b) { + static constexpr int size = Vectorized::size(); + static constexpr int half_size = size / 2; + T a_arr[size]; + T b_arr[size]; + T buffer1[size]; + T buffer2[size]; + a.store(static_cast(a_arr)); + b.store(static_cast(b_arr)); + for (const auto i : c10::irange(half_size)) { + buffer1[i] = a_arr[i * 2]; + buffer1[half_size + i] = b_arr[i * 2]; + buffer2[i] = a_arr[i * 2 + 1]; + buffer2[half_size + i] = b_arr[i * 2 + 1]; + } + return std::make_pair( + Vectorized::loadu(static_cast(buffer1)), + Vectorized::loadu(static_cast(buffer2))); +} + +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(deinterleave2) + +// clang-format off +// inverse operation of deinterleave2 +// Example inputs for AVX512: +// a Vectorized = {a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15} +// b Vectorized = {b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15} +// returns, for AVX512: +// Vectorized = {a0, b0, a1, b1, a2, b2, a3, b3, a4, b4, a5, b5, a6, b6, a7, b7} +// Vectorized = {a8, b8, a9, b9, a10, b10, a11, b11, a12, b12, a13, b13, a14, b14, a15, b15} +// Example inputs for AVX2 : a Vectorized = {a0, a1, a2, a3, a4, a5, a6, a7} +// b Vectorized = {b0, b1, b2, b3, b4, b5, b6, b7} +// returns: Vectorized = {a0, b0, a1, b1, a2, b2, a3, b3} +// Vectorized = {a4, b4, a5, b5, a6, b6, a7, b7} +// clang-format on +template +inline std::enable_if_t< + Vectorized::size() % 2 == 0, + std::pair, Vectorized>> +interleave2(const Vectorized& a, const Vectorized& b) { + static constexpr int size = Vectorized::size(); + static constexpr int half_size = size / 2; + T a_arr[size]; + T b_arr[size]; + T buffer1[size]; + T buffer2[size]; + a.store(static_cast(a_arr)); + b.store(static_cast(b_arr)); + for (const auto i : c10::irange(half_size)) { + buffer1[i * 2] = a_arr[i]; + buffer1[i * 2 + 1] = b_arr[i]; + buffer2[i * 2] = a_arr[half_size + i]; + buffer2[i * 2 + 1] = b_arr[half_size + i]; + } + return std::make_pair( + Vectorized::loadu(static_cast(buffer1)), + Vectorized::loadu(static_cast(buffer2))); +} + +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(interleave2) + +#undef VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC +#undef VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP +#undef VECTORIZED_SUPPORT_SCALARS_FOR_TERNARY_FUNC + +template +inline void convert(const src_T* src, dst_T* dst, int64_t n) { +#ifndef _MSC_VER +#pragma unroll +#endif + for ([[maybe_unused]] const auto i : c10::irange(n)) { + *dst = c10::convert(c10::load(src)); + src++; + dst++; + } +} + +template +inline Vectorized flip(const Vectorized& data) { + static constexpr int size = Vectorized::size(); + T output[size]; + T buffer[size]; + data.store(static_cast(buffer)); + for (const auto i : c10::irange(size)) { + output[i] = buffer[size - i - 1]; + } + return Vectorized::loadu(static_cast(output)); +} + +// Transpose the `src` buffer of type `T` and size (M,N) into the `dst` buffer. +// `ld_src` is the leading dimension of `src` and `ld_dst` is the leading +// dimension of `dst`. +template +inline void transpose_mxn( + const T* src, + int64_t ld_src, + T* dst, + int64_t ld_dst, + int M, + int N) { + for (int i = 0; i < M; i++) { + for (int j = 0; j < N; j++) { + dst[j * ld_dst + i] = src[i * ld_src + j]; + } + } +} + +template +inline void transpose_mxn( + const T* src, + int64_t ld_src, + T* dst, + int64_t ld_dst) { + transpose_mxn(src, ld_src, dst, ld_dst, M, N); +} + +} // namespace CPU_CAPABILITY +} // namespace at::vec + +// additional headers for more operations that depend on vec_base +#include +#include +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec_convert.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec_convert.h new file mode 100644 index 0000000000000000000000000000000000000000..bdeeb6aae83470a41f9a238a726e74e6d68e80c3 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec_convert.h @@ -0,0 +1,84 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +namespace at::vec { +inline namespace CPU_CAPABILITY { + +template < + typename dst_t, + int dst_n, + typename src_t, + int src_n, + typename Enabled = void> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + constexpr int count = std::min( + VectorizedN::size(), VectorizedN::size()); + __at_align__ src_t src_buf[VectorizedN::size()]; + src.store(src_buf); + __at_align__ dst_t dst_buf[VectorizedN::size()]; + for (int i = 0; i < count; i++) { + dst_buf[i] = static_cast(src_buf[i]); + } + return VectorizedN::loadu(dst_buf, count); + } +}; + +template +inline std::enable_if_t, Vectorized> convert( + const Vectorized& src) { + return src; +} + +template +inline std::enable_if_t, Vectorized> +convert(const Vectorized& src) { + return VecConvert::apply(src); +} + +template < + typename dst_t, + int dst_n, + typename src_t, + int src_n, + std::enable_if_t = 0> +inline VectorizedN convert(const VectorizedN& src) { + return VecConvert::apply(src); +} + +template < + typename dst_t, + int dst_n, + typename src_t, + int src_n, + bool keep = false, + std::enable_if_t = 0> +inline std::conditional_t, Vectorized> +convert(const VectorizedN& src) { + return VecConvert::apply(src); +} + +} // namespace CPU_CAPABILITY + +template < + typename scalar_t, + typename std::enable_if_t, int> = 0> +inline std::tuple, Vectorized> convert_to_float( + const Vectorized&); + +template < + typename scalar_t, + typename std::enable_if_t, int> = 0> +inline Vectorized convert_from_float( + const Vectorized&, + const Vectorized&); + +} // namespace at::vec + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec_half.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec_half.h new file mode 100644 index 0000000000000000000000000000000000000000..0d5395ca15d6fbeaaf1c46b16bca2fd3382c9f8c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec_half.h @@ -0,0 +1,123 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +#include + +namespace at::vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +// Transpose a [2, 32] matrix to [32, 2] +// Note: the output leading dimension should be 2, +// that is, the output must be contiguous +template > +static inline void transpose_pad_2x32_block( + const scalar_t* src, + scalar_t* dst, + int64_t ld_src, + int krem = 2, + int nrem = 32) { +#if defined(CPU_CAPABILITY_AVX512) + __m512i r0, r1; + __m512i d0, d1; + // load + if (nrem < 32) { + __mmask32 mask_krem_v = (1LL << nrem) - 1; + r0 = _mm512_maskz_loadu_epi16(mask_krem_v, src); + // if krem is not 2, pad with zeros + if (krem == 2) { + r1 = _mm512_maskz_loadu_epi16(mask_krem_v, src + ld_src); + } else { + r1 = _mm512_setzero_si512(); + } + } else { + r0 = _mm512_loadu_si512(reinterpret_cast(src)); + if (krem == 2) { + r1 = _mm512_loadu_si512(reinterpret_cast(src + ld_src)); + } else { + r1 = _mm512_setzero_si512(); + } + } + // transpose + d0 = _mm512_unpacklo_epi16(r0, r1); + d1 = _mm512_unpackhi_epi16(r0, r1); + r0 = _mm512_shuffle_i32x4(d0, d1, 0x88); + r1 = _mm512_shuffle_i32x4(d0, d1, 0xdd); + d0 = _mm512_shuffle_i32x4(r0, r1, 0x88); + d1 = _mm512_shuffle_i32x4(r0, r1, 0xdd); + + // store + if (nrem < 16) { + __mmask32 mask_rem_v = (1LL << (nrem * 2)) - 1; + _mm512_mask_storeu_epi16(dst, mask_rem_v, d0); + } else if (nrem == 16) { + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), d0); + } else if (nrem < 32) { + __mmask32 mask_rem_v = (1LL << (nrem * 2 - 32)) - 1; + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), d0); + _mm512_mask_storeu_epi16( + reinterpret_cast<__m512i*>(dst + 32), mask_rem_v, d1); + } else { + // normal store + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), d0); + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 32), d1); + } +#else + TORCH_CHECK( + false, + "transpose_pad_2x32_block is only supported when avx512 is supported") +#endif +} + +// To use AMX to accelerate GEMM, +// reorder the memory format [K, N] -> [K/2, N, 2] +// Note: If K % 2 != 0, pad K implicitly +template > +static inline void pack_vnni2( + const scalar_t* src, + scalar_t* dst, + int64_t ld_src, + int64_t K, + int64_t N) { +#if defined(CPU_CAPABILITY_AVX512) + int64_t bk = 0; + int64_t _K = K / 2 * 2; + int64_t _N = N / 32 * 32; + for (; bk < _K; bk += 2) { + int64_t bn = 0; + for (; bn < _N; bn += 32) { + transpose_pad_2x32_block( + src + bk * ld_src + bn, dst + bk * N + bn * 2, ld_src); + } + int64_t nrem = N - bn; + if (nrem > 0) { + transpose_pad_2x32_block( + src + bk * ld_src + bn, dst + bk * N + bn * 2, ld_src, 2, nrem); + } + } + if (K % 2 == 1) { + int64_t bn = 0; + for (; bn < _N; bn += 32) { + transpose_pad_2x32_block( + src + bk * ld_src + bn, dst + bk * N + bn * 2, ld_src, 1); + } + int64_t nrem = N - bn; + if (nrem > 0) { + transpose_pad_2x32_block( + src + bk * ld_src + bn, dst + bk * N + bn * 2, ld_src, 1, nrem); + } + } +#else + TORCH_CHECK(false, "pack_vnni2 is only supported when avx512 is supported") +#endif +} + +} // namespace CPU_CAPABILITY +} // namespace at::vec + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec_mask.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec_mask.h new file mode 100644 index 0000000000000000000000000000000000000000..509e79cfd16c12d1f66edde31eed0f114ca40f8d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec_mask.h @@ -0,0 +1,318 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +namespace at::vec { +inline namespace CPU_CAPABILITY { + +/** + * The `VecMask` class provides a convenient interface for working with + * vectorized masks in SIMD operations. It encapsulates a `Vectorized` + * mask that can be directly usable in masked vectorized operations. It provides + * various methods for manipulating and accessing the mask elements: + * 1. `from` and `to`: Conversion between a vector of boolean values and a + * vectorized mask. + * 2. `cast`: Casts the mask to a different base type. + * 3. `all_zero`: Checks if all mask elements are zero. + * 4. `is_masked`: Checks if a specific element is masked. + * 5. `loadu`: Loads data from memory using the mask. + * 6. `all_masked`: Checks if all mask elements are masked. + * + * Some helper template classes are provided to simplify the specialization of + * the `VecMask` for the specific CPU arch: + * 1. `VecMaskLoad`: Loads data from memory using the mask. + * 2. `VecMaskTo`: Converts the mask to boolean. + * 3. `VecMaskCast`: Casts the mask to a different base type. + * + */ +template +class VecMask; + +template < + typename data_t, + int data_n, + typename mask_t, + int mask_n, + typename Enabled = void> +struct VecMaskLoad { + static inline VectorizedN apply( + const data_t* ptr, + const VecMask& vec_mask) { + constexpr typename VecMask::size_type size = + VecMask::size(); + static_assert(VectorizedN::size() >= size); + __at_align__ data_t data[size]; + __at_align__ mask_t mask[size]; + auto mask_ = VectorizedN(vec_mask); + mask_.store(mask); + for (int i = 0; i < size; i++) { + data[i] = mask[i] ? ptr[i] : static_cast(0); + } + return VectorizedN::loadu(data, size); + } +}; + +template < + typename dst_t, + int dst_n, + typename src_t, + int src_n, + typename Enabled = void> +struct VecMaskTo { + static inline VecMask apply( + const VecMask& vec_mask) { + auto zeros = VectorizedN(static_cast(0)); + auto ones = VectorizedN(static_cast(1)); + return VectorizedN::blendv( + zeros, ones, vec_mask.template cast()); + } +}; + +template < + typename dst_t, + int dst_n, + typename src_t, + int src_n, + typename Enabled = void> +struct VecMaskCast { + static inline VecMask apply( + const VecMask& vec_mask) { + return VecMask::from(VectorizedN(vec_mask)); + } +}; + +template +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + return vec_mask; + } +}; + +template +struct VecMaskCheck { + static inline bool all_zero(const VectorizedN& vec_mask) { + __at_align__ T mask[VectorizedN::size()]; + vec_mask.store(mask); + return std::all_of(mask, mask + VectorizedN::size(), [](T m) { + return m == static_cast(0); + }); + } + + static inline bool all_masked(const VectorizedN& vec_mask) { + __at_align__ T mask[VectorizedN::size()]; + vec_mask.store(mask); + return std::all_of(mask, mask + VectorizedN::size(), [](T m) { + return m != static_cast(0); + }); + } + + static inline bool is_masked(const VectorizedN& vec_mask, int i) { + __at_align__ T mask[VectorizedN::size()]; + vec_mask.store(mask); + return mask[i] != static_cast(0); + } +}; + +template +class VecMask { + public: + using size_type = int; + static constexpr size_type size() { + return VectorizedN::size(); + } + + private: + VectorizedN mask_; + + public: + VecMask() : mask_(static_cast(0)) {} + VecMask(const VectorizedN& mask) : mask_(mask) {} + + template = 0> + VecMask(const Vectorized& mask) : mask_(mask) {} + + template + static VecMask from(const VectorizedN& b_vec) { + __at_align__ U b_buf[size()]; + if constexpr (size() >= VectorizedN::size()) { + b_vec.store(b_buf); + for (int i = VectorizedN::size(); i < size(); i++) { + b_buf[i] = static_cast(0); + } + } else { + b_vec.store(b_buf, size()); + } + return from(b_buf); + } + + template + static VecMask from(U b) { + using int_t = int_same_size_t; + T mask = b ? c10::bit_cast((int_t)(~(int_t)0)) : (T)0; + return VectorizedN(mask); + } + + template + static VecMask from(U* b) { + using int_t = int_same_size_t; + __at_align__ T mask[size()]; +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (int i = 0; i < size(); i++) { + *(int_t*)(mask + i) = b[i] ? ~(int_t)0 : (int_t)0; + } + return VectorizedN(VectorizedN::loadu(mask)); + } + + template + static VecMask from(U* b, int count) { + using int_t = int_same_size_t; + __at_align__ T mask[size()]; +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (int i = 0; i < count; i++) { + *(int_t*)(mask + i) = b[i] ? ~(int_t)0 : (int_t)0; + } + return VectorizedN(VectorizedN::loadu(mask, count)); + } + + static VecMask blendv( + const VecMask& c, + const VecMask& b, + const VecMask& a) { + VectorizedN result = VectorizedN::blendv( + VectorizedN(c), VectorizedN(b), VectorizedN(a)); + return result; + } + + static VecMask set( + const VecMask& a, + const VecMask& b, + int64_t count = size()) { + VectorizedN result = VectorizedN::set( + VectorizedN(a), VectorizedN(b), count); + return result; + } + + void store(bool* b, int count = size()) { + constexpr int L = + (VectorizedN::size() + Vectorized::size() - 1) / + Vectorized::size(); + auto res = this->to(); + res.store(b, count); + return; + } + + template = 2, int> = 0> + inline VectorizedN to() const { + return VecMaskTo::apply(*this); + } + + template = 0> + inline Vectorized to() const { + return VecMaskTo::apply(*this); + } + + template + inline VecMask cast() const { + return VecMaskCast::apply(*this); + } + + inline bool all_zero() const { + return VecMaskCheck::all_zero(mask_); + } + + inline bool all_masked() const { + return VecMaskCheck::all_masked(mask_); + } + + inline bool is_masked(int i) const { + return VecMaskCheck::is_masked(mask_, i); + } + + inline operator VectorizedN() const { + return mask_; + } + + template = 0> + inline operator Vectorized() const { + return mask_[0]; + } + + inline Vectorized operator[](int i) const { + return mask_[i]; + } + + template < + typename U, + int L, + std::enable_if_t= 2 && VectorizedN::size() >= size(), int> = 0> + VectorizedN loadu(const U* ptr) const { + return VecMaskLoad::apply(ptr, *this); + } + + template < + typename U, + int L, + std::enable_if_t::size() >= size(), int> = 0> + Vectorized loadu(const U* ptr) const { + return VecMaskLoad::apply(ptr, *this); + } +}; + +#define VEC_MASK_DEFINE_UNARY_OP_GLOBAL(op) \ + template \ + inline VecMask op(const VecMask& a) { \ + return op(VectorizedN(a)); \ + } + +#define VEC_MASK_DEFINE_BINARY_OP_GLOBAL(op) \ + template < \ + typename T, \ + int N, \ + typename V, \ + int M, \ + std::enable_if_t::size() == VecMask::size(), int> = \ + 0> \ + inline VecMask op(const VecMask& a, const VecMask& b) { \ + return op( \ + VectorizedN(a), VectorizedN(b.template cast())); \ + } + +#define VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(op, EXPR) \ + template < \ + typename T, \ + int N, \ + typename V, \ + int M, \ + std::enable_if_t::size() == VecMask::size(), int> = \ + 0> \ + inline VecMask op(const VecMask& a, const VecMask& b) { \ + return EXPR; \ + } + +VEC_MASK_DEFINE_UNARY_OP_GLOBAL(operator~) +VEC_MASK_DEFINE_BINARY_OP_GLOBAL(operator&) +VEC_MASK_DEFINE_BINARY_OP_GLOBAL(operator|) +VEC_MASK_DEFINE_BINARY_OP_GLOBAL(operator^) +VEC_MASK_DEFINE_BINARY_OP_GLOBAL(operator*) +VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator>, a & ~b) +VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator<, ~a& b) +VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator==, ~(a ^ b)) +VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator>=, (a == b) | (a > b)) +VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator<=, (a == b) | (a < b)) +VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator!=, (a ^ b)) + +#undef VEC_MASK_DEFINE_UNARY_OP_GLOBAL +#undef VEC_MASK_DEFINE_BINARY_OP_GLOBAL +#undef VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL + +} // namespace CPU_CAPABILITY +} // namespace at::vec + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec_n.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec_n.h new file mode 100644 index 0000000000000000000000000000000000000000..5e7ed2de74177d868f3c11ef36a49f79986e2bc7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec_n.h @@ -0,0 +1,412 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +namespace at::vec { +inline namespace CPU_CAPABILITY { + +/** + * @brief A class template representing a vectorized type with + * `N * Vectorized::size()` elements, aiming to support vectors of + * arbitrary size. A specific use case of it is to represent vectors + * converted from data types with different sizes but with the same + * number of vector elements, e.g., `VectorizedN` can be + * a vector converted from two `Vectorized`, `VectorizedN` + * can be a vector converted from two `Vectorized` etc. + * + * It supports most of the operations of `Vectorized` + * and the implementation delegates to `Vectorized` with loops over `N`. + * + * @tparam T The underlying type of the vectorized elements. + * @tparam N The number of underlying `Vectorized`. + */ +template +class VectorizedN { + public: + using value_type = T; + using size_type = int; + + static constexpr size_type size_T = sizeof(T); + static constexpr size_type size() { + return Vectorized::size() * N; + } + + private: + std::array, N> values; + + public: + // methods not implemented yet: + // variadic constructor, operator T*, as_bytes, zero_mask + +#define VECTORIZEDN_DEFINE_UNARY_OP(op) \ + VectorizedN op() const { \ + return unary_op([](const Vectorized& a) { return a.op(); }); \ + } + +#define VECTORIZEDN_DEFINE_BINARY_OP(op) \ + VectorizedN op(const VectorizedN& other) const { \ + return binary_op( \ + other, [](const Vectorized& a, const Vectorized& b) { \ + return a.op(b); \ + }); \ + } + + template + inline VectorizedN unary_op(Op op) const { + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < N; ++i) { + result.values[i] = op(values[i]); + } + return result; + } + + template + inline VectorizedN binary_op(const VectorizedN& other, Op op) + const { + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < N; ++i) { + result.values[i] = op(values[i], other.values[i]); + } + return result; + } + + template + inline VectorizedN ternary_op( + const VectorizedN& other, + const VectorizedN& other2, + Op op) const { + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < N; ++i) { + result.values[i] = op(values[i], other.values[i], other2.values[i]); + } + return result; + } + + VectorizedN() = default; + + explicit VectorizedN(T val) { + for (int i = 0; i < N; ++i) { + values[i] = Vectorized(val); + } + } + + template = 0> + VectorizedN(const Vectorized& val) : values({val}) {} + + template = 0> + VectorizedN(const Vectorized& val_0, const Vectorized& val_1) + : values({val_0, val_1}) {} + + template = 0> + inline operator Vectorized() const { + return values[0]; + } + + inline const Vectorized& operator[](int i) const { + return values[i]; + } + + inline Vectorized& operator[](int i) { + return values[i]; + } + + template + static VectorizedN blend( + const VectorizedN& a, + const VectorizedN& b) { + VectorizedN result; + for (int i = 0; i < N; ++i) { + result.values[i] = + Vectorized::template blend(a.values[i], b.values[i]); + } + return result; + } + + static VectorizedN blendv( + const VectorizedN& a, + const VectorizedN& b, + const VectorizedN& mask) { + VectorizedN result; + for (int i = 0; i < N; ++i) { + result.values[i] = + Vectorized::blendv(a.values[i], b.values[i], mask.values[i]); + } + return result; + } + + template + static VectorizedN arange( + T base = static_cast(0), + step_t step = static_cast(1)) { + VectorizedN result; + for (int i = 0; i < N; ++i) { + result.values[i] = Vectorized::arange(base, step); + base += step * Vectorized::size(); + } + return result; + } + + static VectorizedN set( + const VectorizedN& a, + const VectorizedN& b, + int64_t count = size()) { + VectorizedN result; + for (int i = 0; i < N; ++i) { + if (count > 0) { + result.values[i] = Vectorized::set( + a.values[i], + b.values[i], + std::min(count, (int64_t)Vectorized::size())); + count -= Vectorized::size(); + } else { + result.values[i] = a.values[i]; + } + } + return result; + } + + static VectorizedN loadu(const void* ptr) { + VectorizedN result; + for (int i = 0; i < N; ++i) { + result.values[i] = Vectorized::loadu(ptr); + ptr = static_cast(ptr) + Vectorized::size(); + } + return result; + } + + static VectorizedN loadu(const void* ptr, int64_t count) { + VectorizedN result; + for (int i = 0; i < N; ++i) { + if (count > 0) { + result.values[i] = Vectorized::loadu( + ptr, std::min(count, (int64_t)Vectorized::size())); + ptr = static_cast(ptr) + Vectorized::size(); + count -= Vectorized::size(); + } else { + result.values[i] = Vectorized((T)1); + } + } + return result; + } + + void store(void* ptr) const { + for (int i = 0; i < N; ++i) { + values[i].store(ptr); + ptr = static_cast(ptr) + Vectorized::size(); + } + } + + void store(void* ptr, int count) const { + for (int i = 0; i < N; ++i) { + values[i].store(ptr, std::min(count, (int)Vectorized::size())); + ptr = static_cast(ptr) + Vectorized::size(); + count -= Vectorized::size(); + if (count <= 0) { + break; + } + } + } + + bool has_inf_nan() const { + for (int i = 0; i < N; ++i) { + if (values[i].has_inf_nan()) { + return true; + } + } + return false; + } + + VectorizedN map(T (*const f)(T)) const { + VectorizedN result; + for (int i = 0; i < N; ++i) { + result.values[i] = values[i].map(f); + } + return result; + } + + VectorizedN map(T (*const f)(const T&)) const { + VectorizedN result; + for (int i = 0; i < N; ++i) { + result.values[i] = values[i].map(f); + } + return result; + } + + VECTORIZEDN_DEFINE_UNARY_OP(isnan) + VECTORIZEDN_DEFINE_UNARY_OP(abs) + VECTORIZEDN_DEFINE_UNARY_OP(sgn) + VECTORIZEDN_DEFINE_UNARY_OP(angle) + VECTORIZEDN_DEFINE_UNARY_OP(real) + VECTORIZEDN_DEFINE_UNARY_OP(imag) + VECTORIZEDN_DEFINE_UNARY_OP(conj) + VECTORIZEDN_DEFINE_UNARY_OP(acos) + VECTORIZEDN_DEFINE_UNARY_OP(acosh) + VECTORIZEDN_DEFINE_UNARY_OP(asin) + VECTORIZEDN_DEFINE_UNARY_OP(asinh) + VECTORIZEDN_DEFINE_UNARY_OP(atan) + VECTORIZEDN_DEFINE_UNARY_OP(atanh) + VECTORIZEDN_DEFINE_BINARY_OP(atan2) + VECTORIZEDN_DEFINE_BINARY_OP(copysign) + VECTORIZEDN_DEFINE_UNARY_OP(erf) + VECTORIZEDN_DEFINE_UNARY_OP(erfc) + VECTORIZEDN_DEFINE_UNARY_OP(erfinv) + VECTORIZEDN_DEFINE_UNARY_OP(exp) + VECTORIZEDN_DEFINE_UNARY_OP(exp2) + VECTORIZEDN_DEFINE_UNARY_OP(expm1) + VECTORIZEDN_DEFINE_UNARY_OP(exp_u20) + VECTORIZEDN_DEFINE_UNARY_OP(fexp_u20) + VECTORIZEDN_DEFINE_UNARY_OP(frac) + VECTORIZEDN_DEFINE_BINARY_OP(fmod) + VECTORIZEDN_DEFINE_UNARY_OP(log) + VECTORIZEDN_DEFINE_UNARY_OP(log10) + VECTORIZEDN_DEFINE_UNARY_OP(log1p) + VECTORIZEDN_DEFINE_UNARY_OP(log2) + VECTORIZEDN_DEFINE_UNARY_OP(ceil) + VECTORIZEDN_DEFINE_UNARY_OP(cos) + VECTORIZEDN_DEFINE_UNARY_OP(cosh) + VECTORIZEDN_DEFINE_UNARY_OP(floor) + VECTORIZEDN_DEFINE_BINARY_OP(hypot) + VECTORIZEDN_DEFINE_UNARY_OP(i0) + VECTORIZEDN_DEFINE_UNARY_OP(i0e) + VECTORIZEDN_DEFINE_UNARY_OP(digamma) + VECTORIZEDN_DEFINE_BINARY_OP(igamma) + VECTORIZEDN_DEFINE_BINARY_OP(igammac) + VECTORIZEDN_DEFINE_UNARY_OP(neg) + VECTORIZEDN_DEFINE_BINARY_OP(nextafter) + VECTORIZEDN_DEFINE_UNARY_OP(round) + VECTORIZEDN_DEFINE_UNARY_OP(sin) + VECTORIZEDN_DEFINE_UNARY_OP(sinh) + VECTORIZEDN_DEFINE_UNARY_OP(tan) + VECTORIZEDN_DEFINE_UNARY_OP(tanh) + VECTORIZEDN_DEFINE_UNARY_OP(trunc) + VECTORIZEDN_DEFINE_UNARY_OP(lgamma) + VECTORIZEDN_DEFINE_UNARY_OP(sqrt) + VECTORIZEDN_DEFINE_UNARY_OP(reciprocal) + VECTORIZEDN_DEFINE_UNARY_OP(rsqrt) + VECTORIZEDN_DEFINE_BINARY_OP(pow) + VECTORIZEDN_DEFINE_BINARY_OP(operator==) + VECTORIZEDN_DEFINE_BINARY_OP(operator!=) + VECTORIZEDN_DEFINE_BINARY_OP(operator>=) + VECTORIZEDN_DEFINE_BINARY_OP(operator<=) + VECTORIZEDN_DEFINE_BINARY_OP(operator>) + VECTORIZEDN_DEFINE_BINARY_OP(operator<) + VECTORIZEDN_DEFINE_BINARY_OP(eq) + VECTORIZEDN_DEFINE_BINARY_OP(ne) + VECTORIZEDN_DEFINE_BINARY_OP(gt) + VECTORIZEDN_DEFINE_BINARY_OP(ge) + VECTORIZEDN_DEFINE_BINARY_OP(lt) + VECTORIZEDN_DEFINE_BINARY_OP(le) + +#undef VECTORIZEDN_DEFINE_UNARY_OP +#undef VECTORIZEDN_DEFINE_BINARY_OP +}; + +#define VECTORIZEDN_DEFINE_UNARY_OP_GLOBAL(op) \ + template \ + inline VectorizedN op(const VectorizedN& a) { \ + return a.unary_op([](const Vectorized& a) { return op(a); }); \ + } + +#define VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(op) \ + template \ + inline VectorizedN op( \ + const VectorizedN& a, const VectorizedN& b) { \ + return a.binary_op(b, [](const Vectorized& a, const Vectorized& b) { \ + return op(a, b); \ + }); \ + } + +#define VECTORIZEDN_DEFINE_TERNARY_OP_GLOBAL(op) \ + template \ + inline VectorizedN op( \ + const VectorizedN& a, \ + const VectorizedN& b, \ + const VectorizedN& c) { \ + return a.ternary_op( \ + b, \ + c, \ + [](const Vectorized& a, \ + const Vectorized& b, \ + const Vectorized& c) { return op(a, b, c); }); \ + } + +#define VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(op) \ + template \ + inline VectorizedN& op( \ + VectorizedN& a, const VectorizedN& b) { \ + a = a.binary_op(b, [](const Vectorized& a, const Vectorized& b) { \ + return op(a, b); \ + }); \ + return a; \ + } + +VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator+) +VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator-) +VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator*) +VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator/) +VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator%) +VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator||) +VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator<<) +VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator>>) +VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(maximum) +VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(minimum) +VECTORIZEDN_DEFINE_TERNARY_OP_GLOBAL(fmadd) +VECTORIZEDN_DEFINE_TERNARY_OP_GLOBAL(fmsub) +VECTORIZEDN_DEFINE_TERNARY_OP_GLOBAL(clamp) +VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(clamp_max) +VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(clamp_min) +VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator&) +VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator|) +VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator^) +VECTORIZEDN_DEFINE_UNARY_OP_GLOBAL(operator~) + +VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator+=) +VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator-=) +VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator*=) +VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator/=) +VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator%=) +VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator<<=) +VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator>>=) + +#undef VECTORIZEDN_DEFINE_UNARY_OP_GLOBAL +#undef VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL +#undef VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL + +template +inline T vec_reduce_all(const OpVec& vec_fun, VectorizedN acc_vec) { + Vectorized vec_result = acc_vec[0]; + for (int i = 1; i < N; i++) { + vec_result = vec_fun(vec_result, acc_vec[i]); + } + return vec_reduce_all(vec_fun, vec_result); +} + +template +std::ostream& operator<<(std::ostream& stream, const VectorizedN& vec_n) { + stream << "vec_n["; + for (int i = 0; i < N; ++i) { + if (i != 0) { + stream << ", "; + } + stream << vec_n[i]; + } + stream << ']'; + return stream; +} +} // namespace CPU_CAPABILITY +} // namespace at::vec + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec_quant.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec_quant.h new file mode 100644 index 0000000000000000000000000000000000000000..04c81261f816eb2a1c66d7d3d3c64df2aaf43f7b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec_quant.h @@ -0,0 +1,258 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +namespace at::vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +// Transpose a [4, 64] block to [64, 4] (with contiguous output, ld=4) +template > +static inline void transpose_pad_4x64_block( + const scalar_t* src, + scalar_t* dst, + int64_t ld_src, + int krem = 4, + int nrem = 64) { +#if defined(CPU_CAPABILITY_AVX512) + __m512i r[4]; + // Load with mask if partial + if (nrem < 64) { + __mmask64 mask = (1ULL << nrem) - 1; + for (int i = 0; i < krem; ++i) { + r[i] = _mm512_maskz_loadu_epi8(mask, src + i * ld_src); + } + for (int i = krem; i < 4; ++i) { + r[i] = _mm512_setzero_si512(); + } + } else { + for (int i = 0; i < krem; ++i) { + r[i] = _mm512_loadu_si512( + reinterpret_cast(src + i * ld_src)); + } + for (int i = krem; i < 4; ++i) { + r[i] = _mm512_setzero_si512(); + } + } + + // Transpose 4x64 bytes using unpack and shuffle + __m512i t0 = _mm512_unpacklo_epi8(r[0], r[1]); + __m512i t1 = _mm512_unpackhi_epi8(r[0], r[1]); + __m512i t2 = _mm512_unpacklo_epi8(r[2], r[3]); + __m512i t3 = _mm512_unpackhi_epi8(r[2], r[3]); + + __m512i u0 = _mm512_unpacklo_epi16(t0, t2); + __m512i u1 = _mm512_unpackhi_epi16(t0, t2); + __m512i u2 = _mm512_unpacklo_epi16(t1, t3); + __m512i u3 = _mm512_unpackhi_epi16(t1, t3); + + __m512i v0 = _mm512_shuffle_i32x4(u0, u1, 0x88); + __m512i v1 = _mm512_shuffle_i32x4(u0, u1, 0xdd); + __m512i v2 = _mm512_shuffle_i32x4(u2, u3, 0x88); + __m512i v3 = _mm512_shuffle_i32x4(u2, u3, 0xdd); + + __m512i r0 = _mm512_shuffle_i32x4(v0, v2, 0x88); + __m512i r1 = _mm512_shuffle_i32x4(v1, v3, 0x88); + __m512i r2 = _mm512_shuffle_i32x4(v0, v2, 0xdd); + __m512i r3 = _mm512_shuffle_i32x4(v1, v3, 0xdd); + + // Store output + if (nrem < 16) { + __mmask64 mask = (1ULL << (nrem * 4)) - 1; + _mm512_mask_storeu_epi8(dst, mask, r0); + } else if (nrem == 16) { + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), r0); + } else if (nrem < 32) { + int n_bytes1 = 64; + int n_bytes2 = (nrem * 4) - n_bytes1; + __mmask64 mask = (1ULL << n_bytes2) - 1; + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), r0); + _mm512_mask_storeu_epi8(reinterpret_cast<__m512i*>(dst + 64), mask, r1); + } else if (nrem == 32) { + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), r0); + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 64), r1); + } else if (nrem < 48) { + int n_bytes1 = 64 * 2; + int n_bytes2 = (nrem * 4) - n_bytes1; + __mmask64 mask = (1ULL << n_bytes2) - 1; + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), r0); + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 64), r1); + _mm512_mask_storeu_epi8(reinterpret_cast<__m512i*>(dst + 64 * 2), mask, r2); + } else if (nrem == 48) { + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), r0); + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 64), r1); + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 64 * 2), r2); + } else if (nrem < 64) { + int n_bytes1 = 64 * 3; + int n_bytes2 = (nrem * 4) - n_bytes1; + __mmask64 mask = (1ULL << n_bytes2) - 1; + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), r0); + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 64), r1); + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 64 * 2), r2); + _mm512_mask_storeu_epi8(reinterpret_cast<__m512i*>(dst + 64 * 3), mask, r3); + } else { + // normal case, nrem == 64 + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), r0); + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 64), r1); + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 64 * 2), r2); + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 64 * 3), r3); + } +#else + TORCH_CHECK( + false, + "transpose_pad_4x64_block is only supported when AVX-512 is supported") +#endif +} + +// Reorder [K, N] → [K/4, N, 4] (VNNI4-style layout for bit8) +template > +static inline void pack_vnni4( + const scalar_t* src, + scalar_t* dst, + int64_t ld_src, + int64_t K, + int64_t N) { +#if defined(CPU_CAPABILITY_AVX512) + int64_t bk = 0; + int64_t _K = K / 4 * 4; + int64_t _N = N / 64 * 64; + for (; bk < _K; bk += 4) { + int64_t bn = 0; + for (; bn < _N; bn += 64) { + transpose_pad_4x64_block( + src + bk * ld_src + bn, dst + bk * N + bn * 4, ld_src); + } + int64_t nrem = N - bn; + if (nrem > 0) { + transpose_pad_4x64_block( + src + bk * ld_src + bn, dst + bk * N + bn * 4, ld_src, 4, nrem); + } + } + + // Handle leftover K rows (< 4) + if (K % 4 != 0) { + int krem = K - bk; + int64_t bn = 0; + for (; bn < _N; bn += 64) { + transpose_pad_4x64_block( + src + bk * ld_src + bn, dst + bk * N + bn * 4, ld_src, krem); + } + int64_t nrem = N - bn; + if (nrem > 0) { + transpose_pad_4x64_block( + src + bk * ld_src + bn, dst + bk * N + bn * 4, ld_src, krem, nrem); + } + } +#else + TORCH_CHECK(false, "pack_vnni4 is only supported when AVX-512 is supported") +#endif +} + +// This is a helper function for transpose_pack_vnni4 +// Transform a [4, 16] block (with incontiguous output) +// Src: +// a1 a2 a3 a4 a5 a6 a7 a8 a9 a10 a11 a12 a13 a14 a15 a16 +// b1 b2 b3 b4 b5 b6 b7 b8 b9 b10 b11 b12 b13 b14 b15 b16 +// c1 c2 c3 c4 c5 c6 c7 c8 c9 c10 c11 c12 c13 c14 c15 c16 +// d1 d2 d3 d4 d5 d6 d7 d8 d9 d10 d11 d12 d13 d14 d15 d16 +// Dst: +// a1 a2 a3 a4 b1 b2 b3 b4 c1 c2 c3 c4 d1 d2 d3 d4 +// a5 a6 a7 a8 b5 b6 b7 b8 c5 c6 c7 c8 d5 d6 d7 d8 +// a9 a10 a11 a12 b9 b10 b11 b12 c9 c10 c11 c12 d9 d10 d11 d12 +// a13 a14 a15 a16 b13 b14 b15 b16 c13 c14 c15 c16 d13 d14 d15 d16 +template > +static inline void transpose_vnni4_pad_4x16_block( + const scalar_t* src, + scalar_t* dst, + int64_t ld_src, + int64_t ld_dst, + int krem = 4) { +#if defined(CPU_CAPABILITY_AVX512) + __m128i r[4]; + for (int i = 0; i < krem; ++i) { + r[i] = _mm_loadu_si128(reinterpret_cast(src + i * ld_src)); + } + for (int i = krem; i < 4; ++i) { + r[i] = _mm_setzero_si128(); + } + + // Transpose 4x16 bytes using unpack and shuffle + __m128i t0 = _mm_unpacklo_epi32(r[0], r[1]); + __m128i t1 = _mm_unpackhi_epi32(r[0], r[1]); + __m128i t2 = _mm_unpacklo_epi32(r[2], r[3]); + __m128i t3 = _mm_unpackhi_epi32(r[2], r[3]); + + __m128i r0 = _mm_unpacklo_epi64(t0, t2); + __m128i r1 = _mm_unpackhi_epi64(t0, t2); + __m128i r2 = _mm_unpacklo_epi64(t1, t3); + __m128i r3 = _mm_unpackhi_epi64(t1, t3); + + // Store output + if (krem == 4) { + // normal case + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), r0); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst + ld_dst), r1); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst + ld_dst * 2), r2); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst + ld_dst * 3), r3); + } else { + // masked case + __mmask16 mask = (1ULL << (krem * 4)) - 1; + _mm_mask_storeu_epi8(dst, mask, r0); + _mm_mask_storeu_epi8(reinterpret_cast<__m128i*>(dst + ld_dst), mask, r1); + _mm_mask_storeu_epi8( + reinterpret_cast<__m128i*>(dst + ld_dst * 2), mask, r2); + _mm_mask_storeu_epi8( + reinterpret_cast<__m128i*>(dst + ld_dst * 3), mask, r3); + } +#else + TORCH_CHECK( + false, + "transpose_vnni4_pad_4x16_block is only supported when AVX-512 is supported") +#endif +} + +// Do the transpose packing fusion with VNNI4 +// Reorder [K, N] → [N/4, K, 4] (VNNI4-style layout for bit8) +template > +static inline void transpose_pack_vnni4( + const scalar_t* src, + scalar_t* dst, + int64_t ld_src, + int64_t K, + int64_t N) { +#if defined(CPU_CAPABILITY_AVX512) + TORCH_CHECK( + N % 16 == 0, "N needs to be multiple of 16 for transpose_pack_vnni4"); + int64_t bk = 0; + int64_t _K = K / 4 * 4; + for (; bk < _K; bk += 4) { + int64_t bn = 0; + for (; bn < N; bn += 16) { + transpose_vnni4_pad_4x16_block( + src + bk * ld_src + bn, dst + bn * K + bk * 4, ld_src, K * 4); + } + } + + // Handle leftover K rows (< 4) + if (K % 4 != 0) { + int krem = K - bk; + int64_t bn = 0; + for (; bn < N; bn += 16) { + transpose_vnni4_pad_4x16_block( + src + bk * ld_src + bn, dst + bn * K + bk * 4, ld_src, K * 4, krem); + } + } +#else + TORCH_CHECK( + false, "transpose_pack_vnni4 is only supported when AVX-512 is supported") +#endif +} + +} // namespace CPU_CAPABILITY +} // namespace at::vec + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/ao_sparse/quantized/cpu/fbgemm_utils.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/ao_sparse/quantized/cpu/fbgemm_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..1a716caaaabe88e4ef2fbcb416a5badb181029c6 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/ao_sparse/quantized/cpu/fbgemm_utils.h @@ -0,0 +1,102 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +#ifdef USE_FBGEMM +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wextra-semi") +#include +#include +#include +C10_DIAGNOSTIC_POP() + + +namespace ao::sparse { + +struct TORCH_API PackedLinearWeight + : public LinearPackedParamsBase { + PackedLinearWeight(std::unique_ptr> w, + std::optional bias, + std::vector col_offsets, + std::vector w_scale, + std::vector w_zp, + c10::QScheme q_scheme, + const int64_t out_features_block_size /* block sparsity size across output_features */, + const int64_t in_features_block_size /* block sparsity size across input_features */) + : LinearPackedParamsBase( + out_features_block_size, + in_features_block_size), + w(std::move(w)), + bias_(std::move(bias)), + col_offsets(std::move(col_offsets)), + w_scale(std::move(w_scale)), + w_zp(std::move(w_zp)), + q_scheme(q_scheme) {} + std::unique_ptr> w; + std::optional bias_; + std::vector col_offsets; + std::vector w_scale; + std::vector w_zp; + c10::QScheme q_scheme; + + at::Tensor apply( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point) override; + at::Tensor apply_relu( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point) override; + + at::Tensor apply_dynamic(const at::Tensor& input) override { + TORCH_INTERNAL_ASSERT( + false, + "Sparse quantized dynamic linear with fused relu is not yet " + "supported on qnnpack backend."); + return at::Tensor(); + } + at::Tensor apply_dynamic_relu(const at::Tensor& input) override { + TORCH_INTERNAL_ASSERT( + false, + "Sparse quantized dynamic linear with fused relu is not yet " + "supported on qnnpack backend."); + return at::Tensor(); + } + + LinearPackedSerializationType unpack() override; + + BCSRSerializationType serialize() override; + + static c10::intrusive_ptr deserialize( + const BCSRSerializationType& serialized); + + std::optional bias() override { + return bias_; + } + + static c10::intrusive_ptr prepack( + const at::Tensor& weight, + const std::optional& bias, + const int64_t out_features_block_size, + const int64_t in_features_block_size); + + private: + template + at::Tensor apply_impl( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point); +}; + +} // namespace ao::sparse + +#endif // USE_FBGEMM + +namespace ao::sparse { +int register_linear_params(); +} // namespace ao::sparse + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/ao_sparse/quantized/cpu/packed_params.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/ao_sparse/quantized/cpu/packed_params.h new file mode 100644 index 0000000000000000000000000000000000000000..191b1e160cb4ffdb7e24c75d45b1a82c3cc4a7b7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/ao_sparse/quantized/cpu/packed_params.h @@ -0,0 +1,78 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include + +#include +#include + +namespace ao::sparse { + +// +using LinearPackedSerializationType = + std::tuple, std::vector>; + +#define SPARSE_LINEAR_PACKED_PARAM_SERIALIZATION_VERSION 2 + +using BCSRSerializationType = + std::tuple< + int64_t, // Serialization Version + std::optional, // Bias + int64_t, // Out Features (Row) Block Size + int64_t, // In Features (Column) Block Size + at::Tensor, // Weight Scales (single element vector if per-tensor) (float) + at::Tensor, // Wrapper for Weight Zero Points (single element vector if per-tensor) (int8_t) + bool, // Quantization Scheme (true: per tensor, false: per channel) + at::Tensor, // Wrapper for Row Block Indices (int8_t, int16_t, or int32_t) + at::Tensor, // Wrapper for Column Block Indices (int8_t, int16_t, or int32_t) + at::Tensor, // Wrapper for Non-Zero Weight Values, each +128 (uint8_t) + int64_t, // Number of Output Channels + int64_t // Number of Input Channels + >; + +using BCSR = + std::tuple< + std::vector, // Non-Zero Weight Values + std::vector, // Compressed Row Block Indices + std::vector // Column Block Indices + >; + +struct LinearPackedParamsBase : public torch::jit::CustomClassHolder { + public: + LinearPackedParamsBase( + const int64_t out_features_block_size, + const int64_t in_features_block_size) + : out_features_block_size_(out_features_block_size), + in_features_block_size_(in_features_block_size) {} + + virtual at::Tensor apply( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point) = 0; + virtual at::Tensor apply_relu( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point) = 0; + + virtual at::Tensor apply_dynamic(const at::Tensor& input) = 0; + virtual at::Tensor apply_dynamic_relu(const at::Tensor& input) = 0; + + virtual LinearPackedSerializationType unpack() = 0; + + virtual BCSRSerializationType serialize() = 0; + + virtual std::optional bias() = 0; + + virtual void set_bias(const std::optional& bias) { + TORCH_CHECK(false, "set_bias is not implemented for this packed parameter type"); + } + + protected: + const int64_t out_features_block_size_, in_features_block_size_; +}; + +} // namespace ao::sparse + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/ao_sparse/quantized/cpu/qnnpack_utils.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/ao_sparse/quantized/cpu/qnnpack_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..cdd7f91cb49c918f4a1a64d174f05dd16a84e144 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/ao_sparse/quantized/cpu/qnnpack_utils.h @@ -0,0 +1,95 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +#ifdef USE_PYTORCH_QNNPACK +// TODO: Refacto QnnpackUtils.h so as to separate code +// needed for quantized op from the generic qnnpack specific +// quantization utilities. +#include +#include +#include + +namespace ao::sparse { + +struct TORCH_API PackedLinearWeightQnnp : public LinearPackedParamsBase { + PackedLinearWeightQnnp(const at::Tensor& weight, const std::optional& bias, const int64_t out_features_block_size /* block sparsity size across output_features */, const int64_t in_features_block_size /* block sparsity size across input_features */); + explicit PackedLinearWeightQnnp(const BCSRSerializationType& serialized); + std::optional orig_bias_; + // Separate copy of bias exist so that we can fill in zeros when + // optional bias does not exist. This is to compy with qnnpack operator that + // expects bias to be present. + // In case bias is present bias_ is just a reference to orig_bias_ + at::Tensor bias_; + c10::QScheme q_scheme_; + double input_scale_{}; + std::unique_ptr bcsr_matrix_; + at::Tensor w_scales_; + std::vector w_zero_points_; + std::vector requantization_scales_; + std::unique_ptr + sparse_linear_op_{nullptr}; + int64_t output_channels_; + int64_t input_channels_; + // Deserialized Tensors are stored to maintain the lifetime of underlying + // BCSR data. + // These are left empty if PackedLinearWeightQnnp is created via prepacking + // rather than deserializing. + at::Tensor deserialized_bcsr_row_block_indices_; + at::Tensor deserialized_bcsr_col_block_indices_; + at::Tensor deserialized_bcsr_weight_values_; + + at::Tensor apply( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point) override { + TORCH_CHECK( + false, "Static quantized sparse linear unimplemented on QNNPACK"); + } + at::Tensor apply_relu( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point) override { + TORCH_CHECK( + false, "Static quantized sparse linear unimplemented on QNNPACK"); + } + + at::Tensor apply_dynamic(const at::Tensor& input) override; + at::Tensor apply_dynamic_relu(const at::Tensor& input) override; + + LinearPackedSerializationType unpack() override; + + BCSRSerializationType serialize() override; + + static c10::intrusive_ptr deserialize( + const BCSRSerializationType& serialized); + + std::optional bias() override { + return orig_bias_; + } + + static c10::intrusive_ptr prepack( + const at::Tensor& weight, + const std::optional& bias, + const int64_t out_features_block_size, + const int64_t in_features_block_size); + + private: + template + at::Tensor apply_impl( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point); + template + at::Tensor apply_dynamic_impl(const at::Tensor& input); +}; + +} // namespace ao::sparse + +#endif // USE_PYTORCH_QNNPACK + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/mps/kernels/Activation.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/mps/kernels/Activation.h new file mode 100644 index 0000000000000000000000000000000000000000..39742b47768ec19ce12c2b74833b65abff0bab69 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/mps/kernels/Activation.h @@ -0,0 +1,21 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +template +struct ELUParams { + T alpha; + T scale; + T input_scale; +}; + +template +struct ELUBackwardParams { + T alpha; + T scale; + T input_scale; + bool is_result; +}; + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/mps/kernels/EmbeddingBag.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/mps/kernels/EmbeddingBag.h new file mode 100644 index 0000000000000000000000000000000000000000..e1f50d0950ee478fda8d349164e35a8932ed7264 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/mps/kernels/EmbeddingBag.h @@ -0,0 +1,53 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include + +#ifdef __METAL__ +enum class EmbeddingBagMode { SUM = 0, MEAN, MAX }; +#else +#include +using at::native::EmbeddingBagMode; +#endif + +template +struct EmbeddingBagParams { + ::c10::metal::array weight_strides; + ::c10::metal::array output_strides; + ::c10::metal::array max_indices_strides; + + bool use_per_sample_weights; + idx_type_t per_sample_weights_stride; + + idx_type_t num_indices; + idx_type_t num_bags; + idx_type_t feature_size; + idx_type_t num_weights; + + EmbeddingBagMode mode; + int64_t padding_idx; +}; + +template +struct EmbeddingBagBackwardParams { + ::c10::metal::array weight_grad_strides; + ::c10::metal::array output_grad_strides; + ::c10::metal::array max_indices_strides; + bool use_per_sample_weights; + idx_type_t per_sample_weights_stride; + idx_type_t feature_size; + EmbeddingBagMode mode; + int64_t padding_idx; +}; + +template +struct EmbeddingBagPerSampleWeightsBackwardParams { + ::c10::metal::array output_grad_strides; + ::c10::metal::array weight_strides; + idx_type_t per_sample_weights_grad_stride; + idx_type_t feature_size; + int64_t padding_idx; +}; + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/mps/kernels/GridSampler.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/mps/kernels/GridSampler.h new file mode 100644 index 0000000000000000000000000000000000000000..2d4c3f2beacf9097e0721729e08829cf638d2a05 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/mps/kernels/GridSampler.h @@ -0,0 +1,30 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include + +#ifdef __METAL__ +enum class GridSamplerInterpolation { Bilinear, Nearest, Bicubic }; +enum class GridSamplerPadding { Zeros, Border, Reflection }; +#else +#include +using at::native::GridSamplerInterpolation; +using at::native::GridSamplerPadding; +#endif + +template +struct GridSamplerParams { + int32_t sampler_dims; + ::c10::metal::array output_sizes; + ::c10::metal::array output_strides; + ::c10::metal::array input_sizes; + ::c10::metal::array input_strides; + ::c10::metal::array grid_sizes; + ::c10::metal::array grid_strides; + GridSamplerInterpolation interpolation_mode; + GridSamplerPadding padding_mode; + bool align_corners; +}; + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/mps/kernels/LinearAlgebra.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/mps/kernels/LinearAlgebra.h new file mode 100644 index 0000000000000000000000000000000000000000..238252b54a734dddfc8c34637269accaa98e7ccc --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/mps/kernels/LinearAlgebra.h @@ -0,0 +1,27 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include + +template +struct OrgqrParams { + int32_t num_batch_dims; + + uint32_t m; + uint32_t n; + uint32_t k; + + ::c10::metal::array A_strides; + ::c10::metal::array tau_strides; + ::c10::metal::array H_strides; + ::c10::metal::array H_sizes; +}; + +struct UnpackPivotsParams { + uint32_t perm_batch_stride; + uint32_t pivots_batch_stride; + uint32_t dim_size; +}; + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/mps/kernels/Pooling.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/mps/kernels/Pooling.h new file mode 100644 index 0000000000000000000000000000000000000000..192805eb79413356fe3a5f89ce0a8fe6fcc06498 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/mps/kernels/Pooling.h @@ -0,0 +1,66 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include + +// N is the maximum allowed number of dimensions in the input and outputs. The +// maximum allowed pooling dimensions is N-2, because the input may have up to 2 +// leading dimensions that are not pooled. To support up to 3-D pooling, N=5 is +// the default. +template +struct PoolingParams { + int32_t dims; + int32_t pooling_dims; + ::c10::metal::array input_sizes; + ::c10::metal::array input_strides; + ::c10::metal::array output_sizes; + ::c10::metal::array output_strides; + ::c10::metal::array indices_sizes; + ::c10::metal::array indices_strides; + ::c10::metal::array kernel_size; + ::c10::metal::array stride; + ::c10::metal::array padding; + ::c10::metal::array dilation; + bool return_indices; +}; + +template +struct AvgPoolingParams { + int32_t dims; + int32_t pooling_dims; + ::c10::metal::array input_sizes; + ::c10::metal::array input_strides; + ::c10::metal::array output_sizes; + ::c10::metal::array output_strides; + ::c10::metal::array kernel_size; + ::c10::metal::array stride; + ::c10::metal::array padding; + bool count_include_pad; + bool has_divisor_override; + int32_t divisor_override; +}; + +template +struct PoolingBackwardParams { + int32_t dims; + int32_t pooling_dims; + ::c10::metal::array grad_input_sizes; + ::c10::metal::array grad_input_strides; + ::c10::metal::array grad_output_sizes; + ::c10::metal::array grad_output_strides; + ::c10::metal::array indices_strides; +}; + +template +struct MaxUnpoolingParams { + int32_t dims; + int32_t pooling_dims; + ::c10::metal::array input_sizes; + ::c10::metal::array input_strides; + ::c10::metal::array output_sizes; + ::c10::metal::array output_strides; + ::c10::metal::array indices_strides; +}; + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/mps/kernels/Shape.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/mps/kernels/Shape.h new file mode 100644 index 0000000000000000000000000000000000000000..e8370f69cd48c67beb1dd044a2392c1766f3441d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/mps/kernels/Shape.h @@ -0,0 +1,23 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include + +template +struct CatSharedParams { + int32_t ndim; + int32_t cat_dim; + ::c10::metal::array output_strides; + ::c10::metal::array output_sizes; +}; + +template +struct CatInputParams { + idx_type_t cat_dim_offset; + idx_type_t input_element_offset; + ::c10::metal::array input_strides; + ::c10::metal::array input_sizes; +}; + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/mps/kernels/TensorCompare.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/mps/kernels/TensorCompare.h new file mode 100644 index 0000000000000000000000000000000000000000..66970743448e48dd55292912f4137c9eafe251a1 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/mps/kernels/TensorCompare.h @@ -0,0 +1,12 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +template +struct ClampScalarParams { + T min; + T max; +}; + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/mps/kernels/UpSample.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/mps/kernels/UpSample.h new file mode 100644 index 0000000000000000000000000000000000000000..14bbe274b5139017f303f705aa6d2f29ca810826 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/mps/kernels/UpSample.h @@ -0,0 +1,17 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include + +template +struct UpsampleParams { + ::c10::metal::array input_strides; + ::c10::metal::array input_sizes; + ::c10::metal::array output_strides; + ::c10::metal::array output_sizes; + ::c10::metal::array scales; + bool align_corners; +}; + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/mps/operations/BinaryKernel.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/mps/operations/BinaryKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..bfedae8bb7dbbd5ebb1772612344e9a7de85af32 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/mps/operations/BinaryKernel.h @@ -0,0 +1,15 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +namespace at::native::mps { +void binary_op_kernel( + const std::string func_name, + const Tensor& input, + const Tensor& other, + const Tensor& output, + const std::optional alpha = std::nullopt); +} // namespace at::native::mps + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/mps/operations/FusedAdamAmsgradKernelImpl.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/mps/operations/FusedAdamAmsgradKernelImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..e7d8ab12f3b09e6dddfc87447729916bf1a2c065 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/mps/operations/FusedAdamAmsgradKernelImpl.h @@ -0,0 +1,43 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include + +namespace at::native::mps { + +void _fused_adam_amsgrad_mps_impl_( + TensorList params, + TensorList grads, + TensorList exp_avgs, + TensorList exp_avg_sqs, + TensorList max_exp_avg_sqs, + TensorList state_steps, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const std::optional& grad_scale, + const std::optional& found_inf); + +void _fused_adam_amsgrad_mps_impl_( + TensorList params, + TensorList grads, + TensorList exp_avgs, + TensorList exp_avg_sqs, + TensorList max_exp_avg_sqs, + TensorList state_steps, + const at::Tensor& lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const std::optional& grad_scale, + const std::optional& found_inf); + +} // namespace at::native::mps + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/mps/operations/FusedAdamKernelImpl.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/mps/operations/FusedAdamKernelImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..457cbbc46c85b4d204d5715672fe054db5229288 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/mps/operations/FusedAdamKernelImpl.h @@ -0,0 +1,40 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include + +namespace at::native::mps { + +void _fused_adam_mps_impl_( + TensorList params, + TensorList grads, + TensorList exp_avgs, + TensorList exp_avg_sqs, + TensorList state_steps, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const std::optional& grad_scale, + const std::optional& found_inf); + +void _fused_adam_mps_impl_( + TensorList params, + TensorList grads, + TensorList exp_avgs, + TensorList exp_avg_sqs, + TensorList state_steps, + const Tensor& lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const std::optional& grad_scale, + const std::optional& found_inf); +} // namespace at::native::mps + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/mps/operations/FusedAdamWAmsgradKernelImpl.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/mps/operations/FusedAdamWAmsgradKernelImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..80d66bc90a748d6fe5c4ed8ad15a7e283e47c294 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/mps/operations/FusedAdamWAmsgradKernelImpl.h @@ -0,0 +1,42 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include + +namespace at::native::mps { + +void _fused_adamw_amsgrad_mps_impl_( + TensorList params, + TensorList grads, + TensorList exp_avgs, + TensorList exp_avg_sqs, + TensorList max_exp_avg_sqs, + TensorList state_steps, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const std::optional& grad_scale, + const std::optional& found_inf); + +void _fused_adamw_amsgrad_mps_impl_( + TensorList params, + TensorList grads, + TensorList exp_avgs, + TensorList exp_avg_sqs, + TensorList max_exp_avg_sqs, + TensorList state_steps, + const Tensor& lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const std::optional& grad_scale, + const std::optional& found_inf); +} // namespace at::native::mps + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/mps/operations/FusedAdamWKernelImpl.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/mps/operations/FusedAdamWKernelImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..94bc73bb1d5991653f57b16d52e7c48a55fa6904 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/mps/operations/FusedAdamWKernelImpl.h @@ -0,0 +1,41 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include + +namespace at::native::mps { + +void _fused_adamw_mps_impl_( + TensorList params, + TensorList grads, + TensorList exp_avgs, + TensorList exp_avg_sqs, + TensorList state_steps, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const std::optional& grad_scale, + const std::optional& found_inf); + +void _fused_adamw_mps_impl_( + TensorList params, + TensorList grads, + TensorList exp_avgs, + TensorList exp_avg_sqs, + TensorList state_steps, + const Tensor& lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const std::optional& grad_scale, + const std::optional& found_inf); + +} // namespace at::native::mps + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/mps/operations/MultiTensorApply.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/mps/operations/MultiTensorApply.h new file mode 100644 index 0000000000000000000000000000000000000000..c156a5789ede7b044a3a10c9b7ca9efec699adc3 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/mps/operations/MultiTensorApply.h @@ -0,0 +1,367 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include +#include +#include + +static_assert(sizeof(bool) == 1); + +namespace at::native::mps { + +static constexpr int64_t kChunkSize = 65536; +static constexpr int64_t kmaxThreadGroups = 32; +static constexpr int64_t kmaxTensors = 32; + +struct MetadataArguments { // the size of this struct must be less than 4 kilobytes + uint64_t numels[kmaxTensors]; + uint64_t threadgroup_to_tensor[kmaxThreadGroups]; + uint64_t threadgroup_to_chunk[kmaxThreadGroups]; +}; + +struct FusedAdamEncodingFunctor { + void operator()(id& computeEncoder, + id& tensorArgumentBuffer, + const MetadataArguments& metadata_arguments, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize) const { + mtl_setArgs( + computeEncoder, tensorArgumentBuffer, metadata_arguments, lr, beta1, beta2, weight_decay, eps, maximize); + } + + void operator()(id& computeEncoder, + id& tensorArgumentBuffer, + const MetadataArguments& metadata_arguments, + const at::Tensor& lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize) const { + mtl_setArgs( + computeEncoder, tensorArgumentBuffer, metadata_arguments, lr, beta1, beta2, weight_decay, eps, maximize); + } +}; + +template +struct FusedSgdEncodingFunctor {}; + +template <> +struct FusedSgdEncodingFunctor { + void operator()(id& computeEncoder, + id& tensorArgumentBuffer, + const MetadataArguments& metadata_arguments, + const double weight_decay, + const double momentum, + const double lr, + const double dampening, + const bool nesterov, + const bool maximize, + const bool is_first_step) const { + mtl_setArgs(computeEncoder, + tensorArgumentBuffer, + metadata_arguments, + weight_decay, + momentum, + lr, + dampening, + nesterov, + maximize, + is_first_step); + } + + void operator()(id& computeEncoder, + id& tensorArgumentBuffer, + const MetadataArguments& metadata_arguments, + const double weight_decay, + const double momentum, + const at::Tensor& lr, + const double dampening, + const bool nesterov, + const bool maximize, + const bool is_first_step) const { + mtl_setArgs(computeEncoder, + tensorArgumentBuffer, + metadata_arguments, + weight_decay, + momentum, + lr, + dampening, + nesterov, + maximize, + is_first_step); + } +}; + +template <> +struct FusedSgdEncodingFunctor { + void operator()(id& computeEncoder, + id& tensorArgumentBuffer, + const MetadataArguments& metadata_arguments, + const double weight_decay, + const double lr, + const bool maximize) const { + mtl_setArgs(computeEncoder, tensorArgumentBuffer, metadata_arguments, weight_decay, lr, maximize); + } + + void operator()(id& computeEncoder, + id& tensorArgumentBuffer, + const MetadataArguments& metadata_arguments, + const double weight_decay, + const at::Tensor& lr, + const bool maximize) const { + mtl_setArgs(computeEncoder, tensorArgumentBuffer, metadata_arguments, weight_decay, lr, maximize); + } +}; + +std::pair, id> getFusedAdamCPLState(const std::string& fname); +template +static void multi_tensor_apply_for_fused_optimizer(const std::string& kernel_name, + std::vector>& tensor_lists, + at::TensorList state_steps, + encoder_func_t encode, + ArgTypes... args) { + const auto num_tensors = tensor_lists[0].size(); + + if (num_tensors == 0) { + return; + } + + TORCH_CHECK(tensor_lists.size() == depth, "Number of tensor lists has to match the depth"); + for (const auto& d : c10::irange(depth)) { + const auto scalar_type = tensor_lists[d][0].scalar_type(); + TORCH_CHECK(scalar_type == kFloat || scalar_type == kHalf || scalar_type == kBFloat16, + "Only float, bfloat and half are supported"); + } + + id device = MPSDevice::getInstance()->device(); + MPSStream* mpsStream = getCurrentMPSStream(); + + // Remove comment for debugging + /* + mpsStream->addCompletedHandler(^(id cb) { + [cb.logs enumerateObjectsUsingBlock:^(NSString* log, NSUInteger idx, BOOL* stop) { + NSLog(@"MPSStream: %@", log); + } + ]; + }); + */ + + dispatch_sync_with_rethrow(mpsStream->queue(), ^() { + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + auto [fusedOptimizerPSO, fusedOptimizerFunc] = getFusedAdamCPLState(kernel_name); + + // this function call is a no-op if MPS Profiler is not enabled + getMPSProfiler().beginProfileKernel(fusedOptimizerPSO, kernel_name, {tensor_lists[0]}); + + [computeEncoder setComputePipelineState:fusedOptimizerPSO]; + + // BufferIndex is the index in the kernel function + auto tensorArgumentEncoder = [[fusedOptimizerFunc newArgumentEncoderWithBufferIndex:0] autorelease]; + id tensorArgumentBuffer = [[device newBufferWithLength:tensorArgumentEncoder.encodedLength + options:0] autorelease]; + [tensorArgumentEncoder setArgumentBuffer:tensorArgumentBuffer offset:0]; + + int64_t tensor_loc = 0; + int64_t threadgroup_loc = 0; + MetadataArguments metadata_arguments; + + for (const auto tensor_index : c10::irange(num_tensors)) { + // short-circuit to avoid adding empty tensors to tensorListMeta + if (tensor_lists[0][tensor_index].numel() == 0) { + continue; + } + + for (const auto& d : c10::irange(depth)) { + mtl_setBuffer(tensorArgumentEncoder, tensor_lists[d][tensor_index], d * kmaxTensors + tensor_loc); + [computeEncoder useResource:getMTLBufferStorage(tensor_lists[d][tensor_index]) + usage:MTLResourceUsageRead | MTLResourceUsageWrite]; + } + if (!state_steps.empty()) { + mtl_setBuffer(tensorArgumentEncoder, state_steps[tensor_index], depth * kmaxTensors + tensor_loc); + [computeEncoder useResource:getMTLBufferStorage(state_steps[tensor_index]) usage:MTLResourceUsageRead]; + } + metadata_arguments.numels[tensor_loc] = tensor_lists[0][tensor_index].numel(); + + tensor_loc++; + + const auto numel = tensor_lists[0][tensor_index].numel(); + const auto chunks = numel / kChunkSize + (numel % kChunkSize != 0); + TORCH_CHECK(chunks > -1); + + for (const auto& chunk : c10::irange(chunks)) { + metadata_arguments.threadgroup_to_tensor[threadgroup_loc] = tensor_loc - 1; + metadata_arguments.threadgroup_to_chunk[threadgroup_loc] = chunk; + + threadgroup_loc++; + + const auto tensor_full = tensor_loc == kmaxTensors && chunk == chunks - 1; + // Reach the maximum threadgroups per dispatch + const auto blocks_full = threadgroup_loc == kmaxThreadGroups; + + if (tensor_full || blocks_full) { + encode(computeEncoder, tensorArgumentBuffer, metadata_arguments, args...); + MTLSize gridSize = MTLSizeMake(threadgroup_loc, 1, 1); + uint32_t maxThreadsPerGroup = [fusedOptimizerPSO maxTotalThreadsPerThreadgroup]; + MTLSize threadGroupSize = MTLSizeMake(std::min(maxThreadsPerGroup, kThreadGroupSize), 1, 1); + [computeEncoder dispatchThreadgroups:gridSize threadsPerThreadgroup:threadGroupSize]; + + // Reset + threadgroup_loc = 0; + if (chunk == chunks - 1) { + // last chunk + tensor_loc = 0; + tensorArgumentBuffer = [[device newBufferWithLength:tensorArgumentEncoder.encodedLength + options:0] autorelease]; + [tensorArgumentEncoder setArgumentBuffer:tensorArgumentBuffer offset:0]; + } else { + // reuse the current tensor since the current one isn't done. + metadata_arguments.numels[0] = metadata_arguments.numels[tensor_loc - 1]; + + tensorArgumentBuffer = [[device newBufferWithLength:tensorArgumentEncoder.encodedLength + options:0] autorelease]; + [tensorArgumentEncoder setArgumentBuffer:tensorArgumentBuffer offset:0]; + + for (const auto& d : c10::irange(depth)) { + mtl_setBuffer(tensorArgumentEncoder, tensor_lists[d][tensor_index], d * kmaxTensors); + [computeEncoder useResource:getMTLBufferStorage(tensor_lists[d][tensor_index]) + usage:MTLResourceUsageWrite | MTLResourceUsageRead]; + } + if (!state_steps.empty()) { + mtl_setBuffer(tensorArgumentEncoder, state_steps[tensor_index], depth * kmaxTensors); + [computeEncoder useResource:getMTLBufferStorage(state_steps[tensor_index]) usage:MTLResourceUsageRead]; + } + tensor_loc = 1; + } + } + } + } + + if (threadgroup_loc != 0) { + encode(computeEncoder, tensorArgumentBuffer, metadata_arguments, args...); + MTLSize gridSize = MTLSizeMake(threadgroup_loc, 1, 1); + uint32_t maxThreadsPerGroup = [fusedOptimizerPSO maxTotalThreadsPerThreadgroup]; + MTLSize threadGroupSize = MTLSizeMake(std::min(maxThreadsPerGroup, kThreadGroupSize), 1, 1); + [computeEncoder dispatchThreadgroups:gridSize threadsPerThreadgroup:threadGroupSize]; + } + + getMPSProfiler().endProfileKernel(fusedOptimizerPSO); + } + }); +} + +std::pair, id> getAmpCPLState(const std::string& fname); +template +void multi_tensor_apply(const std::string& kernel_name, + std::vector>& tensor_lists, + ArgTypes... args) { + const auto num_tensors = tensor_lists[0].size(); + if (num_tensors == 0) { + return; + } + + TORCH_CHECK(tensor_lists.size() == depth, "Number of tensor lists must match depth."); + + id device = MPSDevice::getInstance()->device(); + MPSStream* mpsStream = getCurrentMPSStream(); + + dispatch_sync_with_rethrow(mpsStream->queue(), ^() { + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + auto [pipeline, function] = getAmpCPLState(kernel_name); + [computeEncoder setComputePipelineState:pipeline]; + + id argumentEncoder = [function newArgumentEncoderWithBufferIndex:0]; + auto tensorArgumentBuffer = [[device newBufferWithLength:argumentEncoder.encodedLength options:0] autorelease]; + [argumentEncoder setArgumentBuffer:tensorArgumentBuffer offset:0]; + + int tensor_loc = 0; + int threadgroup_loc = 0; + MetadataArguments metadata_arguments; + std::memset(&metadata_arguments, 0, sizeof(metadata_arguments)); + + for (size_t t = 0; t < num_tensors; t++) { + if (tensor_lists[0][t].numel() == 0) + continue; + + // bind each tensor in this list to the correct slots across depths + for (int d = 0; d < depth; d++) { + mtl_setBuffer(argumentEncoder, tensor_lists[d][t], d * kmaxTensors + tensor_loc); + [computeEncoder useResource:getMTLBufferStorage(tensor_lists[d][t]) + usage:(MTLResourceUsageRead | MTLResourceUsageWrite)]; + } + + // save number of elements for this tensor + metadata_arguments.numels[tensor_loc] = tensor_lists[0][t].numel(); + int currentTensorIndex = tensor_loc; + tensor_loc++; + + const auto numel = tensor_lists[0][t].numel(); + const auto chunks = numel / kChunkSize + ((numel % kChunkSize) ? 1 : 0); + + // process tensor in chunks based on max chunk size + for (uint chunk = 0; chunk < chunks; chunk++) { + metadata_arguments.threadgroup_to_tensor[threadgroup_loc] = currentTensorIndex; + metadata_arguments.threadgroup_to_chunk[threadgroup_loc] = chunk; + threadgroup_loc++; + + // dispatch when we've filled the threadgroup array or finished the chunks + const bool dispatch_now = (threadgroup_loc == kmaxThreadGroups) || (chunk == chunks - 1); + if (dispatch_now) { + // check for a partial dispatch (i.e. more chunks remain for the current tensor) + bool partial = (chunk != chunks - 1); + uint carried_numels = 0; + if (partial) { + carried_numels = metadata_arguments.numels[currentTensorIndex]; + } + + mtl_setArgs(computeEncoder, tensorArgumentBuffer, metadata_arguments, args...); + MTLSize gridSize = MTLSizeMake(threadgroup_loc, 1, 1); + uint32_t maxThreads = [pipeline maxTotalThreadsPerThreadgroup]; + MTLSize threadGroupSize = MTLSizeMake(std::min(maxThreads, (uint32_t)64), 1, 1); + [computeEncoder dispatchThreadgroups:gridSize threadsPerThreadgroup:threadGroupSize]; + + // prepare for the next batch: reset threadgroup count and create a new buffer + threadgroup_loc = 0; + tensorArgumentBuffer = [[device newBufferWithLength:argumentEncoder.encodedLength options:0] autorelease]; + [argumentEncoder setArgumentBuffer:tensorArgumentBuffer offset:0]; + + if (partial) { + // for a partial dispatch, rebind the partially processed tensor to slot 0 + // so that its metadata is in the correct location + for (int d = 0; d < depth; d++) { + mtl_setBuffer(argumentEncoder, tensor_lists[d][t], d * kmaxTensors + 0); + [computeEncoder useResource:getMTLBufferStorage(tensor_lists[d][t]) + usage:(MTLResourceUsageRead | MTLResourceUsageWrite)]; + } + metadata_arguments.numels[0] = carried_numels; + // the currently processed tensor now lives at index 0 + currentTensorIndex = 0; + tensor_loc = 1; + } else { + tensor_loc = 0; + } + } + } + } + + if (threadgroup_loc != 0) { + mtl_setArgs(computeEncoder, tensorArgumentBuffer, metadata_arguments, args...); + MTLSize gridSize = MTLSizeMake(threadgroup_loc, 1, 1); + uint32_t maxThreads = [pipeline maxTotalThreadsPerThreadgroup]; + MTLSize threadGroupSize = MTLSizeMake(std::min(maxThreads, static_cast(64)), 1, 1); + [computeEncoder dispatchThreadgroups:gridSize threadsPerThreadgroup:threadGroupSize]; + } + } + }); +} + +} // namespace at::native::mps + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/epilogue/epilogue_pipelined.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/epilogue/epilogue_pipelined.h new file mode 100644 index 0000000000000000000000000000000000000000..eaf7836eb027d541a23224ecc68d7cb39e9fb33b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/epilogue/epilogue_pipelined.h @@ -0,0 +1,636 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + File copied from + then modified to: + (1) load 2 source fragments at the same time (pipelining) + (2) support reading from a different dtype + (3) pass the row id to the OutputOp if it takes it + (see MemoryEfficientAttentionNormalize) + Note that in general the fragment passed to the OutputOp could + span multiple rows but it does not happen with the configurations we have +*/ + +#pragma once + +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include + +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +template +struct ApplyEpilogueOp { + static CUTLASS_DEVICE typename Op::FragmentOutput apply( + Op const& output_op, + int row_id, + typename Op::FragmentAccumulator const& accum, + typename Op::FragmentOutput const& source) { + return output_op(accum, source); + } + static CUTLASS_DEVICE typename Op::FragmentOutput apply( + Op const& output_op, + int row_id, + typename Op::FragmentAccumulator const& accum) { + return output_op(accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Epilogue operator +template < + typename Shape_, ///< Shape of threadblock tile (concept: GemmShape) + typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: + ///< gemm::warp::MmaTensorOp) + int PartitionsK, ///< Number of partitions of the K dimension + typename OutputTileIterator_, ///< Tile iterator writing output tensors + typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting + ///< accumulators + typename WarpTileIterator_, ///< Warp-scoped tile iterator writing + ///< accumulators to SMEM + typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading + ///< from SMEM + typename OutputOp_, ///< Output operator + typename Padding_, ///< Padding added to SMEM allocation to avoid bank + ///< conflicts (concept: MatrixShape) + int FragmentsPerPartition = + 1, ///< Used to coarsten the epilogue granularity + int IterationsUnroll = ///< Used to reduce binary size when epilogue op is + ///< large + (!IsEpilogueFunctorHeavy::value), + typename OutputTileSourceIterator_ = + OutputTileIterator_ ///< Tile iterator reading tensors + > +class EpiloguePipelined : public EpilogueBase< + Shape_, + typename WarpMmaOperator_::Shape, + PartitionsK, + AccumulatorFragmentIterator_, + WarpTileIterator_, + Padding_, + FragmentsPerPartition> { + public: + using Base = EpilogueBase< + Shape_, + typename WarpMmaOperator_::Shape, + PartitionsK, + AccumulatorFragmentIterator_, + WarpTileIterator_, + Padding_, + FragmentsPerPartition>; + + using Shape = Shape_; + using WarpMmaOperator = WarpMmaOperator_; + static int const kPartitionsK = PartitionsK; + using OutputTileIterator = OutputTileIterator_; + using OutputTileSourceIterator = OutputTileSourceIterator_; + using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; + using WarpTileIterator = WarpTileIterator_; + using SharedLoadIterator = SharedLoadIterator_; + using OutputOp = OutputOp_; + using Padding = Padding_; + + using Layout = layout::RowMajor; + using LongIndex = typename Layout::LongIndex; + + /// The complete warp-level accumulator tile + using AccumulatorTile = typename Base::AccumulatorTile; + + /// Accumulator element + using ElementAccumulator = typename WarpTileIterator::Element; + + /// Output element + using ElementOutput = typename OutputTileIterator::Element; + using ElementSource = typename OutputTileSourceIterator::Element; + + /// Output access size + static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; + + /// Tensor reference to destination tensor + using TensorRef = typename OutputTileIterator::TensorRef; + + /// Tensor reference to sync tensor + using SyncTensorRef = + typename cutlass::TensorRef; + + /// Const tensor reference to source tensor + using ConstTensorRef = typename OutputTileIterator::ConstTensorRef; + + /// Array type used to output + using OutputAccessType = Array< + typename OutputTileIterator::Element, + OutputTileIterator::kElementsPerAccess>; + using SourceAccessType = Array< + typename OutputTileSourceIterator::Element, + OutputTileSourceIterator::kElementsPerAccess>; + + /// Array type used by output functor + using AccumulatorAccessType = Array< + typename WarpTileIterator::Element, + OutputTileIterator::kElementsPerAccess>; + + /// Number of warps + using WarpCount = typename Base::WarpCount; + + static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1 + ? Base::kFragmentsPerIteration + : kPartitionsK; + static int constexpr kSmemPointerOffset = + Base::SharedStorage::StorageShape::kCount / kSmemTiles; + + public: + static_assert( + OutputTileSourceIterator::Fragment::kElements == + OutputTileIterator::Fragment::kElements, + "Mismatch between input tile and output tile iterator (kElements)"); + static_assert( + OutputTileSourceIterator::kIterations == OutputTileIterator::kIterations, + "Mismatch between input tile and output tile iterator (kIterations)"); + static_assert( + SharedLoadIterator::Fragment::kElements == + OutputTileIterator::Fragment::kElements, + "Mismatch between shared load iterator and output tile iterator."); + + static_assert( + OutputTileIterator::kElementsPerAccess, + "OutputTileIterator::kElementsPerAccess must not be zero."); + + static_assert( + !(OutputTileIterator::Fragment::kElements % + OutputTileIterator::kElementsPerAccess), + "Divisibility"); + + private: + /// Loads fragment from shared memory aligned with output tensor + SharedLoadIterator shared_load_iterator_; + + public: + /// Constructor + CUTLASS_DEVICE + EpiloguePipelined( + typename Base::SharedStorage& shared_storage, ///< Shared storage object + int thread_idx, ///< ID of a thread within the threadblock + int warp_idx, ///< ID of warp within threadblock + int lane_idx ///< Id of thread within warp + ) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), + shared_load_iterator_(shared_storage.reference(), thread_idx) {} + + /// Streams the result to global memory + CUTLASS_DEVICE + void operator()( + OutputOp const& output_op, ///< Output operator + OutputTileIterator + destination_iterator, ///< Tile iterator for destination + AccumulatorTile const& + accumulators, ///< Complete warp-level accumulator tile + OutputTileSourceIterator + source_iterator) { ///< Threadblock tile coordinate in GEMM (in units + ///< of threadblock tiles) + + if (!output_op.is_source_needed()) { + compute_source_not_needed_(output_op, destination_iterator, accumulators); + } else { + compute_source_needed_( + output_op, destination_iterator, accumulators, source_iterator); + } + } + CUTLASS_DEVICE + void operator()( + OutputOp const& output_op, ///< Output operator + OutputTileIterator + destination_iterator, ///< Tile iterator for destination + AccumulatorTile const& + accumulators) { ///< Complete warp-level accumulator tile + compute_source_not_needed_(output_op, destination_iterator, accumulators); + } + + private: + template + struct acc2smem_source_not_needed; + + template + struct acc2smem_source_not_needed> { + template + CUTLASS_DEVICE static void helper( + AccumulatorFragmentIterator accum_fragment_iterator, + WarpTileIterator& warp_tile_iterator) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Advance; i++) { + ++accum_fragment_iterator; + } + + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < Base::kFragmentsPerIteration; ++p) { + typename AccumulatorFragmentIterator::Fragment accum_fragment; + + accum_fragment_iterator.load(accum_fragment); + ++accum_fragment_iterator; + + warp_tile_iterator.store(accum_fragment); + if (p < Base::kFragmentsPerIteration - 1) { + warp_tile_iterator.add_pointer_offset(kSmemPointerOffset); + } + } + + if (Base::kFragmentsPerIteration > 1) { + warp_tile_iterator.add_pointer_offset( + kSmemPointerOffset * (1 - Base::kFragmentsPerIteration)); + } + } + + CUTLASS_DEVICE + static void push( + size_t pos, + AccumulatorFragmentIterator const& iterator_begin, + WarpTileIterator& warp_tile_iterator) { + int dummy[] = { + (pos == (Seq * Base::kFragmentsPerIteration)) && + (helper( + iterator_begin, warp_tile_iterator), + 0)...}; + + CUTLASS_UNUSED(dummy[0]); + } + }; + + static_assert( + kPartitionsK == 1 || Base::kFragmentsPerIteration == 1, + "One of these must be exactly 1."); + + /// Streams the result to global memory + CUTLASS_DEVICE + void compute_source_not_needed_( + OutputOp const& output_op, ///< Output operator + OutputTileIterator + destination_iterator, ///< Tile iterator for destination + AccumulatorTile const& + accumulators ///< Complete warp-level accumulator tile + ) { + // + // Iterator over warp-level accumulator fragment + // + + AccumulatorFragmentIterator accum_fragment_iterator(accumulators); + + // + // Iterate over accumulator tile + // + +#pragma unroll( \ + IterationsUnroll \ + ? OutputTileIterator::kIterations / Base::kFragmentsPerIteration \ + : 1) + for (int iter = 0; iter < OutputTileIterator::kIterations; + iter += Base::kFragmentsPerIteration) { + // + // Convert and store fragment + // + + __syncthreads(); + + acc2smem_source_not_needed>:: + push(iter, accum_fragment_iterator, this->warp_tile_iterator_); + + __syncthreads(); + + // + // Load fragments from shared memory + // + + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < Base::kFragmentsPerIteration; ++p) { + typename SharedLoadIterator::Fragment + aligned_accum_fragment[kPartitionsK]; + + shared_load_iterator_.load(aligned_accum_fragment[0]); + + if (p < Base::kFragmentsPerIteration - 1) { + shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); + } else if (kPartitionsK > 1) { + plus add_fragments; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < kPartitionsK; ++i) { + shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); + shared_load_iterator_.load(aligned_accum_fragment[i]); + aligned_accum_fragment[0] = add_fragments( + aligned_accum_fragment[0], aligned_accum_fragment[i]); + } + + shared_load_iterator_.add_pointer_offset( + (1 - kPartitionsK) * kSmemPointerOffset); + } + + // + // Compute the output result + // + + typename OutputTileIterator::Fragment output_fragment; + + apply_output_operator_source_not_needed_( + destination_iterator.thread_start_row(), + output_fragment, + output_op, + aligned_accum_fragment[0]); + + // + // Store the final result + // + + destination_iterator.store(output_fragment); + ++destination_iterator; + } + + if (Base::kFragmentsPerIteration > 1) { + shared_load_iterator_.add_pointer_offset( + kSmemPointerOffset * (1 - Base::kFragmentsPerIteration)); + } + } + } + + template + struct acc2smem_source_needed; + + template + struct acc2smem_source_needed> { + template + CUTLASS_DEVICE static void helper( + AccumulatorFragmentIterator accum_fragment_iterator, + WarpTileIterator& warp_tile_iterator) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Advance; i++) { + ++accum_fragment_iterator; + } + + typename AccumulatorFragmentIterator::Fragment accum_fragment; + accum_fragment_iterator.load(accum_fragment); + warp_tile_iterator.store(accum_fragment); + } + + CUTLASS_DEVICE + static void push( + size_t pos, + AccumulatorFragmentIterator const& iterator_begin, + WarpTileIterator& warp_tile_iterator) { + int dummy[] = { + (pos == Seq) && + (helper(iterator_begin, warp_tile_iterator), 0)...}; + } + }; + + /// Streams the result to global memory + CUTLASS_DEVICE + void compute_source_needed_( + OutputOp const& output_op, ///< Output operator + OutputTileIterator + destination_iterator, ///< Tile iterator for destination + AccumulatorTile const& + accumulators, ///< Complete warp-level accumulator tile + OutputTileSourceIterator + source_iterator ///< Threadblock tile coordinate in GEMM (in units of + ///< threadblock tiles) + ) { + typename OutputTileSourceIterator::Fragment source_fragment[2]; + + source_fragment[0].clear(); + source_iterator.load(source_fragment[0]); + ++source_iterator; + source_fragment[1].clear(); + + // + // Iterator over warp-level accumulator fragment + // + + AccumulatorFragmentIterator accum_fragment_iterator(accumulators); + + // + // Iterate over accumulator tile + // + +#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1) + for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { + if (iter > 0) { + __syncthreads(); + } + // + // Load the source for next iteration (pipelining) + // + + if (iter + 1 < OutputTileIterator::kIterations) { + source_iterator.load(source_fragment[(iter + 1) % 2]); + } + ++source_iterator; + acc2smem_source_needed< + cutlass::make_index_sequence>:: + push(iter, accum_fragment_iterator, this->warp_tile_iterator_); + + __syncthreads(); + + // + // Load fragments from shared memory + // + + typename SharedLoadIterator::Fragment + aligned_accum_fragment[kPartitionsK]; + + shared_load_iterator_.load(aligned_accum_fragment[0]); + + // If the number of k-slices is > 1 - perform a reduction amongst the + // k-slices + if (kPartitionsK > 1) { + plus add_fragments; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < kPartitionsK; ++i) { + shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); + shared_load_iterator_.load(aligned_accum_fragment[i]); + aligned_accum_fragment[0] = add_fragments( + aligned_accum_fragment[0], aligned_accum_fragment[i]); + } + + shared_load_iterator_.add_pointer_offset( + (1 - kPartitionsK) * kSmemPointerOffset); + } + + // + // Compute the output result + // + + typename OutputTileIterator::Fragment output_fragment; + + apply_output_operator_( + destination_iterator.thread_start_row(), + output_fragment, + output_op, + aligned_accum_fragment[0], + source_fragment[iter % 2]); + + // + // Store the final result + // + + destination_iterator.store(output_fragment); + ++destination_iterator; + } + } + + /// Helper to invoke the output functor over each vector of output + CUTLASS_DEVICE + void apply_output_operator_( + int begin_row, + typename OutputTileIterator::Fragment& output_fragment, + OutputOp const& output_op, ///< Output operator + typename SharedLoadIterator::Fragment const& aligned_accum_fragment, + typename OutputTileSourceIterator::Fragment const& source_fragment) { + OutputAccessType* output_frag_ptr = + reinterpret_cast(&output_fragment); + + AccumulatorAccessType const* compute_frag_ptr = + reinterpret_cast(&aligned_accum_fragment); + + SourceAccessType const* source_frag_ptr = + reinterpret_cast(&source_fragment); + + int const kOutputOpIterations = OutputTileIterator::Fragment::kElements / + OutputTileIterator::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kOutputOpIterations; ++i) { + // Call the output operator + output_frag_ptr[i] = ApplyEpilogueOp::apply( + output_op, + begin_row + getRowOffset(i * OutputTileIterator::kElementsPerAccess), + compute_frag_ptr[i], + source_frag_ptr[i]); + } + } + + /// Helper to invoke the output functor over each vector of output + CUTLASS_DEVICE + void apply_output_operator_source_not_needed_( + int begin_row, + typename OutputTileIterator::Fragment& output_fragment, + OutputOp const& output_op, ///< Output operator + typename SharedLoadIterator::Fragment const& aligned_accum_fragment) { + OutputAccessType* output_frag_ptr = + reinterpret_cast(&output_fragment); + + AccumulatorAccessType const* compute_frag_ptr = + reinterpret_cast(&aligned_accum_fragment); + + int const kOutputOpIterations = OutputTileIterator::Fragment::kElements / + OutputTileIterator::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kOutputOpIterations; ++i) { + // Call the output operator + output_frag_ptr[i] = ApplyEpilogueOp::apply( + output_op, + begin_row + getRowOffset(i * OutputTileIterator::kElementsPerAccess), + compute_frag_ptr[i]); + } + } + + constexpr int CUTLASS_HOST_DEVICE getRowOffset(int i) { + using ThreadMap = typename OutputTileIterator::ThreadMap; + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; + ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + int frag_row_idx = + (row + + ThreadMap::Iterations::kRow * + (group + ThreadMap::Iterations::kGroup * cluster)); + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; + ++column) { + int frag_idx = ThreadMap::kElementsPerAccess * + (frag_row_idx * ThreadMap::Iterations::kColumn + column); + if (i < frag_idx + ThreadMap::kElementsPerAccess) { + return row_offset; + } + } + } + } + } + return -1; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/epilogue/epilogue_rescale_output.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/epilogue/epilogue_rescale_output.h new file mode 100644 index 0000000000000000000000000000000000000000..44a8adde27ca17333f731f9de0316f9bcc6aa3dd --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/epilogue/epilogue_rescale_output.h @@ -0,0 +1,243 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + The epilogue rearranges the result of a matrix product through shared memory + to match canonical tensor layouts in global memory. Epilogues support + conversion and reduction operations. + + This is a copy of cutlass/epilogue/threadblock/epilogue.h that can + handle "row_id" as a first argument, as uses it to get the corresponding + `m_prime` / `s_prime` to rescale the output. +*/ + +#pragma once + +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Applies a linear combination operator to an array of elements. +// output <- alpha * accumulator + beta * source +// with: +// alpha = 1 / s_prime (to normalize when isLast=True, 1 otherwise) +// beta = alpha / m_prime (renormalize the output when the max changes) +// source is the current output +template < + typename ElementOutput_, ///< Data type used to store tensors + typename ElementSource_, //< Data type for source (usually matches + //`ElementOutput`) + int Count, ///< Number of elements computed per operation. + ///< Usually it is 128/sizeof_bits, + ///< but we use 64 or 32 sometimes when there are not enough data + ///< to store + typename ElementAccumulator_, ///< Accumulator data type + typename ElementCompute_, ///< Data type used to compute linear combination + bool isFirst, + bool isLast, + typename FragmentAlphaBeta_, + FloatRoundStyle Round = FloatRoundStyle::round_to_nearest> +class MemoryEfficientAttentionNormalize { + public: + using ElementOutput = ElementOutput_; + using ElementSource = ElementSource_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + + static int const kCount = Count; + + using FragmentOutput = Array; + using FragmentSource = Array; + using FragmentAccumulator = Array; + using ComputeFragment = Array; + using FragmentAlphaBeta = FragmentAlphaBeta_; + + static FloatRoundStyle const kRound = Round; + + private: + // + // Data members + // + + FragmentAlphaBeta const& s_prime_; + FragmentAlphaBeta const& m_prime_; + + public: + /// Constructs the function object, possibly loading from pointers in host + /// memory + CUTLASS_HOST_DEVICE + MemoryEfficientAttentionNormalize( + FragmentAlphaBeta const& s_prime, + FragmentAlphaBeta const& m_prime) + : s_prime_(s_prime), m_prime_(m_prime) {} + + /// Returns true if source is needed + CUTLASS_HOST_DEVICE + bool is_source_needed() const { + return !isFirst; + } + + /// Functionally required for serial reduction in the epilogue + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition, int k_partition_count) {} + + /// Computes linear scaling: D = alpha * accumulator + beta * source + CUTLASS_HOST_DEVICE + FragmentOutput operator()( + int row, + FragmentAccumulator const& accumulator, + FragmentSource const& source) const { + assert(!isFirst); + + // Convert source to internal compute numeric type + NumericArrayConverter + source_converter; + NumericArrayConverter + accumulator_converter; + + // Convert to destination numeric type + NumericArrayConverter + destination_converter; + + ComputeFragment converted_source = source_converter(source); + ComputeFragment converted_accumulator = accumulator_converter(accumulator); + + // Perform binary operations + ComputeFragment intermediate; + + multiplies mul_add_source; + multiply_add mul_add_accumulator; + + // Row sums for full masked out rows are 0, we set them to 1 + // In order to avoid NaNs in the output and instead sem them to 0. + ElementCompute denom = s_prime_[row] == 0 ? 1 : s_prime_[row]; + ElementCompute alpha = isLast ? (1 / denom) : 1; + ElementCompute beta = alpha * m_prime_[row]; + + intermediate = mul_add_source(beta, converted_source); // X = beta * C + + intermediate = mul_add_accumulator( + alpha, converted_accumulator, intermediate); // D = alpha * Accum + X + + return destination_converter(intermediate); + } + + /// Computes linear scaling: D = alpha * accumulator + CUTLASS_HOST_DEVICE + FragmentOutput operator()(int row, FragmentAccumulator const& accumulator) + const { + assert(isFirst); + + // Convert source to internal compute numeric type + NumericArrayConverter + accumulator_converter; + + // Convert to destination numeric type + NumericArrayConverter + destination_converter; + + ComputeFragment converted_accumulator = accumulator_converter(accumulator); + + ComputeFragment intermediate; + multiplies mul_accumulator; + + // Row sums for full masked out rows are 0, we set them to 1 + // In order to avoid NaNs in the output and instead sem them to 0. + ElementCompute denom = s_prime_[row] == 0 ? 1 : s_prime_[row]; + ElementCompute alpha = isLast ? (1 / denom) : 1; + + intermediate = mul_accumulator( + alpha, converted_accumulator); // X = alpha * C + uniform + + return destination_converter(intermediate); + } +}; + +} // namespace thread + +namespace threadblock { +template < + typename EO, + typename ES, + int Count, + typename EA, + typename EC, + bool F, + bool L, + typename FAB, + FloatRoundStyle R> +struct ApplyEpilogueOp> { + using Op = thread:: + MemoryEfficientAttentionNormalize; + static CUTLASS_DEVICE typename Op::FragmentOutput apply( + Op const& output_op, + int row_id, + typename Op::FragmentAccumulator const& accum, + typename Op::FragmentSource const& source) { + return output_op(row_id, accum, source); + } + static CUTLASS_DEVICE typename Op::FragmentOutput apply( + Op const& output_op, + int row_id, + typename Op::FragmentAccumulator const& accum) { + return output_op(row_id, accum); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/epilogue/epilogue_thread_apply_logsumexp.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/epilogue/epilogue_thread_apply_logsumexp.h new file mode 100644 index 0000000000000000000000000000000000000000..2f5bf2957dad1f41639111a093b52b841842045d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/epilogue/epilogue_thread_apply_logsumexp.h @@ -0,0 +1,180 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing linear combination operations used by epilogues. +*/ + +#pragma once + +#include + +#include +#include +#include +#include +#include +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +struct ArrayExponential { + CUTLASS_HOST_DEVICE + Array operator()( + Array const& input) const { + Array result; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ElementsPerAccess; ++i) { + result[i] = expf(input[i]); + } + + return result; + } +}; + +template +struct ArrayExponential { + CUTLASS_DEVICE + Array operator()( + Array const& input) const { + Array result; + + int const kVectorCount = ElementsPerAccess / 2; + + __half2 const* input_ptr = + reinterpret_cast<__half2 const*>(input.raw_data()); + __half2* res_ptr = reinterpret_cast<__half2*>(result.raw_data()); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kVectorCount; ++i) { + res_ptr[i] = h2exp(input_ptr[i]); + } + + return result; + } +}; +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Applies: +/// output <- (input - lse).exp() +template < + typename ElementOutput_, // output + typename ElementLSE_, // accumulator from LSE + typename ElementAccumulator_, // accumulator from matmul + typename ElementCompute_, // intermediate compute (and exp calculation) + int ElementsPerAccess> +class ApplyLogSumExp { + public: + using ElementOutput = ElementOutput_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + using ElementLSE = ElementLSE_; + + static int constexpr kElementsPerAccess = ElementsPerAccess; + static int constexpr kCount = kElementsPerAccess; + static constexpr ScaleType::Kind kScale = + cutlass::epilogue::thread::ScaleType::NoBetaScaling; + + using FragmentOutput = Array; + using FragmentAccumulator = Array; + using FragmentCompute = Array; + using FragmentLSE = Array; + using FragmentScaleBias = FragmentLSE; // Used by epilogue_smem_accumulator.h + + public: + // + // Methods + // + + CUTLASS_HOST_DEVICE + ApplyLogSumExp() {} + + /// Returns true if source is needed + CUTLASS_HOST_DEVICE + bool is_source_needed() const { + return true; + } + + /// Functionally required for serial reduction in the epilogue + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition, int k_partition_count) {} + + CUTLASS_HOST_DEVICE + FragmentOutput operator()( + FragmentAccumulator const& AB, + FragmentLSE const& scale_unused, + // bias used as LSE + FragmentLSE const& bias) const { + FragmentCompute frag_AB = NumericArrayConverter< + ElementCompute, + ElementAccumulator, + kElementsPerAccess>()(AB); + FragmentCompute frag_lse_compute = + NumericArrayConverter()( + bias); + FragmentCompute frag_compute; + + minus minus_lse; + detail::ArrayExponential apply_exp; + frag_compute = minus_lse(frag_AB, frag_lse_compute); + frag_compute = apply_exp(frag_compute); + + return NumericArrayConverter< + ElementOutput, + ElementCompute, + kElementsPerAccess>()(frag_compute); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/default_warp_iterator_from_smem.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/default_warp_iterator_from_smem.h new file mode 100644 index 0000000000000000000000000000000000000000..e0892db259b2f8aadec4ee4bb5615b6224ff5907 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/default_warp_iterator_from_smem.h @@ -0,0 +1,148 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Instantiates the right WarpIterator to read from shared memory + The class `DefaultWarpIteratorAFromSharedMemory` is useful when reading + data dumped with `B2bGemm::accumToSmem`. +*/ + +#pragma once + +#include +#include +#include + +#include + +namespace cutlass { +namespace gemm { +namespace threadblock { + +template < + typename WarpShape, + typename InstructionShape, + typename RegularWarpIterator, + typename Policy, + typename Enable = void> +struct DefaultWarpIteratorAFromSharedMemory {}; + +// TensorOp - Ampere half +template +struct DefaultWarpIteratorAFromSharedMemory< + cutlass::gemm::GemmShape<32, 32, 32>, + cutlass::gemm::GemmShape<16, 8, kInstrK>, + RegularWarpIterator, + Policy, + typename platform::enable_if<( + sizeof_bits::value == 16 && + Policy::Operator::Policy::OpDelta::kRow == 1)>::type> { + using OpDelta = typename Policy::Operator::Policy::OpDelta; + using WarpShape = cutlass::MatrixShape<32, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, kInstrK>; + + using WarpIterator = cutlass::gemm::warp::WarpIteratorFromSmem< + cutlass::gemm::Operand::kA, + typename RegularWarpIterator::Element, + cutlass::MatrixShape>; +}; + +// TensorOp - Ampere f32 +template +struct DefaultWarpIteratorAFromSharedMemory< + WarpShape, + cutlass::gemm::GemmShape<16, 8, 8>, + RegularWarpIterator, + Policy, + typename platform::enable_if<( + sizeof_bits::value != 16 || + Policy::Operator::Policy::OpDelta::kRow != 1)>::type> { + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + static constexpr auto kWarpSize = 32; + using OpDelta = typename Policy::Operator::Policy::OpDelta; + + using WarpIterator = + cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator< + cutlass::MatrixShape, + cutlass::gemm::Operand::kA, + typename RegularWarpIterator::Element, + cutlass::layout::RowMajor, + cutlass::MatrixShape, + OpDelta::kRow, + kWarpSize>; +}; + +// TensorOp - Volta +template +struct DefaultWarpIteratorAFromSharedMemory< + WarpShape, + cutlass::gemm::GemmShape<16, 16, 4>, + RegularWarpIterator, + Policy> { + using InstructionShape = cutlass::gemm::GemmShape<16, 16, 4>; + static constexpr auto kWarpSize = 32; + using OpDelta = typename Policy::Operator::Policy::OpDelta; + + using WarpIterator = + cutlass::gemm::warp::MmaVoltaTensorOpMultiplicandTileIterator< + cutlass::MatrixShape<32, 32>, // MatrixShape, + cutlass::gemm::Operand::kA, + typename RegularWarpIterator::Element, + cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>, + cutlass::MatrixShape<16, 4>, + OpDelta::kRow, + kWarpSize>; +}; + +// Simt +template +struct DefaultWarpIteratorAFromSharedMemory< + WarpShape, + cutlass::gemm::GemmShape<1, 1, 1>, + RegularWarpIterator, + Policy> { + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + static constexpr auto kWarpSize = 32; + + // We just use the same iterator, as we reproduced the same shared-memory + // schema. Just modify it to handle non-complete tiles. + using WarpIterator = RegularWarpIterator; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/epilogue_predicated_tile_iterator.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/epilogue_predicated_tile_iterator.h new file mode 100644 index 0000000000000000000000000000000000000000..ff985595ca81354f02bf3c9681ebf99887f3e26d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/epilogue_predicated_tile_iterator.h @@ -0,0 +1,757 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Epilogue iterator that supports prefetching + + Mostly copied from +*/ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +//////////////////////////////////////////////////////////////////////////////// + +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator used to load and store output tile from global memory in +/// epilogue. +/// +/// Satisfies: ReadableTileIterator | PredicatedTileIterator | +/// ForwardTileIterator +/// +template < + typename ThreadMap_, ///< Thread map (concept: OutputTileThreadMap) + typename Element_, ///< Element data type + bool ScatterD = false, ///< Scatter D operand or not + bool UseCUDAStore = false> +class PredicatedTileIteratorPrefetch { + public: + using ThreadMap = ThreadMap_; + using Shape = typename ThreadMap::Shape; + + using Element = Element_; + + using Layout = layout::RowMajor; + using TensorRef = TensorRef; + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using TensorCoord = MatrixCoord; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + static int const kThreads = ThreadMap::kThreads; + static int const kIterations = ThreadMap::Count::kTile; + + static_assert( + ThreadMap::Iterations::kRow > 0, + "ThreadMap::Iterations::kRow must be > 0"); + static_assert( + ThreadMap::Iterations::kGroup > 0, + "ThreadMap::Iterations::kGroup must be > 0"); + static_assert( + ThreadMap::Iterations::kCluster > 0, + "ThreadMap::Iterations::kCluster must be > 0"); + static_assert( + ThreadMap::Iterations::kColumn > 0, + "ThreadMap::Iterations::kColumn must be > 0"); + + /// Fragment object + using Fragment = Array< + Element, + ThreadMap::Iterations::kColumn * ThreadMap::Iterations::kRow * + ThreadMap::Iterations::kGroup * ThreadMap::Iterations::kCluster * + ThreadMap::kElementsPerAccess>; + + /// Memory access size + using AccessType = AlignedArray; + + // + // Parameters struct + // + + /// Uses a non-template class + struct Params : PredicatedTileIteratorParams { + using Base = PredicatedTileIteratorParams; + + CUTLASS_HOST_DEVICE + Params() {} + + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : PredicatedTileIteratorParams( + layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess, + make_OutputTileThreadMapDesc()) {} + + CUTLASS_HOST_DEVICE + Params(Base const& base) : Base(base) {} + }; + + /// Mask object + struct Mask { + static int const kCount = ThreadMap::Iterations::kColumn; + + /// Predicate state + bool predicates[kCount]; + + // + // Mask + // + CUTLASS_HOST_DEVICE + Mask() { + enable(); + } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_HOST_DEVICE void clear() { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = false; + } + } + + ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask + CUTLASS_DEVICE void enable() { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = true; + } + } + }; + + private: + // + // Data members + // + + /// Parameters structure containing reference and precomputed state. + PredicatedTileIteratorParams params_; + + /// Byte-level pointer + uint8_t* byte_pointer_; + + /// Array of boolean values to contain steady-state predicates + Mask mask_; + + /// Extent of the matrix tile in rows + Index extent_row_; + + /// Extent of the matrix tile in rows + Index extent_column_; + + /// A thread's starting row position (assuming steady-state predicates have + /// been computed) + Index thread_start_row_; + + /// A thread's starting column + Index thread_start_column_; + + /// Internal state counter + int state_[3]; + + /// Scatter indices + int const* indices_; + + // + // Static asserts about internal strides + // + + static_assert(sizeof(extent_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents"); + static_assert( + sizeof(PredicatedTileIteratorParams::stride) == 8, + "Expected 64b strides"); + + private: + // + // Methods + // + + public: + // + // Methods + // + + /// Constructor + CUTLASS_DEVICE + PredicatedTileIteratorPrefetch( + PredicatedTileIteratorParams const& params, + Element* pointer, + TensorCoord extent, + int thread_idx, + TensorCoord threadblock_offset = TensorCoord(), + int const* indices = nullptr) + : params_(params), indices_(indices) { + TensorCoord thread_offset = + ThreadMap::initial_offset(thread_idx) + threadblock_offset; + + extent_row_ = extent.row(); + extent_column_ = extent.column(); + + thread_start_row_ = thread_offset.row(); + thread_start_column_ = thread_offset.column(); + + // Initialize predicates + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) { + mask_.predicates[c] = + ((thread_offset.column() + ThreadMap::Delta::kColumn * c) < + extent.column()); + } + + // Null pointer performs no accesses + if (!pointer) { + mask_.clear(); + } + + if (ScatterD && !indices) { + mask_.clear(); + } + + // Initialize pointer + byte_pointer_ = reinterpret_cast(pointer) + + LongIndex(thread_offset.row()) * LongIndex(params_.stride) + + LongIndex(thread_offset.column()) * sizeof(AccessType) / + kElementsPerAccess; + + if (ScatterD) { + byte_pointer_ = reinterpret_cast(pointer) + + LongIndex(thread_offset.column()) * sizeof(AccessType) / + kElementsPerAccess; + } + + // Initialize internal state counter + state_[0] = state_[1] = state_[2] = 0; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + byte_pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_DEVICE + void prefetch_all() { + CUTLASS_PRAGMA_UNROLL + for (int iter = 0; iter < kIterations; ++iter) { + prefetch(); + ++(*this); + } + } + + CUTLASS_DEVICE + void prefetch() { + uint8_t* byte_pointer = byte_pointer_; + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; + ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + AccessType* memory_pointer = + reinterpret_cast(byte_pointer); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; + ++column) { + // on windows using unsigned long here gives the error + // error: asm operand type size(4) does not match + // type/size implied by constraint 'l' + uint64_t addr = (uint64_t)((void*)&memory_pointer + [column * ThreadMap::Delta::kColumn / + kElementsPerAccess]); + asm volatile("prefetch.global.L1 [ %1 ];" : "=l"(addr) : "l"(addr)); + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + if (!ScatterD) { + byte_pointer += params_.increment_row; + } + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, int64_t byte_offset) const { + uint8_t* byte_pointer = byte_pointer_; + AccessType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; + ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + + ThreadMap::Iterations::kRow * + (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + AccessType* memory_pointer = + reinterpret_cast(byte_pointer + byte_offset); + + if (ScatterD && row_guard) { + assert(indices_); + + memory_pointer = reinterpret_cast( + byte_pointer + byte_offset + + LongIndex(indices_[row_offset + thread_start_row_]) * + LongIndex(params_.stride)); + } + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; + ++column) { + bool guard = row_guard && mask_.predicates[column]; + + cutlass::arch::global_load( + frag_ptr + [frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void*)&memory_pointer + [column * ThreadMap::Delta::kColumn / kElementsPerAccess], + guard); + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + if (!ScatterD) { + byte_pointer += params_.increment_row; + } + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) const { + load_with_byte_offset(frag, 0); + } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, int64_t byte_offset) const { + uint8_t* byte_pointer = byte_pointer_; + AccessType const* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; + ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + + ThreadMap::Iterations::kRow * + (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + AccessType* memory_pointer = + reinterpret_cast(byte_pointer + byte_offset); + + if (ScatterD && row_guard) { + assert(indices_); + + memory_pointer = reinterpret_cast( + byte_pointer + byte_offset + + LongIndex(indices_[row_offset + thread_start_row_]) * + LongIndex(params_.stride)); + } + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; + ++column) { + bool guard = row_guard && mask_.predicates[column]; + + if (UseCUDAStore) { + if (guard) { + memory_pointer + [column * ThreadMap::Delta::kColumn / kElementsPerAccess] = + frag_ptr + [frag_row_idx * ThreadMap::Iterations::kColumn + + column]; + } + } else { + cutlass::arch::global_store( + frag_ptr + [frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void*)&memory_pointer + [column * ThreadMap::Delta::kColumn / kElementsPerAccess], + guard); + } + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + if (!ScatterD) { + byte_pointer += params_.increment_row; + } + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) const { + store_with_byte_offset(frag, 0); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void downsample_load_with_byte_offset( + Fragment& frag, + int64_t byte_offset, + int convolution_P, + int convolution_Q, + int add_P, + int add_Q, + int problem_N) const { + uint8_t* byte_pointer = byte_pointer_; + AccessType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; + ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + + ThreadMap::Iterations::kRow * + (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + int output_row = row_offset + thread_start_row_; + int output_N = output_row / (convolution_P * convolution_Q); + int output_PQ = output_row % (convolution_P * convolution_Q); + int output_P = output_PQ / convolution_Q; + int output_Q = output_PQ % convolution_Q; + + int input_row = output_N * 2 * convolution_P * 2 * convolution_Q + + (2 * output_P + add_P) * 2 * convolution_Q + 2 * output_Q + add_Q; + + int64_t byte_offset = + (input_row - output_row) * problem_N * sizeof(float); + + AccessType* memory_pointer = + reinterpret_cast(byte_pointer + byte_offset); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; + ++column) { + bool guard = row_guard && mask_.predicates[column]; + + cutlass::arch::global_load( + frag_ptr + [frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void*)&memory_pointer + [column * ThreadMap::Delta::kColumn / kElementsPerAccess], + guard); + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + byte_pointer += params_.increment_row; + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void upsample_load_with_byte_offset( + Fragment& frag, + int64_t byte_offset, + int convolution_P, + int convolution_Q, + int add_P, + int add_Q, + int problem_N) const { + uint8_t* byte_pointer = byte_pointer_; + AccessType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; + ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + + ThreadMap::Iterations::kRow * + (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + int output_row = row_offset + thread_start_row_; + int output_N = output_row / (convolution_P * convolution_Q); + int output_PQ = output_row % (convolution_P * convolution_Q); + int output_P = output_PQ / convolution_Q; + int output_Q = output_PQ % convolution_Q; + int row_add_P = add_P; + int row_add_Q = add_Q; + if (output_P > convolution_P - 2) + row_add_P = 0; + if (output_Q > convolution_Q - 2) + row_add_Q = 0; + + int input_row = output_N * (convolution_P / 2) * (convolution_Q / 2) + + ((output_P + row_add_P) / 2) * (convolution_Q / 2) + + (output_Q + row_add_Q) / 2; + + int64_t byte_offset = + (input_row - output_row) * problem_N * sizeof(float); + + AccessType* memory_pointer = + reinterpret_cast(byte_pointer + byte_offset); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; + ++column) { + bool guard = row_guard && mask_.predicates[column]; + + cutlass::arch::global_load( + frag_ptr + [frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void*)&memory_pointer + [column * ThreadMap::Delta::kColumn / kElementsPerAccess], + guard); + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + byte_pointer += params_.increment_row; + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + CUTLASS_DEVICE + MatrixCoord thread_start() const { + return MatrixCoord(thread_start_row_, thread_start_column_); + } + + /// Need to get the thread start row from the tile iterator + CUTLASS_DEVICE + int32_t thread_start_row() const { + return thread_start_row_; + } + + /// Need to get the thread start row from the tile iterator + CUTLASS_DEVICE + int32_t thread_start_column() const { + return thread_start_column_; + } + + /// Extent of the matrix in rows + CUTLASS_DEVICE + Index extent_row() const { + return extent_row_; + } + + /// Extent of the matrix in columns + CUTLASS_DEVICE + Index extent_column() const { + return extent_column_; + } + + /// Advances to the next position to load or store + CUTLASS_HOST_DEVICE + PredicatedTileIteratorPrefetch& operator++() { + ++state_[0]; + + if (!ScatterD) { + byte_pointer_ += params_.advance_row; + } + + thread_start_row_ += ThreadMap::Shape::kRow; + + if (state_[0] == ThreadMap::Count::kRow) { + state_[0] = 0; + ++state_[1]; + byte_pointer_ += params_.advance_group; + + thread_start_row_ += (ThreadMap::Shape::kGroup - 1) * + ThreadMap::Shape::kRow * ThreadMap::Count::kRow; + + if (state_[1] == ThreadMap::Count::kGroup) { + state_[1] = 0; + ++state_[2]; + byte_pointer_ += params_.advance_cluster; + + thread_start_row_ += ThreadMap::Count::kGroup * + ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * + ThreadMap::Shape::kRow; + + if (state_[2] == ThreadMap::Count::kCluster) { + state_[2] = 0; + byte_pointer_ += params_.advance_tile; + } + } + } + + return *this; + } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_DEVICE void clear_mask() { + mask_.clear(); + } + + ///< Efficiently enables all accesses guarded by mask + CUTLASS_DEVICE void enable_mask() { + mask_.enable(); + } + + ///< Sets the mask + CUTLASS_DEVICE void get_mask(Mask& mask) const { + mask = mask_; + } + + ///< Sets the mask + CUTLASS_DEVICE void set_mask(Mask const& mask) { + mask_ = mask; + } +}; + +template +struct MakePrefetchableIterator { + using Iterator = PredicatedTileIteratorPrefetch< + typename IT::ThreadMap, + typename IT::Element>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/make_residual_last.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/make_residual_last.h new file mode 100644 index 0000000000000000000000000000000000000000..5c8e84bbf52ae598f607309077ed369cea5e1327 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/make_residual_last.h @@ -0,0 +1,79 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include +#include + + +namespace cutlass { +namespace transform { +namespace threadblock { + +template +struct MakeIteratorResidualLast; + +template < + typename Shape, + typename Element, + typename Layout, + int AdvanceRank, + typename ThreadMap, + int AccessSize, + bool Gather> +struct MakeIteratorResidualLast> { + using Iterator = PredicatedTileIteratorResidualLast< + Shape, + Element, + Layout, + AdvanceRank, + ThreadMap, + AccessSize, + Gather>; +}; + +template < + typename Shape, + typename Element, + typename Layout, + int AdvanceRank, + typename ThreadMap, + typename AccessType, + bool Gather> +struct MakeIteratorResidualLast> { + using Iterator = PredicatedTileAccessIteratorResidualLast< + Shape, + Element, + Layout, + AdvanceRank, + ThreadMap, + AccessType, + Gather>; +}; +} // namespace threadblock +} // namespace transform +} // namespace cutlass + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/predicated_tile_access_iterator_residual_last.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/predicated_tile_access_iterator_residual_last.h new file mode 100644 index 0000000000000000000000000000000000000000..1b1f9f6917b428f129a668591804a73b9f22592c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/predicated_tile_access_iterator_residual_last.h @@ -0,0 +1,2120 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates calculating the address and predicates to the load of tiles + from pitch-linear rank=2 tensors. + + This iterator uses masks to guard out-of-bounds accesses. The first tile + this iterator visits maybe partial, then the remaining tiles are complete. + So, we only need to compute the predicates twice, once before the first tile + and once for the remaining full tiles which can share the same predicates. + + A precomputed "Params" object minimizes the amount of state that must be + stored in registers, and integer addition is used to advance the pointer + through memory. +*/ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// PredicatedTileAccessIteratorResidualLast +/// +template < + typename Shape, + typename Element, + typename Layout, + int AdvanceRank, + typename ThreadMap, + typename AccessType, + bool Gather = false> +class PredicatedTileAccessIteratorResidualLast; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for pitch-linear +/// data. +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + typename AccessType_, + bool Gather> +class PredicatedTileAccessIteratorResidualLast< + Shape_, + Element_, + layout::PitchLinear, + AdvanceRank, + ThreadMap_, + AccessType_, + Gather> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::PitchLinear; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingPredicates = PredicatedTileAccessIteratorPredicates< + Shape, + Element, + Layout, + AdvanceRank, + ThreadMap, + AccessType>; + + static int const kAccessesPerVector = + ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert( + !(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + using Mask = typename UnderlyingPredicates::Mask; + + /// Uses a non-template class + struct Params : PredicatedTileAccessIteratorParams { + using Base = PredicatedTileAccessIteratorParams; + + // Default ctor + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : Base( + layout.stride(0), + MakePredicatedTileAccessIteratorDesc< + Shape, + Element, + Layout, + kAdvanceRank, + ThreadMap>()()) {} + + CUTLASS_HOST_DEVICE + Params(Base const& base) : Base(base) {} + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char*; + + private: + // + // Data members + // + + UnderlyingPredicates the_predicates; + Mask residual_tile_mask; + + /// Parameters object with precomputed internal state + Params params_; + + /// Internal pointer to first access of tile + BytePointer pointer_; + + /// Below is used when Gather is turned on. We need to record strided_offset + /// and contiguous_offset separated to compute the offset by using + /// + /// offset = contiguous_offset + indices[strided_offset] + /// + + /// Gather indices + int const* indices_; + + Index gather_offset_strided; + + private: + /// Computes predicates based on internally tracked per-thread offset. + CUTLASS_DEVICE + void compute_predicates_( + /// Extent of the matrix window + TensorCoord extent, + /// optionally, simplify predicate calculation during 'steady state' phase + bool is_steady_state = false) { + the_predicates.compute_predicates_(extent, is_steady_state); + } + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + /// Gather indices + int const* indices = nullptr) + : params_(params), + pointer_(reinterpret_cast( + const_cast(pointer))), + the_predicates(extent), + indices_(indices) { + the_predicates.set_predicates(thread_id, threadblock_offset); + the_predicates.get_mask(residual_tile_mask); + + // Working around a weird compiler bug happening on P100 for the backward. + // I've seen together: the_predicates.predicates_[0] = 14 (instead of 15) + // residual_tile_mask[0] = 15 (correct) + // + // Adding prints when the value is calculated (in `compute_predicates_`) + // sometimes removes the bug. The consequence is that we skip some + // element of a tensor, leading to wrong results + // Setting `compute_predicates_`'s second argument (`is_steady_state`) to + // true also seems to get rid of the bug - at the cost of twice as many + // comparisons. +#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 700) + constexpr bool kWorkAroundCompilerBug = false; +#else + constexpr bool kWorkAroundCompilerBug = true; +#endif + the_predicates.compute_predicates_(extent, true && !kWorkAroundCompilerBug); + + // update internal pointers + Layout layout(params_.stride_); + + if (!Gather) { + add_pointer_offset(layout(the_predicates.thread_offset_)); + } else { + gather_offset_strided = the_predicates.thread_offset_.strided(); + add_pointer_offset( + layout(make_Coord(the_predicates.thread_offset_.contiguous(), 0))); + } + } + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id) + : PredicatedTileAccessIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + the_predicates.set_iteration_index(index); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool is_residual_tile) { + if (is_residual_tile) { + the_predicates.set_mask(residual_tile_mask); + } + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += sizeof_bits::value * pointer_offset / 8; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + if (!Gather) { + if (kAdvanceRank) { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset.strided()); + pointer_ += Shape::kContiguous * tile_offset.contiguous(); + } else { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset.contiguous()); + pointer_ += Shape::kStrided * tile_offset.strided(); + } + } else { + add_pointer_offset(Shape::kContiguous * tile_offset.contiguous()); + gather_offset_strided += Shape::kStrided * tile_offset.strided(); + } + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { + if (Gather) { + assert(indices_); + + if (!valid()) { + return nullptr; + } + + LongIndex contiguous_offset = the_predicates.iteration_contiguous_ * + (ThreadMap::Delta::kContiguous * sizeof_bits::value / + 8) + + the_predicates.iteration_vector_; + int strided_index = gather_offset_strided + + the_predicates.iteration_strided_ * ThreadMap::Delta::kStrided; + + LongIndex strided_offset = indices_[strided_index] * + LongIndex(params_.stride_) * sizeof_bits::value / 8; + + return reinterpret_cast( + pointer_ + contiguous_offset + strided_offset); + } + + return reinterpret_cast( + pointer_ + + the_predicates.iteration_contiguous_ * + (ThreadMap::Delta::kContiguous * + sizeof_bits::value) / + 8) + + the_predicates.iteration_vector_; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() { + the_predicates.operator++(); + + ++the_predicates.iteration_vector_; + if (the_predicates.iteration_vector_ < kAccessesPerVector) { + return *this; + } + + the_predicates.iteration_vector_ = 0; + ++the_predicates.iteration_contiguous_; + + if (the_predicates.iteration_contiguous_ < + ThreadMap::Iterations::kContiguous) { + return *this; + } + + // Enter here only if (iteration_contiguous_ == + // ThreadMap::Iteration::kContiguous) + the_predicates.iteration_contiguous_ = 0; + ++the_predicates.iteration_strided_; + + if (the_predicates.iteration_strided_ < ThreadMap::Iterations::kStrided) { + if (!Gather) { + pointer_ += params_.inc_strided_; + } + + return *this; + } + + // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) + // which means we enter the next tile. + the_predicates.iteration_strided_ = 0; + + if (!Gather) { + // advance to next tile + pointer_ += params_.inc_next_; + + // now return to start tile - if the iterator is subsequently advanced, + // this subtraction as well as the subsequent integer addition are both + // elided by the compiler. + pointer_ -= params_.inc_advance_; + } + + return *this; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + the_predicates.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + the_predicates.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + the_predicates.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + the_predicates.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() const { + return the_predicates.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for column-major +/// data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + typename AccessType_, + bool Gather> +class PredicatedTileAccessIteratorResidualLast< + Shape_, + Element_, + layout::ColumnMajor, + AdvanceRank, + ThreadMap_, + AccessType_, + Gather> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessType, + Gather>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + /// Default ctor + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::PitchLinear(layout.stride(0))){}; + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + ///< Precomputed parameters object + Params const& params, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.row(), extent.column()), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.row(), + threadblock_offset.column()), + indices) {} + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + iterator_.set_iteration_index(index); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for row-major +/// data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + typename AccessType_, + bool Gather> +class PredicatedTileAccessIteratorResidualLast< + Shape_, + Element_, + layout::RowMajor, + AdvanceRank, + ThreadMap_, + AccessType_, + Gather> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessType, + Gather>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + /// Default ctor + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::PitchLinear(layout.stride(0))){}; + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + ///< Precomputed parameters object + Params const& params, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const& threadblock_offset, + /// Gather indices + int const* indices = nullptr) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.column(), + threadblock_offset.row()), + indices) {} + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + iterator_.set_iteration_index(index); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for affine rank 2 +/// data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + typename AccessType_> +class PredicatedTileAccessIteratorResidualLast< + Shape_, + Element_, + layout::AffineRankN<2>, + AdvanceRank, + ThreadMap_, + AccessType_, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRankN<2>; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingPredicates = PredicatedTileAccessIteratorPredicates< + Shape, + Element, + layout::PitchLinear, + AdvanceRank, + ThreadMap, + AccessType>; + + static int const kAccessesPerVector = + ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert( + !(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingPredicates::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + public: + friend PredicatedTileAccessIteratorResidualLast; + + private: + /// stride of pitch-linear layout (units of Element) + Coord stride_; + /// amount (in byte) to increment pointer to move to next access along + /// contiguous dimension + LongIndex inc_contiguous_; + /// amount (in byte) to increment pointer from first access of current + /// contiguous dimension to first access of next one. + LongIndex inc_strided_; + /// amount (in byte) to increment pointer from last access of current + /// contiguous dimension to first access of next one. + LongIndex inc_next_strided_; + /// amount (in byte) to increment pointer from last access to first access + /// of next tile + LongIndex inc_next_; + /// amount (in byte) to increment pointer from first access of current tile + /// to first access of next tile + LongIndex inc_advance_; + + public: + // Default ctor + CUTLASS_HOST_DEVICE + Params() + : stride_(0), + inc_contiguous_(0), + inc_strided_(0), + inc_next_(0), + inc_advance_(0) {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : stride_({layout.stride(0), layout.stride(1)}) { + inc_contiguous_ = + (LongIndex(stride_[0]) * ThreadMap::Delta::kContiguous) * + sizeof_bits::value / 8; + + inc_strided_ = (LongIndex(stride_[1]) * ThreadMap::Delta::kStrided) * + sizeof_bits::value / 8; + + inc_next_strided_ = inc_strided_ - + LongIndex(ThreadMap::Iterations::kContiguous - 1) * inc_contiguous_; + + if (kAdvanceRank) { + // advance along strided dimension + inc_advance_ = Shape::kStrided * LongIndex(stride_[1]) * + sizeof_bits::value / 8; + } else { + // advance along contiguous dimension + inc_advance_ = + Shape::kContiguous * stride_[0] * sizeof_bits::value / 8; + } + + inc_next_ = inc_advance_ - + LongIndex(ThreadMap::Iterations::kContiguous - 1) * inc_contiguous_ - + LongIndex(ThreadMap::Iterations::kStrided - 1) * inc_strided_; + }; + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char*; + + // + // Data members + // + + /// Parameters object with precomputed internal state + Params params_; + + /// Internal pointer to first access of tile + BytePointer pointer_; + + UnderlyingPredicates the_predicates; + Mask residual_tile_mask; + + private: + /// Computes predicates based on internally tracked per-thread offset. + CUTLASS_DEVICE + void compute_predicates_( + /// Extent of the matrix window + TensorCoord extent, + /// optionally, simplify predicate calculation during 'steady state' phase + bool is_steady_state = false) { + the_predicates.compute_predicates_(extent, is_steady_state); + } + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + ///< Precomputed parameters object + Params const& params, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : params_(params), + pointer_(reinterpret_cast( + const_cast(pointer))), + the_predicates(extent) { + the_predicates.set_predicates(thread_id, threadblock_offset); + + // update internal pointers + Layout layout(params_.stride_); + add_pointer_offset(layout(the_predicates.thread_offset_)); + } + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + the_predicates.set_iteration_index(index); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool is_residual_tile) { + if (is_residual_tile) { + the_predicates.set_mask(residual_tile_mask); + } + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += sizeof_bits::value * pointer_offset / 8; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + if (kAdvanceRank) { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset[1]); + pointer_ += Shape::kContiguous * tile_offset[0]; + } else { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset[0]); + pointer_ += Shape::kStrided * tile_offset[1]; + } + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { + return reinterpret_cast(pointer_) + + the_predicates.iteration_vector_; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() { + the_predicates.operator++(); + ++the_predicates.iteration_vector_; + if (the_predicates.iteration_vector_ < kAccessesPerVector) { + return *this; + } + + the_predicates.iteration_vector_ = 0; + ++the_predicates.iteration_contiguous_; + + if (the_predicates.iteration_contiguous_ < + ThreadMap::Iterations::kContiguous) { + pointer_ += params_.inc_contiguous_; + return *this; + } + + // Enter here only if (iteration_contiguous_ == + // ThreadMap::Iteration::kContiguous) + the_predicates.iteration_contiguous_ = 0; + ++the_predicates.iteration_strided_; + + if (the_predicates.iteration_strided_ < ThreadMap::Iterations::kStrided) { + pointer_ += params_.inc_next_strided_; + return *this; + } + + // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) + // which means we enter the next tile. + the_predicates.iteration_strided_ = 0; + + // advance to next tile + pointer_ += params_.inc_next_; + + // now return to start tile - if the iterator is subsequently advanced, this + // subtraction as well as the subsequent integer addition are both elided by + // the compiler. + pointer_ -= params_.inc_advance_; + + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + the_predicates.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + the_predicates.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + the_predicates.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + the_predicates.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return the_predicates.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for affine rank 2 +/// column-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + typename AccessType_> +class PredicatedTileAccessIteratorResidualLast< + Shape_, + Element_, + layout::AffineRank2ColumnMajor, + AdvanceRank, + ThreadMap_, + AccessType_, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRank2ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + // Map to the underlying AffineRankN<2> layout + using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::AffineRankN<2>, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessType>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + /// Default ctor + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given an AffineRankN<2> tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::AffineRankN<2>(layout.stride(0), layout.stride(1))){}; + }; + + private: + // + // Data members + // + + /// Underlying AffineRankN<2> tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + ///< Precomputed parameters object + Params const& params, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.row(), extent.column()), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.row(), + threadblock_offset.column())) {} + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + iterator_.set_iteration_index(index); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + iterator_.add_tile_offset( + make_Coord(tile_offset.row(), tile_offset.column())); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for affine rank-2 +/// row-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + typename AccessType_> +class PredicatedTileAccessIteratorResidualLast< + Shape_, + Element_, + layout::AffineRank2RowMajor, + AdvanceRank, + ThreadMap_, + AccessType_, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRank2RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + // Map to the underlying AffineRankN<2> layout + using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::AffineRankN<2>, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessType>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + /// Default ctor + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given an AffineRankN<2> tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::AffineRankN<2>(layout.stride(1), layout.stride(0))){}; + }; + + private: + // + // Data members + // + + /// Underlying AffineRankN<2> tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + ///< Precomputed parameters object + Params const& params, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.column(), + threadblock_offset.row())) {} + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + iterator_.set_iteration_index(index); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + iterator_.add_tile_offset( + make_Coord(tile_offset.column(), tile_offset.row())); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for column-major +/// interleaved data. It is mapped to the congruous layout. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// + +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + typename AccessType_, + int InterleavedK> +class PredicatedTileAccessIteratorResidualLast< + Shape_, + Element_, + layout::ColumnMajorInterleaved, + AdvanceRank, + ThreadMap_, + AccessType_, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + static int const kInterleavedK = InterleavedK; + using Layout = layout::ColumnMajorInterleaved; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< + layout::PitchLinearShape< + Shape::kRow * kInterleavedK, + Shape::kColumn / kInterleavedK>, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessType>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord( + extent.row() * kInterleavedK, + extent.column() / kInterleavedK), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.row() * kInterleavedK, + threadblock_offset.column() / kInterleavedK)) {} + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + iterator_.set_iteration_index(index); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for row-major +/// interleaved data. +// It is mapped to the congruous layout. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + typename AccessType_, + int InterleavedK> +class PredicatedTileAccessIteratorResidualLast< + Shape_, + Element_, + layout::RowMajorInterleaved, + AdvanceRank, + ThreadMap_, + AccessType_, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + static int const kInterleavedK = InterleavedK; + using Layout = layout::RowMajorInterleaved; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< + layout::PitchLinearShape< + Shape::kColumn * kInterleavedK, + Shape::kRow / kInterleavedK>, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessType>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord( + extent.column() * kInterleavedK, + extent.row() / kInterleavedK), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.column() * kInterleavedK, + threadblock_offset.row() / kInterleavedK)) {} + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + iterator_.set_iteration_index(index); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/predicated_tile_iterator_residual_last.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/predicated_tile_iterator_residual_last.h new file mode 100644 index 0000000000000000000000000000000000000000..0f1f5a2f63f12a0537ced07872ea81121315f9d4 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/predicated_tile_iterator_residual_last.h @@ -0,0 +1,2125 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of tiles from pitch-linear rank=2 + tensors. + + This iterator uses masks to guard out-of-bounds accesses. The first tile + this iterator visits maybe partial, then the remaining tiles are complete. + So, we only need to compute the predicates twice, once before the first tile + and once for the remaining full tiles which can share the same predicates. + + A precomputed "Params" object minimizes the amount of state that must be + stored in registers, and integer addition is used to advance the pointer + through memory. +*/ + +#pragma once + +#include +#include + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// PredicatedTileIteratorResidualLast +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +/// Regular tile iterator using a precomputed control structure to minimize +/// register liveness and integer arithmetic. +/// +/// Layout is assumed to be invariant at the time the precomputed "Params" +/// object is constructed. +/// +/// Base pointer and tensor extents may be specified at the time the iterator is +/// constructed. Subsequently, they are assumed to be immutable. +/// +/// Adding a logical coordinate offset may be performed at the time the iterator +/// is constructed. Subsequent additions to logical coordinate offset may be +/// performed but are relatively expensive. +/// +/// Visitation order is intended to first visit a "residual" tile that may be +/// partially full in both the advance dimension and the steady-state dimension. +/// This is assumed to be the last tile in the iteration sequence. Advancing an +/// iterator that has just been constructed moves to the first tile that is full +/// in the advance dimension and recomputes predicates. Subsequent accesses may +/// be performed without updating internal predicates and are efficient in terms +/// of live register state and pointer arithmetic instructions. +/// +/// To be efficient, this assumes the iterator will be dereferenced and advanced +/// at least once outside any looping structure to minimize integer arithmetic. +/// +/// Access out of bounds are safe so long as `clear_mask()` is called prior to +/// dereferencing the iterator. +/// +/// +/// Example: +/// +/// An efficient pipeline structure may be constructed as follows: +/// +// template +// __global__ void kernel( +// typename Iterator::Params params, +// typename Iterator::Element *ptr, +// TensorCoord extent) { +// +// typename Iterator::Fragment fragment; +// +// TensorCoord threadblock_offset(0, 0); +// +// Iterator iter(params, ptr, extent, threadIdx.x, threadblock_offsets); +// +// +// fragment = *iter; // load "residue" tile first +// ++iter; // advance to first "steady state" tile and update +// internal masks +// +// +// #pragma unroll +// for (int i = Remaining - 1; i >= 0; --i) { +// +// f(fragment); +// +// if (!i) { +// iter.clear_mask(); // light-weight operation to clear masks - +// subsequent loads become NO-OPs. +// } +// +// fragment = *iter; // load tile during "steady state" phase +// ++iter; // advance to next tile - lightweight due to +// steady-state masks +// } +// } +// +// void host(TensorView view) { +// +// using Iterator = +// transform::threadblock::PredicatedTileIteratorResidualLast; +// +// typename Iterator::Params params(view.layout()); +// +// kernel(params, view.data()); +// } +/// +/// +template < + typename Shape, + typename Element, + typename Layout, + int AdvanceRank, + typename ThreadMap, + int AccessSize = ThreadMap::kElementsPerAccess, + bool Gather = false> +class PredicatedTileIteratorResidualLast; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for pitch-linear data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize, + bool Gather> +class PredicatedTileIteratorResidualLast< + Shape_, + Element_, + layout::PitchLinear, + AdvanceRank, + ThreadMap_, + AccessSize, + Gather> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::PitchLinear; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + /// Type used for internal memory accesses + using AccessType = AlignedArray< + Element, + AccessSize, + (AccessSize * sizeof_bits::value / 8)>; + + /// Underlying iterator to compute the addresses + using TileAccessIterator = PredicatedTileAccessIteratorResidualLast< + Shape, + Element, + Layout, + kAdvanceRank, + ThreadMap, + AccessType, + Gather>; + + static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array< + Element, + ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename TileAccessIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + public: + using Base = typename TileAccessIterator::Params::Base; + + friend PredicatedTileIteratorResidualLast; + + private: + /// Parameters object + typename TileAccessIterator::Params params_; + + public: + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) : params_(layout) {} + + CUTLASS_HOST_DEVICE + Params() {} + + CUTLASS_HOST_DEVICE + Params(Base const& base) : params_(base) {} + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char*; + + private: + // + // Data members + // + + /// Data member to the tile access iterator + TileAccessIterator address_iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + /// Gather indices + int const* indices = nullptr) + : address_iterator_( + params.params_, + pointer, + extent, + thread_id, + threadblock_offset, + indices) {} + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + address_iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() { + if (kAdvanceRank) + address_iterator_.add_tile_offset({0, 1}); + else + address_iterator_.add_tile_offset({1, 0}); + + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + address_iterator_.clear_mask(enable); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + address_iterator_.set_residual_tile(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + address_iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + address_iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + address_iterator_.get_mask(mask); + } + + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { + load_with_byte_offset( + frag, pointer_offset * sizeof_bits::value / 8); + } + + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { + AccessType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + int idx = v + + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + + address_iterator_.set_iteration_index(idx); + char const* byte_ptr = + reinterpret_cast(address_iterator_.get()) + + byte_offset; + + AccessType const* access_ptr = + reinterpret_cast(byte_ptr); + + cutlass::arch::global_load( + frag_ptr[idx], access_ptr, address_iterator_.valid()); + + ++address_iterator_; + } + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { + load_with_byte_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { + store_with_byte_offset( + frag, pointer_offset * sizeof_bits::value / 8); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { + address_iterator_.set_iteration_index(0); + AccessType const* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + int idx = v + + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + + char* byte_ptr = + reinterpret_cast(address_iterator_.get()) + byte_offset; + AccessType* access_ptr = reinterpret_cast(byte_ptr); + + if (address_iterator_.valid()) { + *access_ptr = frag_ptr[idx]; + } + ++address_iterator_; + } + } + } + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { + store_with_byte_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for pitch-linear data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize, + bool Gather> +class PredicatedTileIteratorResidualLast< + Shape_, + Element_, + layout::ColumnMajor, + AdvanceRank, + ThreadMap_, + AccessSize, + Gather> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessSize, + Gather>; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array< + Element, + ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + TensorCoord const& threadblock_offset, ///< Initial offset of threadblock + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.row(), extent.column()), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.row(), + threadblock_offset.column()), + indices) {} + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { + iterator_.store_with_byte_offset(frag, byte_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for pitch-linear data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize, + bool Gather> +class PredicatedTileIteratorResidualLast< + Shape_, + Element_, + layout::RowMajor, + AdvanceRank, + ThreadMap_, + AccessSize, + Gather> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessSize, + Gather>; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array< + Element, + ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + TensorCoord const& threadblock_offset, ///< Initial offset of threadblock + int const* indices = nullptr ///< Gather indices + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.column(), + threadblock_offset.row()), + indices) {} + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { + iterator_.store_with_byte_offset(frag, byte_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for affine rank-2 data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize> +class PredicatedTileIteratorResidualLast< + Shape_, + Element_, + layout::AffineRankN<2>, + AdvanceRank, + ThreadMap_, + AccessSize, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRankN<2>; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + /// Type used for internal memory accesses + using AccessType = AlignedArray< + Element, + AccessSize, + (AccessSize * sizeof_bits::value / 8)>; + + /// Underlying iterator to compute the addresses + using TileAccessIterator = PredicatedTileAccessIteratorResidualLast< + Shape, + Element, + Layout, + kAdvanceRank, + ThreadMap, + AccessType>; + + static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array< + Element, + ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename TileAccessIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + public: + friend PredicatedTileIteratorResidualLast; + + private: + /// Parameters object + typename TileAccessIterator::Params params_; + + public: + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) : params_(layout) {} + + CUTLASS_HOST_DEVICE + Params() {} + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char*; + + private: + // + // Data members + // + + /// Data member to the tile access iterator + TileAccessIterator address_iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : address_iterator_( + params.params_, + pointer, + extent, + thread_id, + threadblock_offset) {} + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + address_iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() { + if (kAdvanceRank) + address_iterator_.add_tile_offset(make_Coord(0, 1)); + else + address_iterator_.add_tile_offset(make_Coord(1, 0)); + + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + address_iterator_.clear_mask(enable); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + address_iterator_.set_residual_tile(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + address_iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + address_iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + address_iterator_.get_mask(mask); + } + + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { + load_with_byte_offset( + frag, pointer_offset * sizeof_bits::value / 8); + } + + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { + AccessType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + int idx = v + + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + + address_iterator_.set_iteration_index(idx); + char const* byte_ptr = + reinterpret_cast(address_iterator_.get()) + + byte_offset; + + AccessType const* access_ptr = + reinterpret_cast(byte_ptr); + + cutlass::arch::global_load( + frag_ptr[idx], access_ptr, address_iterator_.valid()); + + ++address_iterator_; + } + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { + load_with_byte_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { + store_with_byte_offset( + frag, pointer_offset * sizeof_bits::value / 8); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { + address_iterator_.set_iteration_index(0); + AccessType const* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + int idx = v + + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + + char* byte_ptr = + reinterpret_cast(address_iterator_.get()) + byte_offset; + AccessType* access_ptr = reinterpret_cast(byte_ptr); + + if (address_iterator_.valid()) { + *access_ptr = frag_ptr[idx]; + } + ++address_iterator_; + } + } + } + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { + store_with_byte_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for affine rank 2 +/// column-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize> +class PredicatedTileIteratorResidualLast< + Shape_, + Element_, + layout::AffineRank2ColumnMajor, + AdvanceRank, + ThreadMap_, + AccessSize, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRank2ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + // Map to the underlying AffineRankN<2> layout + using UnderlyingIterator = PredicatedTileIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::AffineRankN<2>, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessSize>; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array< + Element, + ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given an AffineRankN<2> tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::AffineRankN<2>(layout.stride(0), layout.stride(1))) {} + }; + + private: + // + // Data members + // + + /// Underlying AffineRankN<2> tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + TensorCoord const& threadblock_offset, ///< Initial offset of threadblock + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.row(), extent.column()), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.row(), + threadblock_offset.column())) {} + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { + iterator_.store_with_byte_offset(frag, byte_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for affine rank 2 +/// row-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize> +class PredicatedTileIteratorResidualLast< + Shape_, + Element_, + layout::AffineRank2RowMajor, + AdvanceRank, + ThreadMap_, + AccessSize, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRank2RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + // Map to the underlying AffineRankN<2> layout + using UnderlyingIterator = PredicatedTileIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::AffineRankN<2>, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessSize>; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array< + Element, + ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given an AffineRankN<2> tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::AffineRankN<2>(layout.stride(1), layout.stride(0))) {} + }; + + private: + // + // Data members + // + + /// Underlying AffineRankN<2> tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + TensorCoord const& threadblock_offset, ///< Initial offset of threadblock + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.column(), + threadblock_offset.row())) {} + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { + iterator_.store_with_byte_offset(frag, byte_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for interleaved data. +/// It is mapped to the congruous layout. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// + +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize, + int InterleavedK> +class PredicatedTileIteratorResidualLast< + Shape_, + Element_, + layout::ColumnMajorInterleaved, + AdvanceRank, + ThreadMap_, + AccessSize, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + static int const kInterleavedK = InterleavedK; + using Layout = layout::ColumnMajorInterleaved; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileIteratorResidualLast< + layout::PitchLinearShape< + Shape::kRow * kInterleavedK, + Shape::kColumn / kInterleavedK>, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessSize>; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array< + Element, + ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord( + extent.row() * kInterleavedK, + extent.column() / kInterleavedK), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.row() * kInterleavedK, + threadblock_offset.column() / kInterleavedK)) {} + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for interleaved-32 +/// data. It is mapped to the congruous layout. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize, + int InterleavedK> +class PredicatedTileIteratorResidualLast< + Shape_, + Element_, + layout::RowMajorInterleaved, + AdvanceRank, + ThreadMap_, + AccessSize, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + static int const kInterleavedK = InterleavedK; + using Layout = layout::RowMajorInterleaved; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileIteratorResidualLast< + layout::PitchLinearShape< + Shape::kColumn * kInterleavedK, + Shape::kRow / kInterleavedK>, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessSize>; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array< + Element, + ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord( + extent.column() * kInterleavedK, + extent.row() / kInterleavedK), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.column() * kInterleavedK, + threadblock_offset.row() / kInterleavedK)) {} + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/transpose_warp_iterator.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/transpose_warp_iterator.h new file mode 100644 index 0000000000000000000000000000000000000000..916d436d7e86a8af36167a780ff726ef5b410ac8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/transpose_warp_iterator.h @@ -0,0 +1,36 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include + +template +struct TransposeWarpIterator { + using Iterator = char; + static bool constexpr kSupportsTranspose = false; +}; + +template < + /// Operand identity + cutlass::gemm::Operand Operand, + /// Data type of A elements + typename Element, + typename InstructionShape, + bool kTranspose> +struct TransposeWarpIterator< + cutlass::gemm::warp:: + WarpIteratorFromSmem> { + using Iterator = cutlass::gemm::warp:: + WarpIteratorFromSmem; + static bool constexpr kSupportsTranspose = true; +}; + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/warp_iterator_from_smem.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/warp_iterator_from_smem.h new file mode 100644 index 0000000000000000000000000000000000000000..000ad3f97ac2d3f9918cb37dd3246ee5482dded8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/warp_iterator_from_smem.h @@ -0,0 +1,289 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Inspired from + "cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h" Loads tiles of GEMM + operands from a RowMajor shared-memory layout into registers to use by A100 + TensorCores. + + The difference with "mma_tensor_op_tile_access_iterator.h" is that: + (1) We use "ldmatrix" to load tiles, rather than manual loads (slightly + faster) (2) We support to transpose the operand (eg read `A.transpose()` when + the shared memory holds `A`) + + This is only implemented for the specific shapes. +*/ +#pragma once + +#include + +//////////////////////////////////////////////////////////////////////////////// +namespace cutlass { +namespace gemm { +namespace warp { + +template < + /// Operand identity + Operand Operand_, + /// Data type of A elements + typename Element_, + typename InstructionShape_, + bool kTranspose = false> +class WarpIteratorFromSmem { + public: + /// Shape of tile to load (concept: MatrixShape) + using Shape = cutlass::MatrixShape<32, 32>; + + /// Operand tag + static Operand const kOperand = Operand_; + static_assert( + kOperand == Operand::kA, + "No support for OperandB at the moment"); + + /// Basic check + static_assert( + kOperand == Operand::kA || kOperand == Operand::kB, + "WarpIteratorFromSmem may only be instantiated for A or B operands to warp-level Mma."); + + /// Element type + using Element = Element_; + static_assert(sizeof_bits::value == 16, "Only supported for half"); + + /// Layout of source tile + using Layout = cutlass::layout::RowMajor; + + /// Shape of one matrix product operation (concept: MatrixShape) + using InstructionShape = InstructionShape_; + static_assert(InstructionShape::kRow == 16, "Only supports 16x8x8 / 16x8x16"); + static_assert( + InstructionShape::kColumn == 8 || InstructionShape::kColumn == 16, + "Only supports 16x8x8 / 16x8x16"); + + /// Delta between *MMA operations (in units of *MMA operations, concept: + /// MatrixShape) + static int const kOpDelta = 1; + + /// Number of participating threads + static int const kThreads = 32; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + /// Index type + using Index = typename TensorRef::Index; + + /// Long Index type + using LongIndex = typename TensorRef::LongIndex; + + /// Coordinate for an element in the tensor + using TensorCoord = typename TensorRef::TensorCoord; + + /// Number of elements accessed per Shared Memory load + static int const kElementsPerAccess = + (sizeof_bits::value >= 32 ? 1 + : 32 / sizeof_bits::value); + + using InstructionCount = MatrixShape< + Shape::kRow / InstructionShape::kRow, + Shape::kColumn / InstructionShape::kColumn>; + + static int const kIterations = (kOperand == Operand::kA) + ? InstructionCount::kColumn + : InstructionCount::kRow; + + public: + // + // Derived quantities + // + + /// Fragment object holding a thread's part of a tile + using Fragment = Array< + Element, + (kOperand == Operand::kA) + ? (Shape::kRow* InstructionShape::kColumn / kThreads) + : (Shape::kColumn* InstructionShape::kRow / kThreads)>; + + /// Memory access type + // using AccessType = AlignedArray; + using AccessType = Array; + + static int constexpr kWarpShapeDivisibleInner = + (kOperand == Operand::kA ? InstructionShape::kColumn + : InstructionShape::kRow); + static int constexpr kAccessesInner = + (kWarpShapeDivisibleInner / kElementsPerAccess) / 4; + // Number of 32bits tiles to load per `ldmatrix` + static int const kTilesPerInstruction = InstructionShape::kRow / 8; + static_assert(kTilesPerInstruction == 2, "Only supports 16x8x16 and 16x8x8"); + + private: + /// Underlying tensor reference + TensorRef ref_; + + /// Origin + MatrixCoord origin_; + + /// Iterations in a tile + int iterations_; + + public: + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + WarpIteratorFromSmem(TensorRef const& ref, int lane_id) + : WarpIteratorFromSmem(ref, {Shape::kRow, Shape::kColumn}, lane_id) {} + CUTLASS_HOST_DEVICE + WarpIteratorFromSmem(TensorRef const& ref, TensorCoord extent, int lane_id) + : ref_(ref), iterations_(0) { + // See also: + // https://docs.nvidia.com/cuda/archive/11.7.1/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-1688 + // 16x8x8: kAccessesInner = 1 (1 ldmatrix.x4) + // 16x8x16: kAccessesInner = 2 (2 ldmatrix.x4) + int ldsm_vec_num = (lane_id >> 3); + if (kOperand == Operand::kA) { + origin_ = MatrixCoord(lane_id % 8, 0); + static_assert( + InstructionCount::kRow * kTilesPerInstruction == 4, + "can't use ldmatrix.x4"); + int access_m_idx = ldsm_vec_num % kTilesPerInstruction; + int inner_idx = (ldsm_vec_num / kTilesPerInstruction) % kAccessesInner; + int inst_m_idx = ldsm_vec_num / (kTilesPerInstruction * kAccessesInner); + MatrixCoord offset( + access_m_idx * 8 + inst_m_idx * InstructionShape::kRow, + inner_idx * 4 * kElementsPerAccess); + if (kTranspose) { + offset = MatrixCoord(offset.column(), offset.row()); + } + origin_ += offset; + } else { + // XXX: This is not tested or used + origin_ = MatrixCoord(0, lane_id % 8); + static_assert(InstructionCount::kColumn * kAccessesInner == 4, ""); + CUTLASS_PRAGMA_UNROLL + for (int inst_n_idx = 0; inst_n_idx < InstructionCount::kColumn; + ++inst_n_idx) { + CUTLASS_PRAGMA_UNROLL + for (int inner_idx = 0; inner_idx < kAccessesInner; ++inner_idx) { + int access_idx = inner_idx + kAccessesInner * inst_n_idx; + + MatrixCoord offset( + inner_idx * 4 * kElementsPerAccess, inst_n_idx * 8); + + if (access_idx == ldsm_vec_num) { + if (kTranspose) { + offset = MatrixCoord(offset.column(), offset.row()); + } + origin_ += offset; + } + } + } + } + + ref_.add_coord_offset(origin_); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + WarpIteratorFromSmem& add_tile_offset(TensorCoord const& tile_offset) { + TensorCoord coord_offset( + tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn); + if (kTranspose) { + coord_offset = TensorCoord{coord_offset.column(), coord_offset.row()}; + } + origin_ += coord_offset; + + ref_.add_coord_offset(coord_offset); + + return *this; + } + + /// Advances the iterator along the advance dimension + CUTLASS_DEVICE + void advance() { + if (kOperand == Operand::kA) { + add_tile_offset({0, 1}); + } else { + add_tile_offset({1, 0}); + } + + iterations_ = 0; + } + + /// increase iterations in a tile + CUTLASS_HOST_DEVICE + WarpIteratorFromSmem& operator++() { + iterations_++; + + if (iterations_ >= kIterations) + advance(); + + return *this; + } + + /// Loads a fragment from memory at the location pointed to by the iterator. + CUTLASS_DEVICE + void load(Fragment& frag) const { + AccessType* access_ptr = reinterpret_cast(&frag); + using LoadLayout = typename platform:: + conditional::type; + + CUTLASS_PRAGMA_UNROLL + for (int access_m_idx = 0; access_m_idx < + (InstructionCount::kRow * kTilesPerInstruction * kAccessesInner) / 4; + ++access_m_idx) { + MatrixCoord offset; + if (kOperand == Operand::kA) { + offset = MatrixCoord( + access_m_idx * 16, iterations_ * InstructionShape::kColumn); + } else { + offset = MatrixCoord(iterations_ * InstructionShape::kRow, 0); + } + if (kTranspose) { + offset = MatrixCoord(offset.column(), offset.row()); + } + cutlass::arch::ldsm( + access_ptr[access_m_idx], ref_.data() + ref_.offset(offset)); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass +//////////////////////////////////////////////////////////////////////////////// + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB.h new file mode 100644 index 0000000000000000000000000000000000000000..4bcfdfae1ad9ce457ec724fce972654d6091eb00 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB.h @@ -0,0 +1,919 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +// This file is auto-generated. See "generate_kernels.py" +#pragma once +#include +using namespace PyTorchMemEffAttention; +// ======== f16 / sm70 ======== +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k32_seqaligned_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k32_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k64_seqaligned_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k64_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_128x64_k128_seqaligned_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_128x64_k128_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k128_seqaligned_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k128_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_128x64_k65536_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k65536_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k32_dropout_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k64_dropout_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_128x64_k128_dropout_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k128_dropout_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_128x64_k65536_dropout_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k65536_dropout_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k32_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k64_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_128x64_k128_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k128_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_128x64_k65536_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k65536_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k32_dropout_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k64_dropout_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_128x64_k128_dropout_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k128_dropout_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_128x64_k65536_dropout_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k65536_dropout_sm70(typename AttentionBackwardKernel::Params p); + +template void dispatch_cutlassB_f16_sm70(T cb, int cc) { + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k32_seqaligned_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k32_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k64_seqaligned_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k64_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_128x64_k128_seqaligned_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_128x64_k128_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k128_seqaligned_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k128_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_128x64_k65536_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k65536_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k32_dropout_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k64_dropout_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_128x64_k128_dropout_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k128_dropout_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_128x64_k65536_dropout_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k65536_dropout_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k32_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k64_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_128x64_k128_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k128_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_128x64_k65536_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k65536_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k32_dropout_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k64_dropout_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_128x64_k128_dropout_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k128_dropout_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_128x64_k65536_dropout_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k65536_dropout_sm70); +} + +// ======== bf16 / sm80 ======== +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_bf16_aligned_64x64_k32_seqaligned_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_bf16_aligned_64x64_k32_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_bf16_aligned_64x64_k64_seqaligned_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_bf16_aligned_64x64_k64_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_bf16_aligned_128x64_k96_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_bf16_aligned_128x128_k128_seqaligned_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_bf16_aligned_128x128_k128_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_bf16_aligned_64x64_k128_seqaligned_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_bf16_aligned_64x64_k128_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_bf16_aligned_128x64_k65536_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_bf16_aligned_64x64_k65536_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_bf16_aligned_64x64_k32_dropout_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_bf16_aligned_64x64_k64_dropout_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_bf16_aligned_128x128_k128_dropout_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_bf16_aligned_64x64_k128_dropout_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_bf16_aligned_128x64_k65536_dropout_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_bf16_aligned_64x64_k65536_dropout_sm80(typename AttentionBackwardKernel::Params p); + +template void dispatch_cutlassB_bf16_sm80(T cb, int cc) { + cb(AttentionBackwardKernel(), fmha_cutlassB_bf16_aligned_64x64_k32_seqaligned_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_bf16_aligned_64x64_k32_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_bf16_aligned_64x64_k64_seqaligned_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_bf16_aligned_64x64_k64_sm80); + if (cc == 86 || cc == 89) cb(AttentionBackwardKernel(), fmha_cutlassB_bf16_aligned_128x64_k96_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_bf16_aligned_128x128_k128_seqaligned_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_bf16_aligned_128x128_k128_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_bf16_aligned_64x64_k128_seqaligned_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_bf16_aligned_64x64_k128_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_bf16_aligned_128x64_k65536_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_bf16_aligned_64x64_k65536_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_bf16_aligned_64x64_k32_dropout_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_bf16_aligned_64x64_k64_dropout_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_bf16_aligned_128x128_k128_dropout_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_bf16_aligned_64x64_k128_dropout_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_bf16_aligned_128x64_k65536_dropout_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_bf16_aligned_64x64_k65536_dropout_sm80); +} + +// ======== f16 / sm80 ======== +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k32_seqaligned_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k32_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k64_seqaligned_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k64_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_128x64_k96_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_128x128_k128_seqaligned_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_128x128_k128_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k128_seqaligned_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k128_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_128x64_k65536_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k65536_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k32_dropout_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k64_dropout_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_128x128_k128_dropout_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k128_dropout_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_128x64_k65536_dropout_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k65536_dropout_sm80(typename AttentionBackwardKernel::Params p); + +template void dispatch_cutlassB_f16_sm80(T cb, int cc) { + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k32_seqaligned_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k32_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k64_seqaligned_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k64_sm80); + if (cc == 86 || cc == 89) cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_128x64_k96_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_128x128_k128_seqaligned_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_128x128_k128_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k128_seqaligned_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k128_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_128x64_k65536_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k65536_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k32_dropout_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k64_dropout_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_128x128_k128_dropout_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k128_dropout_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_128x64_k65536_dropout_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k65536_dropout_sm80); +} + +// ======== f16 / sm50 ======== +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k32_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k64_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k128_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k65536_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k32_dropout_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k64_dropout_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k128_dropout_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k65536_dropout_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k32_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k64_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k128_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k65536_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k32_dropout_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k64_dropout_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k128_dropout_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k65536_dropout_sm50(typename AttentionBackwardKernel::Params p); + +template void dispatch_cutlassB_f16_sm50(T cb, int cc) { + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k32_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k64_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k128_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k65536_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k32_dropout_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k64_dropout_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k128_dropout_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k65536_dropout_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k32_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k64_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k128_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k65536_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k32_dropout_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k64_dropout_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k128_dropout_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k65536_dropout_sm50); +} + +// ======== f32 / sm50 ======== +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k32_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k64_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k128_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k65536_sm50(typename AttentionBackwardKernel::Params p); +#if defined(CUDA_VERSION) && CUDA_VERSION == 12040 && !defined(USE_ROCM) +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_32x32_k32_dropout_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_32x32_k64_dropout_sm50(typename AttentionBackwardKernel::Params p); +#else +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k32_dropout_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k64_dropout_sm50(typename AttentionBackwardKernel::Params p); +#endif +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k128_dropout_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k65536_dropout_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k32_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k64_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k128_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k65536_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k32_dropout_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k64_dropout_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k128_dropout_sm50(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k65536_dropout_sm50(typename AttentionBackwardKernel::Params p); + +template void dispatch_cutlassB_f32_sm50(T cb, int cc) { + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k32_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k64_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k128_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k65536_sm50); +#if defined(CUDA_VERSION) && CUDA_VERSION == 12040 && !defined(USE_ROCM) + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_32x32_k32_dropout_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_32x32_k64_dropout_sm50); +#else + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k32_dropout_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k64_dropout_sm50); +#endif + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k128_dropout_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k65536_dropout_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k32_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k64_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k128_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k65536_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k32_dropout_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k64_dropout_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k128_dropout_sm50); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k65536_dropout_sm50); +} + +// ======== f32 / sm70 ======== +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k32_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k64_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k128_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k65536_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k32_dropout_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k64_dropout_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k128_dropout_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k65536_dropout_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k32_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k64_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k128_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k65536_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k32_dropout_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k64_dropout_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k128_dropout_sm70(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k65536_dropout_sm70(typename AttentionBackwardKernel::Params p); + +template void dispatch_cutlassB_f32_sm70(T cb, int cc) { + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k32_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k64_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k128_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k65536_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k32_dropout_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k64_dropout_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k128_dropout_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k65536_dropout_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k32_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k64_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k128_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k65536_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k32_dropout_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k64_dropout_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k128_dropout_sm70); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k65536_dropout_sm70); +} + +// ======== f16 / sm75 ======== +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k32_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k64_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_128x64_k128_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k128_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_128x64_k65536_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k65536_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k32_dropout_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k64_dropout_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_128x64_k128_dropout_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k128_dropout_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_128x64_k65536_dropout_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_64x64_k65536_dropout_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k32_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k64_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_128x64_k128_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k128_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_128x64_k65536_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k65536_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k32_dropout_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k64_dropout_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_128x64_k128_dropout_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k128_dropout_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_128x64_k65536_dropout_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_notaligned_64x64_k65536_dropout_sm75(typename AttentionBackwardKernel::Params p); + +template void dispatch_cutlassB_f16_sm75(T cb, int cc) { + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k32_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k64_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_128x64_k128_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k128_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_128x64_k65536_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k65536_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k32_dropout_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k64_dropout_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_128x64_k128_dropout_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k128_dropout_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_128x64_k65536_dropout_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k65536_dropout_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k32_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k64_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_128x64_k128_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k128_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_128x64_k65536_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k65536_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k32_dropout_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k64_dropout_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_128x64_k128_dropout_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k128_dropout_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_128x64_k65536_dropout_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f16_notaligned_64x64_k65536_dropout_sm75); +} + +// ======== f32 / sm75 ======== +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k32_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k64_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k128_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k65536_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k32_dropout_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k64_dropout_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k128_dropout_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k65536_dropout_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k32_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k64_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k128_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k65536_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k32_dropout_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k64_dropout_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k128_dropout_sm75(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_notaligned_64x64_k65536_dropout_sm75(typename AttentionBackwardKernel::Params p); + +template void dispatch_cutlassB_f32_sm75(T cb, int cc) { + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k32_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k64_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k128_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k65536_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k32_dropout_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k64_dropout_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k128_dropout_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k65536_dropout_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k32_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k64_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k128_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k65536_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k32_dropout_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k64_dropout_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k128_dropout_sm75); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_notaligned_64x64_k65536_dropout_sm75); +} + +// ======== f32 / sm80 ======== +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k32_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k64_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_128x64_k128_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k128_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_128x64_k65536_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k65536_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k32_dropout_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k64_dropout_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_128x64_k128_dropout_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k128_dropout_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_128x64_k65536_dropout_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f32_aligned_64x64_k65536_dropout_sm80(typename AttentionBackwardKernel::Params p); + +template void dispatch_cutlassB_f32_sm80(T cb, int cc) { + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k32_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k64_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_128x64_k128_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k128_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_128x64_k65536_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k65536_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k32_dropout_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k64_dropout_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_128x64_k128_dropout_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k128_dropout_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_128x64_k65536_dropout_sm80); + cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k65536_dropout_sm80); +} + + +template +void dispatch_cutlassB(T cb, int cc = 0) { + + if (std::is_same_v && 70 <= cc && cc < 75) { + dispatch_cutlassB_f16_sm70(cb, cc); + } + if (std::is_same_v && 80 <= cc && cc <= 120) { + dispatch_cutlassB_bf16_sm80(cb, cc); + } + if (std::is_same_v && 80 <= cc && cc <= 120) { + dispatch_cutlassB_f16_sm80(cb, cc); + } + if (std::is_same_v && 50 <= cc && cc < 70) { + dispatch_cutlassB_f16_sm50(cb, cc); + } + if (std::is_same_v && 50 <= cc && cc < 70) { + dispatch_cutlassB_f32_sm50(cb, cc); + } + if (std::is_same_v && 70 <= cc && cc < 75) { + dispatch_cutlassB_f32_sm70(cb, cc); + } + if (std::is_same_v && 75 <= cc && cc < 80) { + dispatch_cutlassB_f16_sm75(cb, cc); + } + if (std::is_same_v && 75 <= cc && cc < 80) { + dispatch_cutlassB_f32_sm75(cb, cc); + } + if (std::is_same_v && 80 <= cc && cc <= 120) { + dispatch_cutlassB_f32_sm80(cb, cc); + } +} + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassF.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassF.h new file mode 100644 index 0000000000000000000000000000000000000000..68a4dcad6bb7d015435df371b011c43469ddfb01 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassF.h @@ -0,0 +1,318 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +// This file is auto-generated. See "generate_kernels.py" +#pragma once +#include +using namespace PyTorchMemEffAttention; +// ======== bf16 / sm80 ======== +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_bf16_aligned_64x64_rf_sm80(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_bf16_aligned_64x128_rf_sm80(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_bf16_aligned_32x128_gmem_sm80(typename AttentionKernel::Params p); + +template void dispatch_cutlassF_bf16_sm80(T cb, int cc) { + cb(AttentionKernel(), fmha_cutlassF_bf16_aligned_64x64_rf_sm80); + cb(AttentionKernel(), fmha_cutlassF_bf16_aligned_64x128_rf_sm80); + cb(AttentionKernel(), fmha_cutlassF_bf16_aligned_32x128_gmem_sm80); +} + +// ======== f16 / sm50 ======== +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f16_aligned_64x64_rf_sm50(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f16_aligned_32x128_rf_sm50(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f16_aligned_32x128_gmem_sm50(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f16_notaligned_64x64_rf_sm50(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f16_notaligned_32x128_rf_sm50(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f16_notaligned_32x128_gmem_sm50(typename AttentionKernel::Params p); + +template void dispatch_cutlassF_f16_sm50(T cb, int cc) { + cb(AttentionKernel(), fmha_cutlassF_f16_aligned_64x64_rf_sm50); + cb(AttentionKernel(), fmha_cutlassF_f16_aligned_32x128_rf_sm50); + cb(AttentionKernel(), fmha_cutlassF_f16_aligned_32x128_gmem_sm50); + cb(AttentionKernel(), fmha_cutlassF_f16_notaligned_64x64_rf_sm50); + cb(AttentionKernel(), fmha_cutlassF_f16_notaligned_32x128_rf_sm50); + cb(AttentionKernel(), fmha_cutlassF_f16_notaligned_32x128_gmem_sm50); +} + +// ======== f16 / sm70 ======== +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f16_aligned_64x64_rf_sm70(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f16_aligned_32x128_rf_sm70(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f16_aligned_32x128_gmem_sm70(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f16_notaligned_64x64_rf_sm70(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f16_notaligned_32x128_rf_sm70(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f16_notaligned_32x128_gmem_sm70(typename AttentionKernel::Params p); + +template void dispatch_cutlassF_f16_sm70(T cb, int cc) { + cb(AttentionKernel(), fmha_cutlassF_f16_aligned_64x64_rf_sm70); + cb(AttentionKernel(), fmha_cutlassF_f16_aligned_32x128_rf_sm70); + cb(AttentionKernel(), fmha_cutlassF_f16_aligned_32x128_gmem_sm70); + cb(AttentionKernel(), fmha_cutlassF_f16_notaligned_64x64_rf_sm70); + cb(AttentionKernel(), fmha_cutlassF_f16_notaligned_32x128_rf_sm70); + cb(AttentionKernel(), fmha_cutlassF_f16_notaligned_32x128_gmem_sm70); +} + +// ======== f16 / sm75 ======== +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f16_aligned_64x64_rf_sm75(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f16_aligned_32x128_rf_sm75(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f16_aligned_32x128_gmem_sm75(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f16_notaligned_64x64_rf_sm75(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f16_notaligned_32x128_rf_sm75(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f16_notaligned_32x128_gmem_sm75(typename AttentionKernel::Params p); + +template void dispatch_cutlassF_f16_sm75(T cb, int cc) { + cb(AttentionKernel(), fmha_cutlassF_f16_aligned_64x64_rf_sm75); + cb(AttentionKernel(), fmha_cutlassF_f16_aligned_32x128_rf_sm75); + cb(AttentionKernel(), fmha_cutlassF_f16_aligned_32x128_gmem_sm75); + cb(AttentionKernel(), fmha_cutlassF_f16_notaligned_64x64_rf_sm75); + cb(AttentionKernel(), fmha_cutlassF_f16_notaligned_32x128_rf_sm75); + cb(AttentionKernel(), fmha_cutlassF_f16_notaligned_32x128_gmem_sm75); +} + +// ======== f16 / sm80 ======== +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f16_aligned_64x64_rf_sm80(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f16_aligned_64x128_rf_sm80(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f16_aligned_32x128_gmem_sm80(typename AttentionKernel::Params p); + +template void dispatch_cutlassF_f16_sm80(T cb, int cc) { + cb(AttentionKernel(), fmha_cutlassF_f16_aligned_64x64_rf_sm80); + cb(AttentionKernel(), fmha_cutlassF_f16_aligned_64x128_rf_sm80); + cb(AttentionKernel(), fmha_cutlassF_f16_aligned_32x128_gmem_sm80); +} + +// ======== f32 / sm50 ======== +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f32_aligned_64x64_rf_sm50(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f32_aligned_32x128_rf_sm50(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f32_aligned_32x128_gmem_sm50(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f32_notaligned_64x64_rf_sm50(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f32_notaligned_32x128_rf_sm50(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f32_notaligned_32x128_gmem_sm50(typename AttentionKernel::Params p); + +template void dispatch_cutlassF_f32_sm50(T cb, int cc) { + cb(AttentionKernel(), fmha_cutlassF_f32_aligned_64x64_rf_sm50); + cb(AttentionKernel(), fmha_cutlassF_f32_aligned_32x128_rf_sm50); + cb(AttentionKernel(), fmha_cutlassF_f32_aligned_32x128_gmem_sm50); + cb(AttentionKernel(), fmha_cutlassF_f32_notaligned_64x64_rf_sm50); + cb(AttentionKernel(), fmha_cutlassF_f32_notaligned_32x128_rf_sm50); + cb(AttentionKernel(), fmha_cutlassF_f32_notaligned_32x128_gmem_sm50); +} + +// ======== f32 / sm70 ======== +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f32_aligned_64x64_rf_sm70(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f32_aligned_32x128_rf_sm70(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f32_aligned_32x128_gmem_sm70(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f32_notaligned_64x64_rf_sm70(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f32_notaligned_32x128_rf_sm70(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f32_notaligned_32x128_gmem_sm70(typename AttentionKernel::Params p); + +template void dispatch_cutlassF_f32_sm70(T cb, int cc) { + cb(AttentionKernel(), fmha_cutlassF_f32_aligned_64x64_rf_sm70); + cb(AttentionKernel(), fmha_cutlassF_f32_aligned_32x128_rf_sm70); + cb(AttentionKernel(), fmha_cutlassF_f32_aligned_32x128_gmem_sm70); + cb(AttentionKernel(), fmha_cutlassF_f32_notaligned_64x64_rf_sm70); + cb(AttentionKernel(), fmha_cutlassF_f32_notaligned_32x128_rf_sm70); + cb(AttentionKernel(), fmha_cutlassF_f32_notaligned_32x128_gmem_sm70); +} + +// ======== f32 / sm75 ======== +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f32_aligned_64x64_rf_sm75(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f32_aligned_32x128_rf_sm75(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f32_aligned_32x128_gmem_sm75(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f32_notaligned_64x64_rf_sm75(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f32_notaligned_32x128_rf_sm75(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f32_notaligned_32x128_gmem_sm75(typename AttentionKernel::Params p); + +template void dispatch_cutlassF_f32_sm75(T cb, int cc) { + cb(AttentionKernel(), fmha_cutlassF_f32_aligned_64x64_rf_sm75); + cb(AttentionKernel(), fmha_cutlassF_f32_aligned_32x128_rf_sm75); + cb(AttentionKernel(), fmha_cutlassF_f32_aligned_32x128_gmem_sm75); + cb(AttentionKernel(), fmha_cutlassF_f32_notaligned_64x64_rf_sm75); + cb(AttentionKernel(), fmha_cutlassF_f32_notaligned_32x128_rf_sm75); + cb(AttentionKernel(), fmha_cutlassF_f32_notaligned_32x128_gmem_sm75); +} + +// ======== f32 / sm80 ======== +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f32_aligned_64x64_rf_sm80(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f32_aligned_64x128_rf_sm80(typename AttentionKernel::Params p); +__global__ void __launch_bounds__( + AttentionKernel::kNumThreads, + AttentionKernel::kMinBlocksPerSm) +fmha_cutlassF_f32_aligned_32x128_gmem_sm80(typename AttentionKernel::Params p); + +template void dispatch_cutlassF_f32_sm80(T cb, int cc) { + cb(AttentionKernel(), fmha_cutlassF_f32_aligned_64x64_rf_sm80); + cb(AttentionKernel(), fmha_cutlassF_f32_aligned_64x128_rf_sm80); + cb(AttentionKernel(), fmha_cutlassF_f32_aligned_32x128_gmem_sm80); +} + + +template +void dispatch_cutlassF(T cb, int cc = 0) { + + if (std::is_same_v && 80 <= cc && cc <= 120) { + dispatch_cutlassF_bf16_sm80(cb, cc); + } + if (std::is_same_v && 50 <= cc && cc < 70) { + dispatch_cutlassF_f16_sm50(cb, cc); + } + if (std::is_same_v && 70 <= cc && cc < 75) { + dispatch_cutlassF_f16_sm70(cb, cc); + } + if (std::is_same_v && 75 <= cc && cc < 80) { + dispatch_cutlassF_f16_sm75(cb, cc); + } + if (std::is_same_v && 80 <= cc && cc <= 120) { + dispatch_cutlassF_f16_sm80(cb, cc); + } + if (std::is_same_v && 50 <= cc && cc < 70) { + dispatch_cutlassF_f32_sm50(cb, cc); + } + if (std::is_same_v && 70 <= cc && cc < 75) { + dispatch_cutlassF_f32_sm70(cb, cc); + } + if (std::is_same_v && 75 <= cc && cc < 80) { + dispatch_cutlassF_f32_sm75(cb, cc); + } + if (std::is_same_v && 80 <= cc && cc <= 120) { + dispatch_cutlassF_f32_sm80(cb, cc); + } +} + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/transform/tile_smem_loader.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/transform/tile_smem_loader.h new file mode 100644 index 0000000000000000000000000000000000000000..c2ea60fe3b642bd55892e4998b03fcc1f8492200 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/transformers/cuda/mem_eff_attention/transform/tile_smem_loader.h @@ -0,0 +1,71 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +template < + typename scalar_t, // scalar type + typename ThreadblockTileShape, // size of tile to load + int Threads, // number of participating threads + int ElementsPerAccess> // thread access width in elements +class TileSmemLoader { + public: + using SmemTile = + cutlass::AlignedBuffer; + + using ThreadMap = cutlass::transform::PitchLinearStripminedThreadMap< + cutlass::layout::PitchLinearShape< + ThreadblockTileShape::kColumn, // contiguous + ThreadblockTileShape::kRow>, // strided + Threads, // Threads + ElementsPerAccess>; // ElementsPerAccess + + using GmemTileIterator = + cutlass::transform::threadblock::PredicatedTileIterator< + ThreadblockTileShape, // Shape + scalar_t, // Element + cutlass::layout::RowMajor, // Layout + 0, // AdvanceRank + ThreadMap>; // ThreadMap + + using SmemTileIterator = cutlass::transform::threadblock::RegularTileIterator< + ThreadblockTileShape, // Shape + scalar_t, // Element + cutlass::layout::RowMajor, // Layout + 0, // AdvanceRank + ThreadMap>; // ThreadMap + + using Fragment = typename GmemTileIterator::Fragment; + + /// load a tile from global memory into shared memory + CUTLASS_DEVICE + static void load( + GmemTileIterator tile_load_iter, + SmemTileIterator tile_store_iter) { + Fragment tb_frag; + tb_frag.clear(); + tile_load_iter.load(tb_frag); + tile_store_iter.store(tb_frag); + + __syncthreads(); + } +}; + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/transformers/hip/flash_attn/ck/me_ck_api.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/transformers/hip/flash_attn/ck/me_ck_api.h new file mode 100644 index 0000000000000000000000000000000000000000..960eb5f101bc4ef5fe4c67222f91799cd53fbbe5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/transformers/hip/flash_attn/ck/me_ck_api.h @@ -0,0 +1,72 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include + +#include + +#if defined(USE_ROCM_CK_SDPA) +namespace pytorch_flash { + +std::tuple< + at::Tensor, // output + at::Tensor, // q + at::Tensor, // k + at::Tensor, // v + at::Tensor, // lse + at::Tensor, // seed + at::Tensor, // offset + at::Tensor> // dropout randval +mem_eff_forward_ck( + const at::Tensor& q, + const at::Tensor& k, + const at::Tensor& v, + float p_dropout, + bool return_dropout_randval, + std::optional is_causal, + std::optional scale, + const std::optional& attn_bias_, + std::optional& out_, + const std::optional& cu_seqlens_q, + const std::optional& cu_seqlens_k, + const std::optional& seqstart_q, + const std::optional& seqstart_k, + std::optional gen_, + std::optional& seqused_k_ +); + +std::tuple< + at::Tensor, // dQ + at::Tensor, // dK + at::Tensor, // dV + at::Tensor> // dBias +mem_eff_backward_ck( + const at::Tensor &dout, + const at::Tensor &q, + const at::Tensor &k, + const at::Tensor &v, + const at::Tensor &out, + const at::Tensor &softmax_lse, + const at::Tensor &dq_, + const at::Tensor &dk_, + const at::Tensor &dv_, + std::optional &attn_bias, + bool bias_requires_grad, + std::optional &grad_bias, + std::optional &cu_seqlens_q, + std::optional &cu_seqlens_k, + int max_seqlen_q, + int max_seqlen_k, + float p_dropout, + float scale, + bool is_causal, + bool deterministic, + bool zero_tensors, + const at::Tensor philox_seed, + const at::Tensor philox_offset); + +} // namespace pytorch_flash +#endif // USE_ROCM_CK_SDPA + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/transformers/hip/flash_attn/flash_api.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/transformers/hip/flash_attn/flash_api.h new file mode 100644 index 0000000000000000000000000000000000000000..a8248d01a14e583ac8889d555d047877a7e7821a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/native/transformers/hip/flash_attn/flash_api.h @@ -0,0 +1,568 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include + +#include +#include +#include + +#define CHECK_NOSPARSE_CONTIGUOUS_CUDA(TENSOR) \ + TORCH_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \ + TORCH_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \ + TORCH_CHECK(TENSOR.is_contiguous()); + +#define CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(TENSOR) \ + TORCH_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \ + TORCH_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \ + TORCH_CHECK( \ + TENSOR.stride(-1) == 1, #TENSOR ": last dimension must be contiguous"); + +#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \ + TORCH_CHECK( \ + uint64_t(PTR) % ALIGNMENT == 0, #PTR " is not correctly aligned") + +#define ASSIGN_CHECK_OVERFLOW(A, B) \ + { \ + A = B; \ + TORCH_CHECK( \ + B < std::numeric_limits::max(), #B " overflows"); \ + } + +namespace pytorch_flash { + +// AOTriton Implementation +TORCH_API +std::tuple< + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor> +mha_fwd_aot( + const at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size + std::optional& + out_, // batch_size x seqlen_q x num_heads x head_size + std::optional& + alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, + const float softmax_scale, + bool is_causal, + std::optional window_size_left, + std::optional window_size_right, + const bool return_softmax, + const std::optional& gen_); + +std::tuple< + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor> +mha_varlen_fwd_aot( + const at::Tensor& + q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor& + k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& + v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + std::optional& + out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& cu_seqlens_q, // b+1 + const at::Tensor& cu_seqlens_k, // b+1 + std::optional& + seqused_k, // b. If given, only this many elements of each batch + // element's keys are used. + std::optional& block_table_, + std::optional& alibi_slopes_, // num_heads or b x num_heads + int max_seqlen_q, + const int max_seqlen_k, + const float p_dropout, + const float softmax_scale, + const bool zero_tensors, + bool is_causal, + std::optional window_size_left, + std::optional window_size_right, + const bool return_softmax, + const std::optional& gen_); + +std::tuple mha_bwd_aot( + const at::Tensor& dout, // batch_size x seqlen_q x num_heads, x head_size_og + const at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor& out, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor& softmax_lse, // b x h x seqlen_q + std::optional& + dq_, // batch_size x seqlen_q x num_heads x head_size + std::optional& + dk_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional& + dv_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional& + alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, // probability to drop + const float softmax_scale, + const bool is_causal, + std::optional window_size_left, + std::optional window_size_right, + const bool deterministic, + const at::Tensor& philox_seed, + const at::Tensor& philox_offset); + +std::tuple mha_varlen_bwd_aot( + const at::Tensor& dout, // total_q x num_heads, x head_size + const at::Tensor& + q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor& + k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& + v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& out, // total_q x num_heads x head_size + const at::Tensor& softmax_lse, // b x h x s softmax logsumexp + std::optional& + dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + std::optional& + dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + std::optional& + dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& cu_seqlens_q, // b+1 + const at::Tensor& cu_seqlens_k, // b+1 + std::optional& alibi_slopes_, // num_heads or b x num_heads + const int max_seqlen_q, + const int max_seqlen_k, // max sequence length to choose the kernel + const float p_dropout, // probability to drop + const float softmax_scale, + const bool zero_tensors, + const bool is_causal, + std::optional window_size_left, + std::optional window_size_right, + const bool deterministic, + const at::Tensor& philox_seed, + const at::Tensor& philox_offset); + +#if defined(USE_ROCM_CK_SDPA) +// CK implementation +TORCH_API +std::tuple< + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor> +mha_fwd_ck( + const at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size + std::optional& + out_, // batch_size x seqlen_q x num_heads x head_size + const float p_dropout, + const float softmax_scale, + bool is_causal, + int window_size_left, + int window_size_right, + const bool return_softmax, + std::optional gen_, + const std::optional& attn_bias_); // batch_size x nheads x seqlen_q x seqlen_k + +std::tuple< + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor> +mha_varlen_fwd_ck( + const at::Tensor& + q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor& + k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& + v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + std::optional& + out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& cu_seqlens_q, // b+1 + const at::Tensor& cu_seqlens_k, // b+1 + std::optional& + seqused_k, // b. If given, only this many elements of each batch + // element's keys are used. + int max_seqlen_q, + const int max_seqlen_k, + const float p_dropout, + const float softmax_scale, + const bool zero_tensors, + bool is_causal, + int window_size_left, + int window_size_right, + const bool return_softmax, + std::optional gen_, + const std::optional& attn_bias_); + +std::tuple mha_bwd_ck( + const at::Tensor& dout, // batch_size x seqlen_q x num_heads, x head_size_og + const at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor& out, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor& softmax_lse, // b x h x seqlen_q + std::optional& + dq_, // batch_size x seqlen_q x num_heads x head_size + std::optional& + dk_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional& + dv_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional& + attn_bias_, // batch_size x num_heads x seqlen_q x seqlen_k + bool bias_requires_grad, + std::optional& grad_bias, + const float p_dropout, // probability to drop + const float softmax_scale, + const bool is_causal, + int window_size_left, + int window_size_right, + const bool deterministic, + const at::Tensor philox_seed, + const at::Tensor philox_offset); + +std::tuple mha_varlen_bwd_ck( + const at::Tensor& dout, // total_q x num_heads, x head_size + const at::Tensor& + q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor& + k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& + v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& out, // total_q x num_heads x head_size + const at::Tensor& softmax_lse, // b x h x s softmax logsumexp + std::optional& + dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + std::optional& + dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + std::optional& + dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& cu_seqlens_q, // b+1 + const at::Tensor& cu_seqlens_k, // b+1 + std::optional& attn_bias_, // num_heads or b x num_heads + bool bias_requires_grad, + std::optional& grad_bias, + const int max_seqlen_q, + const int max_seqlen_k, // max sequence length to choose the kernel + const float p_dropout, // probability to drop + const float softmax_scale, + const bool zero_tensors, + const bool is_causal, + int window_size_left, + int window_size_right, + const bool deterministic, + const at::Tensor philox_seed, + const at::Tensor philox_offset); +#endif + +TORCH_API +std::tuple< + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor> +mha_fwd( + const at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size + std::optional& + out_, // batch_size x seqlen_q x num_heads x head_size + std::optional& + alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, + const float softmax_scale, + bool is_causal, + std::optional window_size_left, + std::optional window_size_right, + const float softcap, + const bool return_softmax, + std::optional gen_); + +inline std::tuple< + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor> +mha_varlen_fwd( + const at::Tensor& + q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor& + k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& + v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + std::optional& + out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& cu_seqlens_q, // b+1 + const at::Tensor& cu_seqlens_k, // b+1 + std::optional& + seqused_k, // b. If given, only this many elements of each batch + // element's keys are used. + std::optional& + block_table_, // Not used on ROCm. Keeping for parity with CUDA + std::optional& alibi_slopes_, // num_heads or b x num_heads + int max_seqlen_q, + const int max_seqlen_k, + const float p_dropout, + const float softmax_scale, + const bool zero_tensors, + bool is_causal, + std::optional window_size_left, + std::optional window_size_right, + const float softcap, + const bool return_softmax, + std::optional gen_) { +#if defined(USE_ROCM_CK_SDPA) + if (at::globalContext().getROCmFAPreferredBackend() == + at::ROCmFABackend::Ck) { + std::optional dummy_attn_bias = std::nullopt; + const int non_null_window_left = window_size_left.value_or(-1); + const int non_null_window_right = window_size_right.value_or(-1); + return mha_varlen_fwd_ck( + q, + k, + v, + out_, + cu_seqlens_q, + cu_seqlens_k, + seqused_k, + max_seqlen_q, + max_seqlen_k, + p_dropout, + softmax_scale, + zero_tensors, + is_causal, + non_null_window_left, + non_null_window_right, + return_softmax, + gen_, + dummy_attn_bias); // Not used in flash attention + } +#endif + return mha_varlen_fwd_aot( + q, + k, + v, + out_, + cu_seqlens_q, + cu_seqlens_k, + seqused_k, + block_table_, + alibi_slopes_, + max_seqlen_q, + max_seqlen_k, + p_dropout, + softmax_scale, + zero_tensors, + is_causal, + window_size_left, + window_size_right, + return_softmax, + gen_); +} + +inline std::tuple mha_bwd( + const at::Tensor& dout, // batch_size x seqlen_q x num_heads, x head_size_og + const at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor& out, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor& softmax_lse, // b x h x seqlen_q + std::optional& + dq_, // batch_size x seqlen_q x num_heads x head_size + std::optional& + dk_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional& + dv_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional& + alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, // probability to drop + const float softmax_scale, + const bool is_causal, + std::optional window_size_left, + std::optional window_size_right, + const float softcap, + const bool deterministic, + const at::Tensor philox_seed, + const at::Tensor philox_offset) { + +#if defined(USE_ROCM_CK_SDPA) + if (at::globalContext().getROCmFAPreferredBackend() == + at::ROCmFABackend::Ck) { + std::optional non_null_dbias = std::nullopt; + const int non_null_window_left = window_size_left.value_or(-1); + const int non_null_window_right = window_size_right.value_or(-1); + auto[dQuery, + dKey, + dValue, + dSoftmax, + dBias] = mha_bwd_ck( + dout, + q, + k, + v, + out, + softmax_lse, + dq_, + dk_, + dv_, + alibi_slopes_, + false, // bias_requires_grad + non_null_dbias, + p_dropout, + softmax_scale, + is_causal, + non_null_window_left, + non_null_window_right, + deterministic, + philox_seed, + philox_offset); + // for FA return [dQ, dV, dK, dSoftmax] + return std::make_tuple(std::move(dQuery), std::move(dKey), std::move(dValue), std::move(dSoftmax)); + } +#endif + return mha_bwd_aot( + dout, + q, + k, + v, + out, + softmax_lse, + dq_, + dk_, + dv_, + alibi_slopes_, + p_dropout, + softmax_scale, + is_causal, + window_size_left, + window_size_right, + deterministic, + philox_seed, + philox_offset); +} + +inline std::tuple mha_varlen_bwd( + const at::Tensor& dout, // total_q x num_heads, x head_size + const at::Tensor& + q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor& + k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& + v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& out, // total_q x num_heads x head_size + const at::Tensor& softmax_lse, // b x h x s softmax logsumexp + std::optional& + dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + std::optional& + dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + std::optional& + dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& cu_seqlens_q, // b+1 + const at::Tensor& cu_seqlens_k, // b+1 + std::optional& alibi_slopes_, // num_heads or b x num_heads + const int max_seqlen_q, + const int max_seqlen_k, // max sequence length to choose the kernel + const float p_dropout, // probability to drop + const float softmax_scale, + const bool zero_tensors, + const bool is_causal, + std::optional window_size_left, + std::optional window_size_right, + const float softcap, + const bool deterministic, + const at::Tensor philox_seed, + const at::Tensor philox_offset) { +#if defined(USE_ROCM_CK_SDPA) + if (at::globalContext().getROCmFAPreferredBackend() == + at::ROCmFABackend::Ck) { + std::optional non_null_dbias = std::nullopt; + const int non_null_window_left = window_size_left.value_or(-1); + const int non_null_window_right = window_size_right.value_or(-1); + auto[dQuery, + dKey, + dValue, + dSoftmax, + dBias] = mha_varlen_bwd_ck( + dout, + q, + k, + v, + out, + softmax_lse, + dq_, + dk_, + dv_, + cu_seqlens_q, + cu_seqlens_k, + alibi_slopes_, + false, // bias_requires_grad + non_null_dbias, + max_seqlen_q, + max_seqlen_k, + p_dropout, + softmax_scale, + zero_tensors, + is_causal, + non_null_window_left, + non_null_window_right, + deterministic, + philox_seed, + philox_offset); + // for FA return [dQ, dV, dK, dSoftmax] + return std::make_tuple(std::move(dQuery), std::move(dKey), std::move(dValue), std::move(dSoftmax)); + } +#endif + return mha_varlen_bwd_aot( + dout, + q, + k, + v, + out, + softmax_lse, + dq_, + dk_, + dv_, + cu_seqlens_q, + cu_seqlens_k, + alibi_slopes_, + max_seqlen_q, + max_seqlen_k, + p_dropout, + softmax_scale, + zero_tensors, + is_causal, + window_size_left, + window_size_right, + deterministic, + philox_seed, + philox_offset); +} + +} // namespace pytorch_flash + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/xpu/detail/XPUHooks.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/xpu/detail/XPUHooks.h new file mode 100644 index 0000000000000000000000000000000000000000..c1771b96ff4094c38527df0e3e08e3796637654a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/xpu/detail/XPUHooks.h @@ -0,0 +1,38 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include + +namespace at::xpu::detail { + +// The real implementation of XPUHooksInterface +struct XPUHooks : public at::XPUHooksInterface { + XPUHooks(at::XPUHooksArgs) {} + void init() const override; + bool hasXPU() const override; + std::string showConfig() const override; + int32_t getGlobalIdxFromDevice(const at::Device& device) const override; + const Generator& getDefaultGenerator( + DeviceIndex device_index = -1) const override; + Generator getNewGenerator(DeviceIndex device_index = -1) const override; + Device getDeviceFromPtr(void* data) const override; + c10::DeviceIndex getNumGPUs() const override; + DeviceIndex current_device() const override; + void deviceSynchronize(DeviceIndex device_index) const override; + Allocator* getPinnedMemoryAllocator() const override; + + bool isBuilt() const override { + return true; + } + bool isAvailable() const override; + bool isPinnedPtr(const void* data) const override; + bool hasPrimaryContext(DeviceIndex device_index) const override; + DeviceIndex deviceCount() const override; + DeviceIndex getCurrentDevice() const override; +}; + +} // namespace at::xpu::detail + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/Allocator.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/Allocator.h new file mode 100644 index 0000000000000000000000000000000000000000..b66f075ec73fb77290e317e911c66e4497ca1469 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/Allocator.h @@ -0,0 +1,455 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace c10 { + +using CaptureId_t = unsigned long long; +// first is set if the instance is created by CUDAGraph::capture_begin. +// second is set if the instance is created by at::cuda::graph_pool_handle. +using MempoolId_t = std::pair; + +struct MempoolIdHash { + std::size_t operator()(const MempoolId_t& mempool_id) const noexcept { + return mempool_id.first != 0 ? mempool_id.first : mempool_id.second; + } +}; + +// A DataPtr is a unique pointer (with an attached deleter and some +// context for the deleter) to some memory, which also records what +// device is for its data. +// +// nullptr DataPtrs can still have a nontrivial device; this allows +// us to treat zero-size allocations uniformly with non-zero allocations. +// +class C10_API DataPtr { + private: + c10::detail::UniqueVoidPtr ptr_; + Device device_; + + public: + // Choice of CPU here is arbitrary; if there's an "undefined" device + // we could use that too + DataPtr() : device_(DeviceType::CPU) {} + DataPtr(void* data, Device device) : ptr_(data), device_(device) {} + DataPtr(void* data, void* ctx, DeleterFnPtr ctx_deleter, Device device) + : ptr_(data, ctx, ctx_deleter), device_(device) {} + void* operator->() const { + return ptr_.get(); + } + C10_ALWAYS_INLINE bool /* success */ unsafe_reset_data_and_ctx( + void* new_data_and_ctx) { + return ptr_.unsafe_reset_data_and_ctx(new_data_and_ctx); + } + void clear() { + ptr_.clear(); + } + void* get() const { + return ptr_.get(); + } + void* mutable_get() { + return ptr_.get(); + } + void* get_context() const { + return ptr_.get_context(); + } + void* release_context() { + return ptr_.release_context(); + } + std::unique_ptr&& move_context() { + return ptr_.move_context(); + } + operator bool() const { + return static_cast(ptr_); + } + template + T* cast_context(DeleterFnPtr expected_deleter) const { + return ptr_.cast_context(expected_deleter); + } + DeleterFnPtr get_deleter() const { + return ptr_.get_deleter(); + } + /** + * Compare the deleter in a DataPtr to expected_deleter. + * If it matches, replace the deleter with new_deleter + * and return true; otherwise, does nothing and returns + * false. + * + * In general, it is not safe to unconditionally set the + * deleter on a DataPtr, because you don't know what + * the deleter is, and thus will have a hard time properly + * disposing of the deleter without storing the original + * deleter (this is difficult to do, because DeleterFnPtr + * is not a closure, and because the context on DataPtr is + * only a single word, you generally don't have enough + * space to store both the original deleter and its context). + * However, in some cases, you know /exactly/ what the deleter + * is, and you have a new deleter that manually wraps + * the old one. In this case, you can safely swap the deleter + * after asserting that the deleters line up. + * + * What are the requirements on new_deleter? It must still + * properly dispose of the void* pointer passed in as its argument, + * where void* is whatever the context of the original deleter + * is. So in general, you expect the new deleter to look something + * like this: + * + * [](void* ptr) { + * some_new_stuff(ptr); + * get_orig_allocator()->raw_deleter(ptr); + * } + * + * Note that it won't work to close over the original + * allocator; you don't have enough space to do that! Also, + * it's unsafe to assume that the passed in pointer in + * question is the memory pointer in question; it might not + * be; be sure to read the source code of the Allocator + * in question to confirm this. + */ + [[nodiscard]] bool compare_exchange_deleter( + DeleterFnPtr expected_deleter, + DeleterFnPtr new_deleter) { + return ptr_.compare_exchange_deleter(expected_deleter, new_deleter); + } + Device device() const { + return device_; + } + // Unsafely mutates the device on a DataPtr. Under normal use, + // you should never actually need to call this function. + // We need this for the implementation of the hack detailed + // in Note [Masquerading as CUDA] + void unsafe_set_device(Device device) { + device_ = device; + } +}; + +// NB: Device is NOT tested for here; a CUDA nullptr is as much a nullptr as a +// CPU nullptr + +inline bool operator==(const DataPtr& dp, std::nullptr_t) noexcept { + return !dp; +} +inline bool operator==(std::nullptr_t, const DataPtr& dp) noexcept { + return !dp; +} +inline bool operator!=(const DataPtr& dp, std::nullptr_t) noexcept { + return dp; +} +inline bool operator!=(std::nullptr_t, const DataPtr& dp) noexcept { + return dp; +} + +// Note [raw_allocate/raw_deallocate and Thrust] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// Thrust's support for custom allocators requires us to write something +// like this: +// +// class ThrustAllocator { +// char* allocate(size_t); +// void deallocate(char*, size_t); +// }; +// +// This is not good for our unique_ptr based allocator interface, as +// there is no way to get to the context when we free. +// +// However, in some cases the context is exactly the same as +// the data pointer. In this case, we can support the "raw" +// allocate and deallocate interface. This is what +// raw_deleter signifies. By default, it returns a nullptr, which means that +// the raw interface is not implemented. Be sure to implement it whenever +// possible, or the raw interface will incorrectly reported as unsupported, +// when it is actually possible. + +// NOLINTNEXTLINE(cppcoreguidelines-special-member-functions) +struct C10_API Allocator { + virtual ~Allocator() = default; + + virtual DataPtr allocate(size_t n) = 0; + + // Clones an allocation that came from this allocator. + // + // To perform the copy, this function calls `copy_data`, which + // must be implemented by derived classes. + // + // Note that this explicitly ignores any context that may have been + // attached to the input data. + // + // Requires: input data was allocated by the same allocator. + DataPtr clone(const void* data, std::size_t n); + + // Checks if DataPtr has a simple context, not wrapped with any out of the + // ordinary contexts. + virtual bool is_simple_data_ptr(const DataPtr& data_ptr) const; + + // If this returns a non nullptr, it means that allocate() + // is guaranteed to return a unique_ptr with this deleter attached; + // it means the rawAllocate and rawDeallocate APIs are safe to use. + // This function MUST always return the same BoundDeleter. + virtual DeleterFnPtr raw_deleter() const { + return nullptr; + } + void* raw_allocate(size_t n) { + auto dptr = allocate(n); + AT_ASSERT(dptr.get() == dptr.get_context()); + return dptr.release_context(); + } + void raw_deallocate(void* ptr) { + auto d = raw_deleter(); + AT_ASSERT(d); + d(ptr); + } + + // Copies data from one allocation to another. + // Pure virtual, so derived classes must define behavior. + // Derived class implementation can simply call `default_copy_data` + // to use `std::memcpy`. + // + // Requires: src and dest were allocated by this allocator + // Requires: src and dest both have length >= count + virtual void copy_data(void* dest, const void* src, std::size_t count) + const = 0; + + protected: + // Uses `std::memcpy` to copy data. + // Child classes can use this as `copy_data` when an alternative copy + // API is not needed. + void default_copy_data(void* dest, const void* src, std::size_t count) const; +}; + +// This context is used to generate DataPtr which have arbitrary +// std::function deleters associated with them. In some user facing +// functions, we give a (user-friendly) interface for constructing +// tensors from external data which take an arbitrary std::function +// deleter. Grep for InefficientStdFunctionContext to find these +// occurrences. +// +// This context is inefficient because we have to do a dynamic +// allocation InefficientStdFunctionContext, on top of the dynamic +// allocation which is implied by std::function itself. +struct C10_API InefficientStdFunctionContext { + void* ptr_{nullptr}; + std::function deleter_; + InefficientStdFunctionContext(void* ptr, std::function deleter) + : ptr_(ptr), deleter_(std::move(deleter)) {} + InefficientStdFunctionContext(const InefficientStdFunctionContext&) = delete; + InefficientStdFunctionContext(InefficientStdFunctionContext&& rhs) noexcept + : ptr_(std::exchange(rhs.ptr_, nullptr)), + deleter_(std::move(rhs.deleter_)) {} + InefficientStdFunctionContext& operator=( + const InefficientStdFunctionContext&) = delete; + // NOLINTNEXTLINE(*-noexcept-move-*) + InefficientStdFunctionContext& operator=( + InefficientStdFunctionContext&& rhs) { + this->~InefficientStdFunctionContext(); + ptr_ = std::exchange(rhs.ptr_, nullptr); + deleter_ = std::move(rhs.deleter_); + return *this; + } + ~InefficientStdFunctionContext() { + if (deleter_) { + deleter_(ptr_); + } + } + static DataPtr makeDataPtr( + void* ptr, + std::function deleter, + Device device); +}; + +/** Set the allocator for DeviceType `t`. The passed in allocator pointer is + * expected to have static lifetime; this function does NOT take ownership + * of the raw pointer. (The reason for this is to prevent existing pointers + * to an allocator of a particular device from being invalidated when + * SetAllocator is called.) + * + * Also note that this is not thread-safe, and we assume this function will + * only be called during initialization. + * + * The 'priority' flag is introduced when we want to overwrite the default + * allocator, since the allocators are set statically. The default priority + * is 0, which means the lowest. Only higher or equal priority can overwrite + * existing ones. + */ +C10_API void SetAllocator(DeviceType t, Allocator* alloc, uint8_t priority = 0); +C10_API Allocator* GetAllocator(const DeviceType& t); + +template +struct AllocatorRegisterer { + explicit AllocatorRegisterer(Allocator* alloc) { + SetAllocator(t, alloc); + } +}; + +#define REGISTER_ALLOCATOR(t, f) \ + namespace { \ + static c10::AllocatorRegisterer g_allocator_d(f); \ + } + +// An interface for reporting thread local memory usage +// per device +struct C10_API MemoryReportingInfoBase : public c10::DebugInfoBase { + /** + * alloc_size corresponds to the size of the ptr. + * + * total_allocated corresponds to total allocated memory. + * + * total_reserved corresponds to total size of memory pool, both used and + * unused, if applicable. + */ + virtual void reportMemoryUsage( + void* ptr, + int64_t alloc_size, + size_t total_allocated, + size_t total_reserved, + Device device) = 0; + + virtual void reportOutOfMemory( + int64_t alloc_size, + size_t total_allocated, + size_t total_reserved, + Device device); + + virtual bool memoryProfilingEnabled() const = 0; +}; + +C10_API bool memoryProfilingEnabled(); +C10_API void reportMemoryUsageToProfiler( + void* ptr, + int64_t alloc_size, + size_t total_allocated, + size_t total_reserved, + Device device); + +C10_API void reportOutOfMemoryToProfiler( + int64_t alloc_size, + size_t total_allocated, + size_t total_reserved, + Device device); + +// used to hold traceback information in allocators +// NOLINTNEXTLINE(cppcoreguidelines-special-member-functions) +struct GatheredContext { + virtual ~GatheredContext() = default; +}; + +namespace CachingAllocator { +struct Stat { + void increase(size_t amount) { + current += static_cast(amount); + peak = std::max(current, peak); + allocated += static_cast(amount); + } + + void decrease(size_t amount) { + current -= static_cast(amount); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + current >= 0, + "Negative tracked stat in device allocator (likely logic error)."); + freed += static_cast(amount); + } + + void reset_accumulated() { + allocated = 0; + freed = 0; + } + + void reset_peak() { + peak = current; + } + + int64_t current = 0; + int64_t peak = 0; + int64_t allocated = 0; + int64_t freed = 0; +}; + +enum struct StatType : uint64_t { + AGGREGATE = 0, + SMALL_POOL = 1, + LARGE_POOL = 2, + NUM_TYPES = 3 // remember to update this whenever a new stat type is added +}; + +using StatArray = std::array(StatType::NUM_TYPES)>; +using StatTypes = std::array(StatType::NUM_TYPES)>; + +template +void for_each_selected_stat_type(const StatTypes& stat_types, Func f) { + for (const auto stat_type : c10::irange(stat_types.size())) { + if (stat_types[stat_type]) { + f(stat_type); + } + } +} + +// Structure for keeping timing information +struct DurationStat { + void increase(int64_t amount) { + total += amount; + count += 1; + max = std::max(amount, max); + if (min == 0) { + min = amount; + } else { + min = std::min(amount, min); + } + } + + void reset_accumulated() { + total = 0; + count = 0; + } + + void reset_peak() { + min = 0; + max = 0; + } + + int64_t total = 0; + int64_t max = 0; + int64_t min = 0; + int64_t count = 0; +}; + +// Size pretty-printer +inline std::string format_size(uint64_t size) { + std::ostringstream os; + os.precision(2); + os << std::fixed; + if (size <= 1024) { + os << size << " bytes"; + } else if (size <= 1048576) { + os << (static_cast(size) / 1024.0); + os << " KiB"; + } else if (size <= 1073741824ULL) { + os << static_cast(size) / 1048576.0; + os << " MiB"; + } else { + os << static_cast(size) / 1073741824.0; + os << " GiB"; + } + return os.str(); +} + +} // namespace CachingAllocator +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/AllocatorConfig.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/AllocatorConfig.h new file mode 100644 index 0000000000000000000000000000000000000000..ab6a23d24d0884d72c869947857c02c22584b9c3 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/AllocatorConfig.h @@ -0,0 +1,390 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace c10::CachingAllocator { + +// "large" allocations may be packed in 20 MiB blocks +constexpr size_t kLargeBuffer = 20971520; +// "small" allocations are packed in 2 MiB blocks +constexpr size_t kSmallBuffer = 2097152; +// all sizes are rounded to at least 512 bytes +constexpr size_t kMinBlockSize = 512; +// largest "small" allocation is 1 MiB +constexpr size_t kSmallSize = 1048576; +// allocations between 1 and 10 MiB may use kLargeBuffer +constexpr size_t kMinLargeAlloc = 10485760; +// round up large allocations to 2 MiB +constexpr size_t kRoundLarge = 2097152; + +// A utility class for tokenizing allocator configuration strings into discrete +// parts. For example, the config string: +// "key1:val1,key2:[val2,val3]" +// is tokenized into: +// "key1", ":", "val1", ",", "key2", ":", "[", "val2", ",", "val3", "]", +// +// Tokens include keys, values, and special characters (':', ',', '[', ']'). +// Whitespace is ignored. +class ConfigTokenizer { + public: + explicit ConfigTokenizer(const std::string& env) { + std::string buffer; + for (char ch : env) { + if (ch == ',' || ch == ':' || ch == '[' || ch == ']') { + if (!buffer.empty()) { + config_.emplace_back(std::move(buffer)); + buffer.clear(); + } + config_.emplace_back(1, ch); + } else if (!std::isspace(static_cast(ch))) { + buffer += ch; + } + } + if (!buffer.empty()) { + config_.emplace_back(std::move(buffer)); + } + } + + const std::string& operator[](size_t i) const { + TORCH_INTERNAL_ASSERT( + i < config_.size(), "Index out of bounds in ConfigTokenizer"); + return config_[i]; + } + + size_t size() const { + return config_.size(); + } + + bool checkToken(size_t i, const std::string& token) const { + checkIndex(i); + return config_[i] == token; + } + + size_t toSizeT(size_t i) const { + checkIndex(i); + return std::stoull(config_[i]); + } + + double toDouble(size_t i) const { + checkIndex(i); + return std::stod(config_[i]); + } + + bool toBool(size_t i) const { + checkIndex(i); + const auto& token = config_[i]; + if (token == "True") { + return true; + } else if (token == "False") { + return false; + } else { + TORCH_CHECK_VALUE( + false, + "Expected 'True' or 'False' at index ", + i, + " in ConfigTokenizer but got '", + token, + "'"); + } + } + + // Skips the current token group and returns the index of the value token. + // Assumes the current index `i` points to a key name in a key-value pair. + size_t skipKey(size_t i) const { + // Expect a colon after the key + checkToken(++i, ":"); + + ++i; // Move to the value + checkIndex(i); + if (config_[i] != "[") { + // Value is a single token (not a list) -> return its index + return i; + } + + // Skip tokens inside the list until matching ']' + // NOLINTNEXTLINE(bugprone-inc-dec-in-conditions) + while (++i < config_.size() && config_[i] != "]") { + } + + TORCH_INTERNAL_ASSERT( + i < config_.size(), + "Expected closing bracket ']' in ConfigTokenizer but reached end of config"); + + return i; // Return the index of the closing ']' + } + + private: + void checkIndex(size_t i) const { + TORCH_INTERNAL_ASSERT( + i < config_.size(), "Index out of bounds in ConfigTokenizer"); + } + + std::vector config_; +}; + +/** + * Note [AcceleratorAllocatorConfig design] + * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + * This class configures memory allocation for both device and host memory. A + * single `AcceleratorAllocatorConfig` instance is shared across all accelerator + * backends, such as CUDA and XPU, under the assumption that relevant + * environment variables apply uniformly to all accelerators. Device-specific + * configuration extensions are supported via hooks (see + * `registerDeviceConfigParserHook`). + * + * Recommended design: + * - Place common configurations in `AcceleratorAllocatorConfig`. + * - Extend backend-specific configurations in corresponding device-specific + * classes, such as `CUDAAllocatorConfig`, etc. + * + * Scope: + * - Configuration options must be environment-variable driven. + * + * Naming Convention: + * - Public API names in `AcceleratorAllocatorConfig` should be device-generic. + * - Members prefixed with `pinned_` are specific to the host/pinned allocator. + * - Environment variable names should be generic across backends. + * - Comma-separated key-value pairs in the format: `key:value`. Use square + * brackets `[]` for list values Example: `key1:123, key2:[val1,val2]` + * + * Environment Variables: + * - The primary environment variable for configuration is `PYTORCH_ALLOC_CONF`. + * - For backward compatibility, `PYTORCH_CUDA_ALLOC_CONF` is also supported + * with lower priority. + */ + +class C10_API AcceleratorAllocatorConfig { + public: + static AcceleratorAllocatorConfig& instance(); + + C10_DISABLE_COPY_AND_ASSIGN(AcceleratorAllocatorConfig); + AcceleratorAllocatorConfig(AcceleratorAllocatorConfig&&) = delete; + AcceleratorAllocatorConfig& operator=(AcceleratorAllocatorConfig&&) = delete; + ~AcceleratorAllocatorConfig() = default; + + /* Device allocator settings */ + + // Returns the maximum block size (in MB) that is allowed to be split. The + // default is unlimited (all blocks can be split). + static size_t max_split_size() { + return instance().max_split_size_; + } + + // Returns the maximum block size (in MB) that is allowed to be rounded up + // without requiring splitting when searching for a free block. The default is + // 20 MiB. + static size_t max_non_split_rounding_size() { + return instance().max_non_split_rounding_size_; + } + + // Return the number of divisions used when rounding up allocation sizes (in + // MB) to the nearest power-of-2 boundary. + static size_t roundup_power2_divisions(size_t size); + + // Returns the vector of division factors used for rounding up allocation + // sizes. These divisions apply to size intervals between 1MB and 64GB. + static const std::vector& roundup_power2_divisions() { + return instance().roundup_power2_divisions_; + } + + // Returns the threshold that triggers garbage collection when the ratio of + // used memory to maximum allowed memory exceeds this value. The default is 0, + // meaning no garbage collection is triggered. The value should be in the + // range (0.0, 1.0). + static double garbage_collection_threshold() { + return instance().garbage_collection_threshold_; + } + + // Returns whether the expandable segment feature is enabled. This allows the + // allocator to start with one segment that grows as needed, rather than + // creating a new segment for each allocation. Default is false (expandable + // segments disabled). + static bool use_expandable_segments() { + return instance().use_expandable_segments_; + } + + /* Host allocator settings */ + + // Returns whether the pinned host allocator uses background threads for + // processing events. This is useful for improving performance in scenarios + // where many small allocations are made. Default is false (background threads + // disabled). + static bool pinned_use_background_threads() { + return instance().pinned_use_background_threads_; + } + + /* Settings for both device and host allocator */ + + // Returns the current allocator settings as a string. This string is useful + // to expand device-specific allocator configurations + static std::string last_allocator_settings() { + std::lock_guard lock(instance().last_allocator_settings_mutex_); + return instance().last_allocator_settings_; + } + + // Use `Construct On First Use Idiom` to avoid `Static Initialization Order` + // issue. + static std::unordered_set& getMutableKeys() { + static std::unordered_set keys{ + "max_split_size_mb", + "max_non_split_rounding_mb", + "garbage_collection_threshold", + "roundup_power2_divisions", + "expandable_segments", + "pinned_use_background_threads"}; + return keys; + } + + // Returns the set of valid keys for the allocator configuration. + // This set is used to validate the presence and correctness of keys in + // device-specific configuration parsers. + static const std::unordered_set& getKeys() { + return getMutableKeys(); + } + + // Registers a device-specific configuration parser hook and its key. This + // allows backends to parse additional device-specific configuration options + // from the environment variable. The hook should be a function that takes a + // string (the environment variable value) and parses it to set + // device-specific configuration options. The hook will be called when the + // environment variable is parsed. If a hook is already registered, it will be + // replaced with the new one. + static void registerDeviceConfigParserHook( + std::function&& hook, + const std::unordered_set& keys) { + device_config_parser_hook_ = std::move(hook); + auto& mutable_keys = getMutableKeys(); + for (auto& key : keys) { + TORCH_CHECK_VALUE( + mutable_keys.insert(key).second, + "Duplicated key '", + key, + "' found in device-specific configuration parser hook registration"); + } + } + + // Calls the registered device-specific configuration parser hook with the + // provided environment string. This allows backends to parse additional + // device-specific configuration options from the environment variable. + // If no hook is registered, this function does nothing. + static void callDeviceConfigParserHook(const std::string& env) { + if (device_config_parser_hook_) { + device_config_parser_hook_(env); + } + } + + // Parses the environment variable `env` to update the allocator settings. + // If the environment variable is not set, it does nothing. + // The configuration string should be a comma-separated list of key-value + // pairs, where each key is a configuration option and the value is the + // corresponding setting. For example: + // "max_split_size_mb:100,max_non_split_rounding_mb:20,garbage_collection_threshold:0.5,roundup_power2_divisions:[64:8,256:4,1024:4,>:1],expandable_segments:true,pinned_use_background_threads:true" + void parseArgs(const std::string& env); + + private: + AcceleratorAllocatorConfig(); + + /* Internal functions for device allocator */ + + // Parse `max_split_size_mb` from environment variable. + size_t parseMaxSplitSize(const ConfigTokenizer& tokenizer, size_t i); + // Parse `max_non_split_rounding_mb` from environment variable. + size_t parseMaxNonSplitRoundingSize( + const ConfigTokenizer& tokenizer, + size_t i); + // Parse `garbage_collection_threshold` from environment variable. + size_t parseGarbageCollectionThreshold( + const ConfigTokenizer& tokenizer, + size_t i); + // Parse `roundup_power2_divisions` from environment variable. + size_t parseRoundUpPower2Divisions( + const ConfigTokenizer& tokenizer, + size_t i); + // Parse `expandable_segments` from environment variable. + size_t parseExpandableSegments(const ConfigTokenizer& tokenizer, size_t i); + + /* Internal functions for host allocator */ + + // Parse `pinned_use_background_threads` from environment variable. + size_t parsePinnedUseBackgroundThreads( + const ConfigTokenizer& tokenizer, + size_t i); + + /* The following members are specifically used for the device allocator. */ + + // The maximum block size that is allowed to be split. + std::atomic max_split_size_{std::numeric_limits::max()}; + // The maximum allowable extra size of a memory block without requiring + // splitting when searching for a free block. + std::atomic max_non_split_rounding_size_{kLargeBuffer}; + // Used to store how memory allocations of different sizes should be rounded + // up to the nearest power of 2 divisions. + std::vector roundup_power2_divisions_; + // The threshold that triggers garbage collection when the ratio of used + // memory to maximum allowed memory exceeds this value. + std::atomic garbage_collection_threshold_{0}; + // A flag to enable expandable segments feature. + std::atomic use_expandable_segments_{false}; + + /* The following members are specifically used for the host allocator. */ + + // A flag to enable background thread for processing events. + std::atomic pinned_use_background_threads_{false}; + + /* The following members are used for both device and host allocator. */ + + // Record the last allocator config environment setting. + std::mutex last_allocator_settings_mutex_; + std::string last_allocator_settings_; + + // Optional hook for parsing additional device-specific allocator settings. + // This allows backends (e.g., CUDA, XPU) to register a custom parser for + // their own environment configuration extensions. + inline static std::function + device_config_parser_hook_{nullptr}; +}; + +C10_API inline void setAllocatorSettings(const std::string& env) { + AcceleratorAllocatorConfig::instance().parseArgs(env); + AcceleratorAllocatorConfig::callDeviceConfigParserHook(env); +} + +C10_API inline std::string getAllocatorSettings() { + return AcceleratorAllocatorConfig::instance().last_allocator_settings(); +} + +struct DeviceConfigParserHookRegistry { + explicit DeviceConfigParserHookRegistry( + std::function&& hook, + const std::unordered_set& keys) { + // Use static method to avoid static initialization order fiasco issues + AcceleratorAllocatorConfig::registerDeviceConfigParserHook( + std::move(hook), keys); + } +}; + +// Assume each config parser has `parseArgs` and `getKeys` methods +#define REGISTER_ALLOCATOR_CONFIG_PARSE_HOOK(parser_cls) \ + namespace { \ + static at::CachingAllocator::DeviceConfigParserHookRegistry \ + g_device_config_parse_hook_registry_instance( \ + [](const std::string& env) { \ + parser_cls::instance().parseArgs(env); \ + }, \ + parser_cls::getKeys()); \ + } + +} // namespace c10::CachingAllocator + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/AutogradState.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/AutogradState.h new file mode 100644 index 0000000000000000000000000000000000000000..9d596b01d233dad00702dcad5269f146672861c5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/AutogradState.h @@ -0,0 +1,90 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include + +namespace c10 { + +// Structure used to pack all the thread local boolean +// flags used by autograd +struct C10_API AutogradState { + static AutogradState& get_tls_state(); + static void set_tls_state(AutogradState state); + + AutogradState( + bool grad_mode, + bool inference_mode, + bool fw_grad_mode, + bool multithreading_enabled) + : graph_exec_group_(std::nullopt), + grad_mode_(grad_mode), + inference_mode_(inference_mode), + fw_grad_mode_(fw_grad_mode), + multithreading_enabled_(multithreading_enabled), + view_replay_enabled_(false) {} + + void set_grad_mode(bool enabled) { + grad_mode_ = enabled; + } + + void set_fw_grad_mode(bool enabled) { + fw_grad_mode_ = enabled; + } + + void set_inference_mode(bool enabled) { + inference_mode_ = enabled; + } + + void set_multithreading_enabled(bool multithreading_enabled) { + multithreading_enabled_ = multithreading_enabled; + } + + void set_view_replay_enabled(bool view_replay_enabled) { + view_replay_enabled_ = view_replay_enabled; + } + + void set_graph_exec_group(std::optional group) { + graph_exec_group_ = std::move(group); + } + + bool get_grad_mode() const { + return grad_mode_; + } + + bool get_fw_grad_mode() const { + return fw_grad_mode_; + } + + bool get_inference_mode() const { + return inference_mode_; + } + + bool get_multithreading_enabled() const { + return multithreading_enabled_; + } + + bool get_view_replay_enabled() const { + return view_replay_enabled_; + } + + const std::optional& get_graph_exec_group() const { + return graph_exec_group_; + } + + private: + std::optional graph_exec_group_; + bool grad_mode_ : 1; + bool inference_mode_ : 1; + bool fw_grad_mode_ : 1; + bool multithreading_enabled_ : 1; + // NOLINTNEXTLINE(cppcoreguidelines-use-default-member-init) + bool view_replay_enabled_ : 1; +}; + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/Backend.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/Backend.h new file mode 100644 index 0000000000000000000000000000000000000000..d26c0089ae024b876be0df2821e3f562737ff35d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/Backend.h @@ -0,0 +1,414 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include + +#include + +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-enum") + +namespace c10 { + +/** + * This legacy enum class defines the set of backends supported by old school, + * code generated Type-based ATen. A "backend" in this sense roughly + * corresponds to the cartesian product of (device type, layout), but restricted + * only to combinations which we actually have kernels for. Backend does NOT + * include dtype. + * + * The reason we are sunsetting this enum class is because it doesn't allow for + * open registration; e.g., if you want to add SparseXLA, you'd have to + * edit this enum; you wouldn't be able to do it out of tree. DispatchKey is + * the replacement for Backend which supports open registration. + * + * NB: The concept of 'Backend' here disagrees with the notion of backend + * exposed to users in torch.backends. Backend here is something like "CPU" + * or "SparseCUDA"; backend in torch.backends is something like "MKL" or + * "CUDNN". + */ +enum class Backend { + CPU, + CUDA, + HIP, + VE, + FPGA, + IPU, + XPU, + SparseCPU, + SparseCUDA, + SparseCsrCPU, + SparseCsrCUDA, + SparseCsrMPS, + SparseMPS, + SparseHIP, + SparseVE, + SparseXPU, + SparsePrivateUse1, + SparseCsrHIP, + SparseCsrVE, + SparseCsrXPU, + SparseCsrPrivateUse1, + MAIA, + XLA, + Vulkan, + Metal, + Meta, + QuantizedCPU, + QuantizedCUDA, + QuantizedXPU, + QuantizedPrivateUse1, + Undefined, + MkldnnCPU, + MPS, + HPU, + Lazy, + MTIA, + PrivateUse1, + NumOptions +}; + +inline Backend dispatchKeyToBackend(DispatchKey t) { + if (t == DispatchKey::CPU || t == DispatchKey::AutogradCPU) { + return Backend::CPU; + } else if (t == DispatchKey::CUDA || t == DispatchKey::AutogradCUDA) { + return Backend::CUDA; + } else if (t == DispatchKey::HIP) { + return Backend::HIP; + } else if (t == DispatchKey::VE) { + return Backend::VE; + } else if (t == DispatchKey::FPGA) { + return Backend::FPGA; + } else if (t == DispatchKey::MAIA || t == DispatchKey::AutogradMAIA) { + return Backend::MAIA; + } else if (t == DispatchKey::XLA || t == DispatchKey::AutogradXLA) { + return Backend::XLA; + } else if (t == DispatchKey::Lazy || t == DispatchKey::AutogradLazy) { + return Backend::Lazy; + } else if (t == DispatchKey::MPS || t == DispatchKey::AutogradMPS) { + return Backend::MPS; + } else if (t == DispatchKey::Vulkan) { + return Backend::Vulkan; + } else if (t == DispatchKey::Metal) { + return Backend::Metal; + } else if (t == DispatchKey::Meta) { + return Backend::Meta; + } else if (t == DispatchKey::SparseCPU) { + return Backend::SparseCPU; + } else if (t == DispatchKey::SparseCUDA) { + return Backend::SparseCUDA; + } else if (t == DispatchKey::SparseMPS) { + return Backend::SparseMPS; + } else if (t == DispatchKey::SparseCsrMPS) { + return Backend::SparseCsrMPS; + } else if (t == DispatchKey::SparseHIP) { + return Backend::SparseHIP; + } else if (t == DispatchKey::SparseVE) { + return Backend::SparseVE; + } else if (t == DispatchKey::SparsePrivateUse1) { + return Backend::SparsePrivateUse1; + } else if (t == DispatchKey::SparseCsrCPU) { + return Backend::SparseCsrCPU; + } else if (t == DispatchKey::SparseCsrCUDA) { + return Backend::SparseCsrCUDA; + } else if (t == DispatchKey::SparseCsrHIP) { + return Backend::SparseCsrHIP; + } else if (t == DispatchKey::SparseCsrVE) { + return Backend::SparseCsrVE; + } else if (t == DispatchKey::SparseCsrPrivateUse1) { + return Backend::SparseCsrPrivateUse1; + } else if (t == DispatchKey::MkldnnCPU) { + return Backend::MkldnnCPU; + } else if (t == DispatchKey::QuantizedCPU) { + return Backend::QuantizedCPU; + } else if (t == DispatchKey::QuantizedCUDA) { + return Backend::QuantizedCUDA; + } else if (t == DispatchKey::IPU || t == DispatchKey::AutogradIPU) { + return Backend::IPU; + } else if (t == DispatchKey::XPU || t == DispatchKey::AutogradXPU) { + return Backend::XPU; + } else if (t == DispatchKey::SparseXPU) { + return Backend::SparseXPU; + } else if (t == DispatchKey::SparseCsrXPU) { + return Backend::SparseCsrXPU; + } else if (t == DispatchKey::QuantizedXPU) { + return Backend::QuantizedXPU; + } else if (t == DispatchKey::QuantizedPrivateUse1) { + return Backend::QuantizedPrivateUse1; + } else if (t == DispatchKey::HPU || t == DispatchKey::AutogradHPU) { + return Backend::HPU; + } else if (t == DispatchKey::MTIA || t == DispatchKey::AutogradMTIA) { + return Backend::MTIA; + } else if ( + t == DispatchKey::PrivateUse1 || t == DispatchKey::AutogradPrivateUse1) { + return Backend::PrivateUse1; + } else if (t == DispatchKey::Undefined) { + return Backend::Undefined; + } else { + TORCH_CHECK(false, "Unrecognized tensor type ID: ", t); + } +} + +inline DispatchKey backendToDispatchKey(Backend b) { + switch (b) { + case Backend::CPU: + return DispatchKey::CPU; + case Backend::CUDA: + return DispatchKey::CUDA; + case Backend::HIP: + return DispatchKey::HIP; + case Backend::VE: + return DispatchKey::VE; + case Backend::FPGA: + return DispatchKey::FPGA; + case Backend::MAIA: + return DispatchKey::MAIA; + case Backend::XLA: + return DispatchKey::XLA; + case Backend::Lazy: + return DispatchKey::Lazy; + case Backend::IPU: + return DispatchKey::IPU; + case Backend::XPU: + return DispatchKey::XPU; + case Backend::SparseXPU: + return DispatchKey::SparseXPU; + case Backend::SparseCsrXPU: + return DispatchKey::SparseCsrXPU; + case Backend::SparseCPU: + return DispatchKey::SparseCPU; + case Backend::SparseCUDA: + return DispatchKey::SparseCUDA; + case Backend::SparseMPS: + return DispatchKey::SparseMPS; + case Backend::SparseCsrMPS: + return DispatchKey::SparseCsrMPS; + case Backend::SparseHIP: + return DispatchKey::SparseHIP; + case Backend::SparseVE: + return DispatchKey::SparseVE; + case Backend::SparsePrivateUse1: + return DispatchKey::SparsePrivateUse1; + case Backend::SparseCsrCPU: + return DispatchKey::SparseCsrCPU; + case Backend::SparseCsrCUDA: + return DispatchKey::SparseCsrCUDA; + case Backend::SparseCsrHIP: + return DispatchKey::SparseCsrHIP; + case Backend::SparseCsrVE: + return DispatchKey::SparseCsrVE; + case Backend::SparseCsrPrivateUse1: + return DispatchKey::SparseCsrPrivateUse1; + case Backend::MkldnnCPU: + return DispatchKey::MkldnnCPU; + case Backend::Vulkan: + return DispatchKey::Vulkan; + case Backend::Metal: + return DispatchKey::Metal; + case Backend::Meta: + return DispatchKey::Meta; + case Backend::QuantizedCPU: + return DispatchKey::QuantizedCPU; + case Backend::QuantizedCUDA: + return DispatchKey::QuantizedCUDA; + case Backend::QuantizedPrivateUse1: + return DispatchKey::QuantizedPrivateUse1; + case Backend::Undefined: + return DispatchKey::Undefined; + case Backend::MPS: + return DispatchKey::MPS; + case Backend::HPU: + return DispatchKey::HPU; + case Backend::MTIA: + return DispatchKey::MTIA; + case Backend::PrivateUse1: + return DispatchKey::PrivateUse1; + default: + TORCH_CHECK(false, "Unknown backend"); + } +} + +inline DeviceType backendToDeviceType(Backend b) { + switch (b) { + case Backend::CPU: + case Backend::MkldnnCPU: + case Backend::SparseCPU: + case Backend::SparseCsrCPU: + case Backend::QuantizedCPU: + return DeviceType::CPU; + case Backend::CUDA: + case Backend::SparseCUDA: + case Backend::QuantizedCUDA: + case Backend::SparseCsrCUDA: + return DeviceType::CUDA; + case Backend::HIP: + return DeviceType::HIP; + case Backend::VE: + return DeviceType::VE; + case Backend::FPGA: + return DeviceType::FPGA; + case Backend::MAIA: + return DeviceType::MAIA; + case Backend::XLA: + return DeviceType::XLA; + case Backend::Lazy: + return DeviceType::Lazy; + case Backend::SparseHIP: + return DeviceType::HIP; + case Backend::SparseVE: + return DeviceType::VE; + case Backend::SparseCsrHIP: + return DeviceType::HIP; + case Backend::SparseCsrVE: + return DeviceType::VE; + case Backend::IPU: + return DeviceType::IPU; + case Backend::XPU: + case Backend::SparseXPU: + case Backend::SparseCsrXPU: + case Backend::QuantizedXPU: + return DeviceType::XPU; + case Backend::Vulkan: + return DeviceType::Vulkan; + case Backend::Metal: + return DeviceType::Metal; + case Backend::Meta: + return DeviceType::Meta; + case Backend::MPS: + case Backend::SparseMPS: + case Backend::SparseCsrMPS: + return DeviceType::MPS; + case Backend::HPU: + return DeviceType::HPU; + case Backend::MTIA: + return DeviceType::MTIA; + case Backend::PrivateUse1: + case Backend::SparsePrivateUse1: + case Backend::SparseCsrPrivateUse1: + case Backend::QuantizedPrivateUse1: + return DeviceType::PrivateUse1; + case Backend::Undefined: + TORCH_CHECK(false, "Undefined backend is not a valid device type"); + default: + TORCH_CHECK(false, "Unknown backend"); + } +} + +inline const char* toString(Backend b) { + switch (b) { + case Backend::CPU: + return "CPU"; + case Backend::CUDA: + return "CUDA"; + case Backend::HIP: + return "HIP"; + case Backend::VE: + return "VE"; + case Backend::FPGA: + return "FPGA"; + case Backend::XPU: + return "XPU"; + case Backend::IPU: + return "IPU"; + case Backend::MAIA: + return "MAIA"; + case Backend::XLA: + return "XLA"; + case Backend::Lazy: + return "Lazy"; + case Backend::MPS: + return "MPS"; + case Backend::SparseCPU: + return "SparseCPU"; + case Backend::SparseCUDA: + return "SparseCUDA"; + case Backend::SparseMPS: + return "SparseMPS"; + case Backend::SparseCsrMPS: + return "SparseCsrMPS"; + case Backend::SparseHIP: + return "SparseHIP"; + case Backend::SparseVE: + return "SparseVE"; + case Backend::SparseXPU: + return "SparseXPU"; + case Backend::SparsePrivateUse1: + return "SparsePrivateUse1"; + case Backend::SparseCsrCPU: + return "SparseCsrCPU"; + case Backend::SparseCsrCUDA: + return "SparseCsrCUDA"; + case Backend::SparseCsrHIP: + return "SparseCsrHIP"; + case Backend::SparseCsrVE: + return "SparseCsrVE"; + case Backend::SparseCsrXPU: + return "SparseCsrXPU"; + case Backend::SparseCsrPrivateUse1: + return "SparseCsrPrivateUse1"; + case Backend::MkldnnCPU: + return "MkldnnCPU"; + case Backend::Vulkan: + return "Vulkan"; + case Backend::Metal: + return "Metal"; + case Backend::Meta: + return "Meta"; + case Backend::QuantizedCPU: + return "QuantizedCPU"; + case Backend::QuantizedCUDA: + return "QuantizedCUDA"; + case Backend::QuantizedXPU: + return "QuantizedXPU"; + case Backend::QuantizedPrivateUse1: + return "QuantizedPrivateUse1"; + case Backend::HPU: + return "HPU"; + case Backend::MTIA: + return "MTIA"; + case Backend::PrivateUse1: + return "PrivateUseOne"; + default: + return "UNKNOWN_BACKEND"; + } +} + +inline bool isSparse(Backend b) { + switch (b) { + case Backend::SparseXPU: + case Backend::SparseCPU: + case Backend::SparseCUDA: + case Backend::SparseMPS: + case Backend::SparseHIP: + case Backend::SparseVE: + case Backend::SparsePrivateUse1: + return true; + default: + return false; + } +} + +inline bool isSparseCsr(Backend b) { + switch (b) { + case Backend::SparseCsrXPU: + case Backend::SparseCsrCPU: + case Backend::SparseCsrCUDA: + case Backend::SparseCsrHIP: + case Backend::SparseCsrVE: + case Backend::SparseCsrPrivateUse1: + return true; + default: + return false; + } +} + +} // namespace c10 + +C10_DIAGNOSTIC_POP() + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/CPUAllocator.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/CPUAllocator.h new file mode 100644 index 0000000000000000000000000000000000000000..d43d48e32ee794092b23a488cbb8518a6d5d2623 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/CPUAllocator.h @@ -0,0 +1,64 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include + +#include +#include +#include + +// TODO: rename to c10 +C10_DECLARE_bool(caffe2_report_cpu_memory_usage); + +namespace c10 { + +using MemoryDeleter = void (*)(void*); + +// A helper function that is basically doing nothing. +C10_API void NoDelete(void* /*unused*/); + +// A simple struct that is used to report C10's memory allocation, +// deallocation status and out-of-memory events to the profiler +class C10_API ProfiledCPUMemoryReporter { + public: + ProfiledCPUMemoryReporter() = default; + void New(void* ptr, size_t nbytes); + void OutOfMemory(size_t nbytes); + void Delete(void* ptr); + + private: + std::mutex mutex_; + std::unordered_map size_table_; + size_t allocated_ = 0; + size_t log_cnt_ = 0; +}; + +C10_API ProfiledCPUMemoryReporter& profiledCPUMemoryReporter(); + +// Get the CPU Allocator. +C10_API at::Allocator* GetCPUAllocator(); +// Sets the CPU allocator to the given allocator: the caller gives away the +// ownership of the pointer. +C10_API void SetCPUAllocator(at::Allocator* alloc, uint8_t priority = 0); + +// Get the Default CPU Allocator +C10_API at::Allocator* GetDefaultCPUAllocator(); + +// Get the Default Mobile CPU Allocator +C10_API at::Allocator* GetDefaultMobileCPUAllocator(); + +// The CPUCachingAllocator is experimental and might disappear in the future. +// The only place that uses it is in StaticRuntime. +// Set the CPU Caching Allocator +C10_API void SetCPUCachingAllocator(Allocator* alloc, uint8_t priority = 0); +// Get the CPU Caching Allocator +C10_API Allocator* GetCPUCachingAllocator(); + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/CachingDeviceAllocator.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/CachingDeviceAllocator.h new file mode 100644 index 0000000000000000000000000000000000000000..23b413de834aae788e8f763f60cd75ec7750dbea --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/CachingDeviceAllocator.h @@ -0,0 +1,126 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +namespace c10::CachingDeviceAllocator { + +using namespace c10::CachingAllocator; + +// Struct containing memory allocator summary statistics for a device. +struct DeviceStats { + // COUNT: allocations requested by client code + StatArray allocation; + // COUNT: number of allocated segments from device memory allocation. + StatArray segment; + // COUNT: number of active memory blocks (allocated or used by stream) + StatArray active; + // COUNT: number of inactive, split memory blocks (unallocated but can't be + // released via device memory deallocation) + StatArray inactive_split; + + // SUM: bytes allocated by this memory allocator + StatArray allocated_bytes; + // SUM: bytes reserved by this memory allocator (both free and used) + StatArray reserved_bytes; + // SUM: bytes within active memory blocks + StatArray active_bytes; + // SUM: bytes within inactive, split memory blocks + StatArray inactive_split_bytes; + // SUM: bytes requested by client code + StatArray requested_bytes; + + // COUNT: total number of failed calls to device malloc necessitating cache + // flushes. + int64_t num_alloc_retries = 0; + + // COUNT: total number of OOMs (i.e. failed calls to device memory allocation + // after cache flush) + int64_t num_ooms = 0; + + // COUNT: total number of oversize blocks allocated from pool + Stat oversize_allocations; + + // COUNT: total number of oversize blocks requiring malloc + Stat oversize_segments; + + // COUNT: total number of synchronize_and_free_events() calls + int64_t num_sync_all_streams = 0; + + // COUNT: total number of device memory allocation calls. This includes both + // mapped and malloced memory. + int64_t num_device_alloc = 0; + + // COUNT: total number of device memory deallocation calls. This includes both + // un-mapped and free memory. + int64_t num_device_free = 0; + + // SIZE: maximum block size that is allowed to be split. + int64_t max_split_size = 0; +}; + +} // namespace c10::CachingDeviceAllocator + +namespace c10 { + +using CaptureId_t = unsigned long long; + +// first is set if the instance is created by Graph mode capture_begin. +// second is set if the instance is created by Graph mode graph_pool_handle. +using MempoolId_t = std::pair; + +struct C10_API DeviceAllocator : public c10::Allocator { + DeviceAllocator(); + ~DeviceAllocator() override; + + // Returns true if the allocator has been properly initialized and is ready + // for use + virtual bool initialized() = 0; + + // Releases all cached device memory from the specified memory pool back to + // the system + virtual void emptyCache(MempoolId_t mempool_id = {0, 0}) = 0; + + // Associates a memory allocation with a stream to establish dependency + // tracking. Prevents memory reuse until all operations on the specified + // stream complete + virtual void recordStream(const DataPtr& ptr, c10::Stream stream) = 0; + + // Retrieves comprehensive memory statistics for the specified device, + // including allocation patterns, usage metrics + virtual CachingDeviceAllocator::DeviceStats getDeviceStats( + c10::DeviceIndex device) = 0; + + // Resets cumulative allocation statistics for the specified device to zero + virtual void resetAccumulatedStats(c10::DeviceIndex device) = 0; + + // Resets peak memory usage statistics for the specified device + virtual void resetPeakStats(c10::DeviceIndex device) = 0; + + // Return the free memory size and total memory size in bytes for the + // specified device. + virtual std::pair getMemoryInfo(c10::DeviceIndex device) { + TORCH_CHECK_NOT_IMPLEMENTED( + false, "getMemoryInfo is not implemented for this allocator yet."); + } +}; + +// This function is used to get the DeviceAllocator for a specific device type +// and keep backward compatibility with c10::GetAllocator. +C10_API inline DeviceAllocator* getDeviceAllocator(const DeviceType& t) { + TORCH_CHECK( + t != DeviceType::CPU, + "getDeviceAllocator is not supported for CPU device type."); + auto* allocator = c10::GetAllocator(t); + auto* device_allocator = dynamic_cast(allocator); + TORCH_INTERNAL_ASSERT( + device_allocator, "Allocator for ", t, " is not a DeviceAllocator."); + return device_allocator; +} + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/CompileTimeFunctionPointer.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/CompileTimeFunctionPointer.h new file mode 100644 index 0000000000000000000000000000000000000000..28dd52759e8de0f4f2f2947e96ccd0dd7467a95c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/CompileTimeFunctionPointer.h @@ -0,0 +1,62 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +namespace c10 { + +/** + * Represent a function pointer as a C++ type. + * This allows using the function pointer as a type + * in a template and calling it from inside the template + * allows the compiler to inline the call because it + * knows the function pointer at compile time. + * + * Example 1: + * int add(int a, int b) {return a + b;} + * using Add = TORCH_FN_TYPE(add); + * template struct Executor { + * int execute(int a, int b) { + * return Func::func_ptr()(a, b); + * } + * }; + * Executor executor; + * EXPECT_EQ(3, executor.execute(1, 2)); + * + * Example 2: + * int add(int a, int b) {return a + b;} + * template int execute(Func, int a, int b) { + * return Func::func_ptr()(a, b); + * } + * EXPECT_EQ(3, execute(TORCH_FN(add), 1, 2)); + */ +template +struct CompileTimeFunctionPointer final { + static_assert( + guts::is_function_type::value, + "TORCH_FN can only wrap function types."); + using FuncType = FuncType_; + + static constexpr FuncType* func_ptr() { + return func_ptr_; + } +}; + +template +struct is_compile_time_function_pointer : std::false_type {}; +template +struct is_compile_time_function_pointer< + CompileTimeFunctionPointer> : std::true_type {}; + +} // namespace c10 + +#define TORCH_FN_TYPE(func) \ + ::c10::CompileTimeFunctionPointer< \ + std::remove_pointer_t>, \ + func> +#define TORCH_FN(func) TORCH_FN_TYPE(func)() + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/ConstantSymNodeImpl.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/ConstantSymNodeImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..22a3cf2104d1c55c0d18681906cc4ae9c2c85800 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/ConstantSymNodeImpl.h @@ -0,0 +1,115 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace c10 { + +// Unlike other SymNodeImpl, this cannot be "dispatched" conventionally, +// as it typically needs to defer to another SymNodeImpl +// +// Can either represent a bool, int (don't support float yet) this is useful +// for representing otherwise unrepresentable large negative integer constant. +template +class C10_API ConstantSymNodeImpl : public SymNodeImpl { + static_assert( + ::std::is_same_v || ::std::is_same_v, + "ConstantSymNodeImpl can only accept int64_t or bool types"); + + public: + ConstantSymNodeImpl(T val) : value_(val) {} + + bool is_int() override { + return is_int_(); + } + bool is_bool() override { + return is_bool_(); + } + bool is_float() override { + return false; + } + int64_t guard_int( + const char* file [[maybe_unused]], + int64_t line [[maybe_unused]]) override { + TORCH_CHECK(is_int(), "not an int"); + return int_(); + } + bool guard_bool( + const char* file [[maybe_unused]], + int64_t line [[maybe_unused]]) override { + TORCH_CHECK(is_bool(), "not a bool"); + return bool_(); + } + double guard_float( + const char* file [[maybe_unused]], + int64_t line [[maybe_unused]]) override { + TORCH_CHECK(false, "not a float"); + } + int64_t int_() override { + TORCH_CHECK(is_int(), "not an int"); + return ::std::get(value_); + } + bool bool_() override { + TORCH_CHECK(is_bool(), "not a bool"); + return ::std::get(value_); + } + bool has_hint() override { + return true; + } + c10::SymNode eq(const c10::SymNode& other) override; + c10::SymNode ne(const c10::SymNode& other) override; + c10::SymNode ge(const c10::SymNode& other) override; + c10::SymNode le(const c10::SymNode& other) override; + c10::SymNode lt(const c10::SymNode& other) override; + c10::SymNode gt(const c10::SymNode& other) override; + c10::SymNode mul(const c10::SymNode& other) override; + ::std::string str() override { + if constexpr (is_int_()) { + return ::std::to_string(::std::get(value_)); + } else { + return ::std::get(value_) ? "true" : "false"; + } + } + std::optional constant_int() override { + if constexpr (is_int_()) { + return ::std::get(value_); + } else { + return std::nullopt; + } + } + std::optional constant_bool() override { + if constexpr (is_bool_()) { + return ::std::get(value_); + } else { + return std::nullopt; + } + } + bool is_constant() override { + return true; + } + bool is_symbolic() override { + return false; + } + + private: + ::std::variant value_; + + static constexpr bool is_int_() { + return ::std::is_same_v; + } + static constexpr bool is_bool_() { + return ::std::is_same_v; + } +}; + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/Contiguity.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/Contiguity.h new file mode 100644 index 0000000000000000000000000000000000000000..014903df018c3db2b2df40ca72ee4cd40ebf21c6 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/Contiguity.h @@ -0,0 +1,314 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include +#include +#include +#include +#include + +#include +#include + +namespace c10 { + +template +bool _compute_contiguous(ArrayRef sizes, ArrayRef strides, T numel) { + if (numel == 0) { + return true; + } + + T expected_stride = 1; + // NB: make sure we do signed arithmetic + for (int64_t d = int64_t(sizes.size()) - 1; d >= 0; d--) { + const auto& size_d = sizes[d]; + if (size_d == 1) { + continue; + } + + if (strides[d] != expected_stride) { + return false; + } + expected_stride *= size_d; + } + return true; +} + +// Return a SymBool with underlying symbolic expression that represents +// contiguity. Guaranteed not to throw DDE, may returns a symbolic expressions +// or symbolic True. +inline static c10::SymBool _compute_contiguous_sym( + ArrayRef sizes, + ArrayRef strides, + const c10::SymInt& numel) { + // If this return true, the tensor is contiguous indeed. Otherwise it could be + // either. + auto is_contiguous_or_false = [&]() { + if (TORCH_GUARD_OR_FALSE(sym_eq(numel, 0))) { + return true; + } + + // When calculating the expected stride, we can choose to multiply + // with max(1, size[d]) or size[d]. Regardless, this is ok for this + // function. Why? + // (1) If size[d] == 0, then the tensor is contiguous and if + // we return true or false it won't break this function. + // (2) If size[d] is not 0, then max(1,size[d]) and size[d] are equal. + // Therefore, if we choose to use max(1, size[d]) or size[d] to + // calculate the expected stride, the result is the same. + // + // We symbolically check both paths to maximize the cases where this + // function returns true. This is because make_contiguous_strides_for adds + // the max symbolically, and in some other situations the max might not be + // there. And we want to ensure we return true in both cases. + c10::SymInt expected_stride = 1; + c10::SymInt expected_stride_max = 1; + // NB: make sure we do signed arithmetic + for (int64_t d = int64_t(sizes.size()) - 1; d >= 0; d--) { + if (TORCH_GUARD_OR_FALSE(sym_eq(sizes[d], 1))) { + continue; + } + + if (TORCH_GUARD_OR_TRUE(sym_ne(strides[d], expected_stride)) && + TORCH_GUARD_OR_TRUE(sym_ne(strides[d], expected_stride_max))) { + return false; + } + expected_stride_max *= sizes[d].max(1); + expected_stride *= sizes[d]; + } + return true; + }; + + // We try to minimize creating large symbolic expressions when not needed to + // avoid symbolic evaluation perf issues. + if (is_contiguous_or_false()) { + return c10::SymBool(true); + } + + // Build a single expression that represents contiguity and return it. + c10::SymBool is_empty = sym_eq(numel, 0); + c10::SymBool is_contiguous_cond = true; + + c10::SymInt expected_stride = 1; + for (int64_t d = int64_t(sizes.size()) - 1; d >= 0; d--) { + const auto& size_d = sizes[d]; + is_contiguous_cond = is_contiguous_cond.sym_and( + size_d.sym_eq(1).sym_or(sym_eq(strides[d], expected_stride))); + expected_stride = expected_stride * size_d; + } + return is_contiguous_cond.sym_or(is_empty); +} + +// When T is SymInt this function may throw a data dependent error. +// _compute_channels_last_contiguous_2d_sym does not. Only use this function +// when inputs are hinted. +template +bool _compute_channels_last_contiguous_2d( + ArrayRef sizes, + ArrayRef strides) { + // Please don't combine these code, constant array is used here to let + // compiler fully unroll the loop to get better performance + switch (sizes.size()) { + case 4: { + T expected = 1; + for (auto& d : {1, 3, 2, 0}) { + const auto& size_d = sizes[d]; + if (size_d != 1) { + if (strides[d] != expected) { + return false; + } + expected *= size_d; + } + } + return true; + } + // NOLINTNEXTLINE(bugprone-branch-clone) + case 3: + // TODO dim == 3 case will be enabled once it is fully tested + return false; + default: + return false; + } +} + +// Return a SymBool with underlying symbolic expression that represents +// contiguity. Guaranteed not to throw DDE, may returns a symbolic expressions +// or symbolic True. +inline static c10::SymBool _compute_channels_last_contiguous_2d_sym( + ArrayRef sizes, + ArrayRef strides) { + switch (sizes.size()) { + case 4: { + // When this function return True, result always true. When it return + // False, result could be False or data dependent. + auto guard_or_false = [&]() { + c10::SymInt expected = 1; + for (auto& d : {1, 3, 2, 0}) { + const auto& size_d = sizes[d]; + // Not taking this branch could make this return False instead of True + // but not vice-versa. so its ok. + if (TORCH_GUARD_OR_FALSE(sym_eq(sizes[d], 1))) { + continue; + } + // Taking this branch could make this return False instead of True + // but not vice-versa. so its ok. + if (TORCH_GUARD_OR_TRUE(sym_ne(strides[d], expected))) { + return false; + } + expected *= size_d; + } + return true; + }; + + // We try to minimize creating large symbolic expressions when not needed + // to avoid symbolic evaluation perf issues. + if (guard_or_false()) { + return c10::SymBool(true); + } + + // Result is either false, or data dependent. + c10::SymInt expected_stride = 1; + c10::SymBool cond = true; + + for (auto& d : {1, 3, 2, 0}) { + const auto& size_d = sizes[d]; + cond = cond.sym_and( + size_d.sym_eq(1).sym_or(sym_eq(strides[d], expected_stride))); + expected_stride *= size_d; + } + return cond; + } + // NOLINTNEXTLINE(bugprone-branch-clone) + case 3: + // TODO dim == 3 case will be enabled once it is fully tested + return c10::SymBool(false); + default: + return c10::SymBool(false); + } +} + +// When T is SymInt this function may throw a data dependent error. +// _compute_channels_last_contiguous_3d_sym does not. Only use this function +// when inputs are hinted. +template +bool _compute_channels_last_contiguous_3d( + ArrayRef sizes, + ArrayRef strides) { + // Please don't combine these code, constant array is used here to let + // compiler fully unroll the loop to get better performance + switch (sizes.size()) { + case 5: { + T expected = 1; + for (auto& d : {1, 4, 3, 2, 0}) { + const auto& size_d = sizes[d]; + if (size_d != 1) { + if (strides[d] != expected) { + return false; + } + expected *= size_d; + } + } + return true; + } + // NOLINTNEXTLINE(bugprone-branch-clone) + case 4: + // TODO dim == 4 case will be enabled once it is fully tested + return false; + default: + return false; + } +} + +inline static c10::SymBool _compute_channels_last_contiguous_3d_sym( + ArrayRef sizes, + ArrayRef strides) { + switch (sizes.size()) { + case 5: { + // When this function return True, result always true. When it return + // False, result could be False or data dependent. + auto guard_or_false = [&]() { + c10::SymInt expected = 1; + for (auto& d : {1, 4, 3, 2, 0}) { + const auto& size_d = sizes[d]; + // Not taking this branch could make this return False instead of True + // but not vice-versa. so its ok. + if (TORCH_GUARD_OR_FALSE(sym_eq(sizes[d], 1))) { + continue; + } + // Taking this branch could make this return False instead of True + // but not vice-versa. so its ok. + if (TORCH_GUARD_OR_TRUE(sym_ne(strides[d], expected))) { + return false; + } + expected *= size_d; + } + return true; + }; + + // We try to minimize creating large symbolic expressions when not needed + // to avoid symbolic evaluation perf issues. + if (guard_or_false()) { + return c10::SymBool(true); + } + + // Result is either false, or data dependent. + c10::SymInt expected_stride = 1; + c10::SymBool cond = true; + + for (auto& d : {1, 4, 3, 2, 0}) { + const auto& size_d = sizes[d]; + cond = cond.sym_and( + size_d.sym_eq(1).sym_or(sym_eq(strides[d], expected_stride))); + expected_stride *= size_d; + } + return cond; + } + // NOLINTNEXTLINE(bugprone-branch-clone) + case 4: + // TODO dim == 4 case will be enabled once it is fully tested + return c10::SymBool(false); + default: + return c10::SymBool(false); + } +} + +template +bool _compute_non_overlapping_and_dense( + ArrayRef sizes, + ArrayRef strides) { + auto dim = sizes.size(); + if (dim == 1) { + return sizes[0] < 2 || strides[0] == 1; + } + SmallVector perm; + perm.resize(dim); + for (const auto i : c10::irange(dim)) { + perm[i] = i; + } + // Sort by strides, leaving 0 and 1 sized dims at the end of the array + std::sort(perm.begin(), perm.end(), [&](int64_t a, int64_t b) { + if (sizes[a] < 2) { + return false; + } else if (sizes[b] < 2) { + return true; + } + return strides[a] < strides[b]; + }); + T require_stride = 1; + for (const auto i : c10::irange(dim)) { + const auto& size_perm_i = sizes[perm[i]]; + if (size_perm_i < 2) { + return true; + } + if (strides[perm[i]] != require_stride) { + return false; + } + require_stride *= size_perm_i; + } + return true; +} + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/CopyBytes.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/CopyBytes.h new file mode 100644 index 0000000000000000000000000000000000000000..bc2632794299da5a6c9c5d30be0b4591600bab2a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/CopyBytes.h @@ -0,0 +1,53 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include + +namespace c10 { + +using CopyBytesFunction = void (*)( + size_t nbytes, + const void* src, + Device src_device, + void* dst, + Device dst_device); + +struct C10_API _CopyBytesFunctionRegisterer { + _CopyBytesFunctionRegisterer( + DeviceType from, + DeviceType to, + CopyBytesFunction func_sync, + CopyBytesFunction func_async = nullptr); +}; + +#define REGISTER_COPY_BYTES_FUNCTION(from, to, ...) \ + namespace { \ + static _CopyBytesFunctionRegisterer C10_ANONYMOUS_VARIABLE( \ + g_copy_function)(from, to, __VA_ARGS__); \ + } + +/* + * WARNING: Implementations for this function are currently registered from + * ATen and caffe2, not yet from c10. Don't use this if not either ATen + * or caffe2 is present as well. + * We can't move them yet, because the CUDA implementations aren't unified yet + * between ATen and caffe2. + * We're planning to move the implementations into c10/backend/xxx + * to make c10 self contained again. + */ +C10_API void CopyBytes( + size_t nbytes, + const void* src, + Device src_device, + void* dst, + Device dst_device, + bool async); +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/DefaultDtype.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/DefaultDtype.h new file mode 100644 index 0000000000000000000000000000000000000000..240c173ca22ae28ab20e243890b2f8a054156fa5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/DefaultDtype.h @@ -0,0 +1,20 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +namespace caffe2 { +class TypeMeta; +} // namespace caffe2 + +namespace c10 { +C10_API void set_default_dtype(caffe2::TypeMeta dtype); +C10_API const caffe2::TypeMeta get_default_dtype(); +C10_API ScalarType get_default_dtype_as_scalartype(); +C10_API const caffe2::TypeMeta get_default_complex_dtype(); +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/DefaultTensorOptions.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/DefaultTensorOptions.h new file mode 100644 index 0000000000000000000000000000000000000000..8d5e66ec405ddeb1494d987a034cf1b945663667 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/DefaultTensorOptions.h @@ -0,0 +1,50 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include + +namespace c10 { + +struct TensorOptions; + +/// Like TensorOptions, but all fields are guaranteed to be filled. +struct DefaultTensorOptions { + DefaultTensorOptions() = default; + + caffe2::TypeMeta dtype() const noexcept { + return dtype_; + } + Device device() const noexcept { + return device_; + } + Layout layout() const noexcept { + return layout_; + } + bool requires_grad() const noexcept { + return requires_grad_; + } + + // Defined in TensorOptions.h + inline DefaultTensorOptions& merge(const TensorOptions& options); + + private: + caffe2::TypeMeta dtype_ = caffe2::TypeMeta::Make(); // 64-bit + Device device_ = at::kCPU; // 32-bit + Layout layout_ = at::kStrided; // 8-bit + bool requires_grad_ = false; // 8-bit +}; + +inline const DefaultTensorOptions& getDefaultTensorOptions() { + static const auto options = DefaultTensorOptions(); + return options; +} + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/Device.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/Device.h new file mode 100644 index 0000000000000000000000000000000000000000..d3380f434c6c8284476ac3bc662fd88e10289a86 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/Device.h @@ -0,0 +1,221 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace c10 { + +/// An index representing a specific device; e.g., the 1 in GPU 1. +/// A DeviceIndex is not independently meaningful without knowing +/// the DeviceType it is associated; try to use Device rather than +/// DeviceIndex directly. +using DeviceIndex = int8_t; + +/// Represents a compute device on which a tensor is located. A device is +/// uniquely identified by a type, which specifies the type of machine it is +/// (e.g. CPU or CUDA GPU), and a device index or ordinal, which identifies the +/// specific compute device when there is more than one of a certain type. The +/// device index is optional, and in its defaulted state represents (abstractly) +/// "the current device". Further, there are two constraints on the value of the +/// device index, if one is explicitly stored: +/// 1. A negative index represents the current device, a non-negative index +/// represents a specific, concrete device, +/// 2. When the device type is CPU, the device index must be zero. +struct C10_API Device final { + using Type = DeviceType; + + /// Constructs a new `Device` from a `DeviceType` and an optional device + /// index. + /* implicit */ Device(DeviceType type, DeviceIndex index = -1) + : type_(type), index_(index) { + validate(); + } + + /// Constructs a `Device` from a string description, for convenience. + /// The string supplied must follow the following schema: + /// `(cpu|cuda)[:]` + /// where `cpu` or `cuda` specifies the device type, and + /// `:` optionally specifies a device index. + /* implicit */ Device(const std::string& device_string); + + /// Returns true if the type and index of this `Device` matches that of + /// `other`. + bool operator==(const Device& other) const noexcept { + return this->type_ == other.type_ && this->index_ == other.index_; + } + + /// Returns true if the type or index of this `Device` differs from that of + /// `other`. + bool operator!=(const Device& other) const noexcept { + return !(*this == other); + } + + /// Sets the device index. + void set_index(DeviceIndex index) { + index_ = index; + } + + /// Returns the type of device this is. + DeviceType type() const noexcept { + return type_; + } + + /// Returns the optional index. + DeviceIndex index() const noexcept { + return index_; + } + + /// Returns true if the device has a non-default index. + bool has_index() const noexcept { + return index_ != -1; + } + + /// Return true if the device is of CUDA type. + bool is_cuda() const noexcept { + return type_ == DeviceType::CUDA; + } + + /// Return true if the device is of PrivateUse1 type. + bool is_privateuseone() const noexcept { + return type_ == DeviceType::PrivateUse1; + } + + /// Return true if the device is of MPS type. + bool is_mps() const noexcept { + return type_ == DeviceType::MPS; + } + + /// Return true if the device is of HIP type. + bool is_hip() const noexcept { + return type_ == DeviceType::HIP; + } + + /// Return true if the device is of VE type. + bool is_ve() const noexcept { + return type_ == DeviceType::VE; + } + + /// Return true if the device is of XPU type. + bool is_xpu() const noexcept { + return type_ == DeviceType::XPU; + } + + /// Return true if the device is of IPU type. + bool is_ipu() const noexcept { + return type_ == DeviceType::IPU; + } + + /// Return true if the device is of XLA type. + bool is_xla() const noexcept { + return type_ == DeviceType::XLA; + } + + /// Return true if the device is of MTIA type. + bool is_mtia() const noexcept { + return type_ == DeviceType::MTIA; + } + + /// Return true if the device is of HPU type. + bool is_hpu() const noexcept { + return type_ == DeviceType::HPU; + } + + /// Return true if the device is of Lazy type. + bool is_lazy() const noexcept { + return type_ == DeviceType::Lazy; + } + + /// Return true if the device is of Vulkan type. + bool is_vulkan() const noexcept { + return type_ == DeviceType::Vulkan; + } + + /// Return true if the device is of Metal type. + bool is_metal() const noexcept { + return type_ == DeviceType::Metal; + } + + /// Return true if the device is of MAIA type. + bool is_maia() const noexcept { + return type_ == DeviceType::MAIA; + } + + /// Return true if the device is of META type. + bool is_meta() const noexcept { + return type_ == DeviceType::Meta; + } + + /// Return true if the device is of CPU type. + bool is_cpu() const noexcept { + return type_ == DeviceType::CPU; + } + + /// Return true if the device supports arbitrary strides. + bool supports_as_strided() const noexcept { + return type_ != DeviceType::IPU && type_ != DeviceType::XLA && + type_ != DeviceType::Lazy; + } + + /// Same string as returned from operator<<. + std::string str() const; + + private: + DeviceType type_; + DeviceIndex index_ = -1; + void validate() { + // Removing these checks in release builds noticeably improves + // performance in micro-benchmarks. + // This is safe to do, because backends that use the DeviceIndex + // have a later check when we actually try to switch to that device. + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + index_ >= -1, + "Device index must be -1 or non-negative, got ", + static_cast(index_)); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + !is_cpu() || index_ <= 0, + "CPU device index must be -1 or zero, got ", + static_cast(index_)); + } +}; + +C10_API std::ostream& operator<<(std::ostream& stream, const Device& device); + +} // namespace c10 + +namespace std { +template <> +struct hash { + size_t operator()(c10::Device d) const noexcept { + // Are you here because this static assert failed? Make sure you ensure + // that the bitmasking code below is updated accordingly! + static_assert(sizeof(c10::DeviceType) == 1, "DeviceType is not 8-bit"); + static_assert(sizeof(c10::DeviceIndex) == 1, "DeviceIndex is not 8-bit"); + // Note [Hazard when concatenating signed integers] + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + // We must first convert to a same-sized unsigned type, before promoting to + // the result type, to prevent sign extension when any of the values is -1. + // If sign extension occurs, you'll clobber all of the values in the MSB + // half of the resulting integer. + // + // Technically, by C/C++ integer promotion rules, we only need one of the + // uint32_t casts to the result type, but we put in both for explicitness's + // sake. + uint32_t bits = static_cast(static_cast(d.type())) + << 16 | + static_cast(static_cast(d.index())); + return std::hash{}(bits); + } +}; +} // namespace std + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/DeviceArray.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/DeviceArray.h new file mode 100644 index 0000000000000000000000000000000000000000..b2b179b4d2d82385aefe1f1b79cb2069120500d7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/DeviceArray.h @@ -0,0 +1,33 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include +#include +#include +#include +#include + +namespace c10 { + +template +class DeviceArray { + public: + DeviceArray(c10::Allocator& allocator, size_t size) + : data_ptr_(allocator.allocate(size * sizeof(T))) { + static_assert(std::is_trivial_v, "T must be a trivial type"); + TORCH_INTERNAL_ASSERT( + 0 == (reinterpret_cast(data_ptr_.get()) % alignof(T)), + "c10::DeviceArray: Allocated memory is not aligned for this data type"); + } + + T* get() { + return static_cast(data_ptr_.get()); + } + + private: + c10::DataPtr data_ptr_; +}; + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/DeviceCapability.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/DeviceCapability.h new file mode 100644 index 0000000000000000000000000000000000000000..85477281261bed35e2652ddc471c9bae4042707a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/DeviceCapability.h @@ -0,0 +1,81 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include + +namespace c10 { + +constexpr size_t NUMBER_OF_DEVICE_CAPABILITIES = NumScalarTypes; + +// Generate bitfields for each scalar type +#define DEFINE_SCALAR_TYPE(_1, n) unsigned int has_##n : 1; + +// Generate enum indices for each scalar type +#define DEFINE_SCALAR_ENUM(_1, name) kIndex_##name, + +enum ScalarTypeIndex { + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_SCALAR_ENUM) +}; + +/** + * @brief DeviceCapability represents the the common capabilities that all + * devices should support. + * + * This struct provides a compact way to represent the common capabilities that + * all devices should support. Includes the following capabilities: + * - Supported data types + * + * Purpose + * - Enable device-specific optimizations based on supported capabilities + * + * Contract + * + * Supported data types: + * - Each bitfield represents support for one device capability + * - Bit value 1 means the capability is supported, 0 means not supported + * - The struct is initialized with all capabilities enabled by default + * + * @note Adding New Capabilities + * + * 1. Define the new capability in the `DeviceCapability` struct + * 2. Update the support of the new capability in each accelerator + * implementation + * 3. Add the new capability to the returned PyObject Dictionary + */ +struct C10_API DeviceCapability { + union { + struct { + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_SCALAR_TYPE) + } supported_scalar_types; + uint64_t capability_bits; // Allow direct bit manipulation + } capability_data; + + // Default constructor with all capabilities enabled. + DeviceCapability() { + capability_data.capability_bits = + ((1ULL << NUMBER_OF_DEVICE_CAPABILITIES) - 1); + } + + // Iterate supported ScalarTypes without allocating a vector + template + void forEachSupportedScalarType(F&& visitor) const { +#define VISIT_SCALAR_TYPE(_1, n) \ + if (capability_data.supported_scalar_types.has_##n) { \ + visitor(ScalarType::n); \ + } + + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(VISIT_SCALAR_TYPE) + +#undef VISIT_SCALAR_TYPE + } +}; + +#undef DEFINE_SCALAR_ENUM +#undef DEFINE_SCALAR_TYPE +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/DeviceGuard.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/DeviceGuard.h new file mode 100644 index 0000000000000000000000000000000000000000..389ac29d10029d915279857f4fb4e2ffeb880307 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/DeviceGuard.h @@ -0,0 +1,207 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include + +namespace c10 { + +/// RAII guard that sets a certain default device in its constructor, and +/// changes it back to the device that was originally active upon destruction. +/// +/// The device is always reset to the one that was active at the time of +/// construction of the guard. Even if you `set_device` after construction, the +/// destructor will still reset the device to the one that was active at +/// construction time. +/// +/// This device guard does NOT have an uninitialized state; it is guaranteed +/// to reset a device on exit. If you are in a situation where you *might* +/// want to setup a guard (i.e., are looking for the moral equivalent +/// of std::optional), see OptionalDeviceGuard. +class DeviceGuard { + public: + /// No default constructor; see Note [Omitted default constructor from RAII] + explicit DeviceGuard() = delete; + + /// Set the current device to the passed Device. + explicit DeviceGuard(Device device) : guard_(device) {} + + /// This constructor is for testing only. + explicit DeviceGuard( + Device device, + const impl::DeviceGuardImplInterface* impl) + : guard_(device, impl) {} + + ~DeviceGuard() = default; + + /// Copy is disallowed + DeviceGuard(const DeviceGuard&) = delete; + DeviceGuard& operator=(const DeviceGuard&) = delete; + + /// Move is disallowed, as DeviceGuard does not have an uninitialized state, + /// which is required for moves on types with nontrivial destructors. + DeviceGuard(DeviceGuard&& other) = delete; + DeviceGuard& operator=(DeviceGuard&& other) = delete; + + /// Sets the device to the given one. The specified device must be consistent + /// with the device type originally specified during guard construction. + /// + /// TODO: The consistency check here is inconsistent with StreamGuard's + /// behavior with set_stream, where a stream on a different device than + /// the original one isn't an error; we just reset the stream and then + /// switch devices. + void reset_device(at::Device device) { + guard_.reset_device(device); + } + + /// This method is for testing only. + void reset_device( + at::Device device, + const impl::DeviceGuardImplInterface* impl) { + guard_.reset_device(device, impl); + } + + /// Sets the device index to the given one. The device type is inferred + /// from the original device type the guard was constructed with. + void set_index(DeviceIndex index) { + guard_.set_index(index); + } + + /// Returns the device that was set at the time the guard was constructed. + Device original_device() const { + return guard_.original_device(); + } + + /// Returns the most recent device that was set using this device guard, + /// either from construction, or via set_device. + Device current_device() const { + return guard_.current_device(); + } + + private: + impl::InlineDeviceGuard guard_; +}; + +/** + * A OptionalDeviceGuard is an RAII class that sets a device to some value on + * initialization, and resets the device to its original value on destruction. + * Morally, a OptionalDeviceGuard is equivalent to std::optional, + * but with extra constructors and methods as appropriate. + * + * Besides its obvious use (optionally applying a DeviceGuard), + * OptionalDeviceGuard is often also used for the following idiom: + * + * OptionalDeviceGuard g; + * for (const auto& t : tensors) { + * g.set_device(t.device()); + * do_something_with(t); + * } + * + * This usage is marginally more efficient than constructing a DeviceGuard every + * iteration of the for loop, as it avoids an unnecessary device reset. + * + * Unlike DeviceGuard, a OptionalDeviceGuard may be uninitialized. This occurs + * when you use the nullary constructor, or pass a nullopt to the constructor. + * Uninitialized OptionalDeviceGuards do *nothing*; they do not know what the + * original device was and they do not reset on destruction. This is why + * original_device() and current_device() return std::optional rather + * than Device (as they do in DeviceGuard), and also is why we didn't just + * provide OptionalDeviceGuard by default and hide DeviceGuard from users. + * + * The semantics of an OptionalDeviceGuard are exactly explained by thinking + * of it as an std::optional. In particular, an initialized + * OptionalDeviceGuard doesn't restore device to its value at construction; it + * restores device to its value *at initialization*. So if you have the + * program: + * + * setDevice(1); + * OptionalDeviceGuard g; + * setDevice(2); + * g.reset_device(Device(DeviceType::CUDA, 3)); // initializes! + * + * On destruction, g will reset device to 2, rather than 1. + * + * An uninitialized OptionalDeviceGuard is distinct from a (initialized) + * DeviceGuard whose original_device_ and current_device_ match, since the + * DeviceGuard will still reset the device to original_device_. + */ +class OptionalDeviceGuard { + public: + /// Create an uninitialized guard. Set the guard later using reset_device. + explicit OptionalDeviceGuard() = default; + + /// Initialize the guard, setting the current device to the passed Device. + explicit OptionalDeviceGuard(Device device) : guard_(device) {} + + /// Initialize the guard if a Device is passed; otherwise leave the + /// guard uninitialized. + explicit OptionalDeviceGuard(std::optional device) : guard_(device) {} + + /// Constructor for testing only. + explicit OptionalDeviceGuard( + Device device, + const impl::DeviceGuardImplInterface* impl) + : guard_(device, impl) {} + + ~OptionalDeviceGuard() = default; + /// Copy is disallowed + OptionalDeviceGuard(const OptionalDeviceGuard&) = delete; + OptionalDeviceGuard& operator=(const OptionalDeviceGuard&) = delete; + + /// Move is disallowed + /// See Note [Explicit initialization of optional fields] + /// and // Note [Move construction for RAII guards is tricky] + /// for rationale. + OptionalDeviceGuard(OptionalDeviceGuard&& other) = delete; + OptionalDeviceGuard& operator=(OptionalDeviceGuard&& other) = delete; + + /// Sets the device to the given one. The specified device must be consistent + /// with the device type originally specified during guard construction. + void reset_device(at::Device device) { + guard_.reset_device(device); + } + + /// For testing only + void reset_device( + at::Device device, + const impl::DeviceGuardImplInterface* impl) { + guard_.reset_device(device, impl); + } + + /// Returns the device that was set at the time the guard was constructed. + std::optional original_device() const { + return guard_.original_device(); + } + + /// Returns the most recent device that was set using this device guard, + /// either from construction, or via reset_device. + std::optional current_device() const { + return guard_.current_device(); + } + + private: + impl::InlineOptionalDeviceGuard guard_; +}; + +// Note [Whither the DeviceGuard boilerplate] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// Design note: in principle, we could avoid these wrappers using: +// +// using DeviceGuard = impl::InlineDeviceGuard; +// using OptionalDeviceGuard = +// impl::InlineOptionalDeviceGuard; +// +// But the error messages are worse, and our users can't just look at the +// header file to find out what's going on. Furthermore, for specializations +// like CUDAStreamGuard, it can be profitable to replace some interfaces with +// refined types (e.g., return CUDAStream instead of Stream). So, we eat +// the boilerplate and write out the API explicitly. + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/DeviceType.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/DeviceType.h new file mode 100644 index 0000000000000000000000000000000000000000..3847b5e2650e4100d19dc0031747769f709b92f7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/DeviceType.h @@ -0,0 +1,35 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include + +// If you modified DeviceType in caffe2/proto/caffe2.proto, please also sync +// your changes into torch/headeronly/core/DeviceType.h. +#include + +#include +#include + +namespace c10 { + +C10_API std::string DeviceTypeName(DeviceType d, bool lower_case = false); + +C10_API bool isValidDeviceType(DeviceType d); + +C10_API std::ostream& operator<<(std::ostream& stream, DeviceType type); + +C10_API void register_privateuse1_backend(const std::string& backend_name); +C10_API std::string get_privateuse1_backend(bool lower_case = true); + +C10_API bool is_privateuse1_backend_registered(); + +} // namespace c10 + +namespace torch { +// NOLINTNEXTLINE(misc-unused-using-decls) +using c10::DeviceType; +} // namespace torch + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/DispatchKey.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/DispatchKey.h new file mode 100644 index 0000000000000000000000000000000000000000..2aa647574ccbc1112d10a5558255d9a5b625a9b2 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/DispatchKey.h @@ -0,0 +1,750 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace c10 { + +// Semantically, each value of BackendComponent identifies a "backend" for our +// dispatch. Some functionalities that we may dispatch to are allowed to +// register different handlers for each backend. The BackendComponent is then +// used to figure out which backend implementation to dispatch to. + +// In implementation terms, the backend component identifies a specific "bit" in +// a DispatchKeySet. The bits in the DispatchKeySet are split between the bottom +// ~12 "BackendComponent" bits, while the remaining upper bits are assigned to +// functionalities. When we encounter a functionality bit that is known to be +// customizable per-backend, then we also look at the lower BackendComponent +// bits and take the highest bit to determine which backend's implementation to +// use. + +// WARNING! If you add a new backend component to the end of this list, +// make sure you register it before Meta. +// Meta must be at the end so that meta key in tls triggers meta kernels. +// (But you shouldn't: private use keys should have higher precedence than all +// built-in keys) + +// If you add a new (non-privateuse) backend here, +// make sure to add an Autograd fallthrough kernel +// in aten/src/ATen/core/VariableFallbackKernel.cpp + +#define C10_FORALL_BACKEND_COMPONENTS(_, extra) \ + _(CPU, extra) \ + _(CUDA, extra) \ + _(HIP, extra) \ + _(XLA, extra) \ + _(MPS, extra) \ + _(IPU, extra) \ + _(XPU, extra) \ + _(HPU, extra) \ + _(VE, extra) \ + _(Lazy, extra) \ + _(MTIA, extra) \ + _(MAIA, extra) \ + _(PrivateUse1, extra) \ + _(PrivateUse2, extra) \ + _(PrivateUse3, extra) \ + _(Meta, extra) + +// WARNING! If we add a new per-backend functionality key that has higher +// priority than Autograd, then make sure you update EndOfRuntimeBackendKeys + +#define C10_FORALL_FUNCTIONALITY_KEYS(_) \ + _(Dense, ) \ + _(Quantized, Quantized) \ + _(Sparse, Sparse) \ + _(SparseCsr, SparseCsr) \ + _(NestedTensor, NestedTensor) \ + _(AutogradFunctionality, Autograd) + +enum class BackendComponent : uint8_t { + + // A "backend" is colloquially used to refer to handlers for dispatch + // which actually implement the numerics of an operation in question. + // + // Due to the nature of the enum, these backends are specified in + // an ordered way, but for most backends this order is not semantically + // meaningful (e.g., it's valid to reorder these backends without changing + // semantics). The only situation when backend ordering is meaningful + // is when the backend participates in multiple dispatch with another + // backend; e.g., CPU and CUDA (cuda must have higher priority). + + // These keys don't correspond to individual kernels. + // Instead, they represent the backends that are allowed to override specific + // pieces of functionality: + // - dense kernels (e.g. DispatchKey::CPU) + // - sparse kernels (e.g. DispatchKey::SparseCPU) + // - quantized kernels (e.g. DispatchKey::QuantizedCPU) + // - autograd kernels (e.g. DispatchKey::AutogradCPU) + // We reserve space in the runtime operator table for this full cross product + // of + // [backends in this enum] x [keys below that are explicitly marked as having + // per-backend functionality] + // + // A meta tensor is a tensor without any data associated with it. (They + // have also colloquially been referred to as tensors on the "null" device). + // A meta tensor can be used to dry run operators without actually doing any + // computation, e.g., add on two meta tensors would give you another meta + // tensor with the output shape and dtype, but wouldn't actually add anything. + + InvalidBit = 0, +#define DEFINE_BACKEND_COMPONENT(n, _) n##Bit, + C10_FORALL_BACKEND_COMPONENTS(DEFINE_BACKEND_COMPONENT, unused) +#undef DEFINE_BACKEND_COMPONENT + + // Define an alias to represent end of backend dispatch keys. + // If you add new backend keys after PrivateUse3, please also update it here. + EndOfBackendKeys = MetaBit, +}; + +// Semantically, a dispatch key identifies a possible "level" in our +// dispatch, for which a handler may be registered. Each handler corresponds +// to a type of functionality. +// +// In implementation terms, the dispatch key identifies a specific "bit" in a +// DispatchKeySet. Higher bit indexes get handled by dispatching first (because +// we "count leading zeros" when we extract the highest priority dispatch +// key.) +// +// Note [DispatchKey Classification] +// This enum actually contains several types of keys, which are explained +// in more detail further down: +// (1) non-customizable backends (e.g. FPGA) +// (2) non-customizable functionalities (e.g. Functionalize) +// (3) functionalized that are customizable per backend (e.g. Dense, Sparse, +// AutogradFunctionality) (4) per-backend instances of customizable +// functionalities (e.g. CPU, SparseCPU, AutogradCPU) (5) alias keys (e.g. +// CompositeImplicitAutograd) +// +// Of the categories above, it's important to note: +// (a) which keys are assigned individual bits in a DispatchKeySet +// (b) which keys are assigned individual slots in the runtime operator table +// ("Runtime keys") +// +// (1), (2) and (3) all get their own dedicated bits in the DispatchKeySet. +// (1), (2) and (4) all get their own dedicated slots in the runtime operator +// table. + +// See Note [DispatchKeySet Internal Representation] for more details. +// +// NOTE: Keep the list in sync with `DispatchKey` in torchgen/model.py +enum class DispatchKey : uint16_t { + + // ~~~~~~~~~~~~~~~~~~~~~~~~~~ UNDEFINED ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // + // This is not a "real" functionality, but it exists to give us a "nullopt" + // element we can return for cases when a DispatchKeySet contains no elements. + // You can think a more semantically accurate definition of DispatchKey is: + // + // using DispatchKey = std::optional + // + // and Undefined == nullopt. We didn't actually represent + // it this way because std::optional would take two + // words, when DispatchKey fits in eight bits. + + Undefined = 0, + + // Define an alias for Undefined to represent CatchAll (long term + // this will get eliminated, but for now it's convenient) + CatchAll = Undefined, + + // ~~~~~~~~~~~~~~~~~~~~~~~~~~ Functionality Keys ~~~~~~~~~~~~~~~~~~~~~~ // + // Every value in the enum (up to EndOfFunctionalityKeys) + // corresponds to an individual "functionality" that can be dispatched to. + // This is represented in the DispatchKeySet by assigning each of these enum + // values + // to each of the remaining (64 - len(BackendComponent)) bits. + // + // Most of these functionalities have a single handler assigned to them, + // making them "runtime keys". + // That map to a single slot in the runtime operator table. + // + // A few functionalities are allowed to be customizable per backend. + // See [Note: Per-Backend Functionality Dispatch Keys] for details. + + // See [Note: Per-Backend Functionality Dispatch Keys] + Dense, + + // Below are non-extensible backends. + // These are backends that currently don't have their own overrides for + // Autograd/Sparse/Quantized kernels, + // and we therefore don't waste space in the runtime operator table allocating + // space for them. + // If any of these backends ever need to customize, e.g., Autograd, then we'll + // need to add a DispatchKey::*Bit for them. + + // TODO: put this in BackendComponents + FPGA, // Xilinx support lives out of tree at + // https://gitlab.com/pytorch-complex/vitis_kernels + + Vulkan, // TODO: put this in BackendComponents + Metal, // TODO: put this in BackendComponents + + // See [Note: Per-Backend Functionality Dispatch Keys] + Quantized, + + // This backend is to support custom RNGs; it lets you go + // to a different kernel if you pass in a generator that is not a + // traditional CPUGeneratorImpl/CUDAGeneratorImpl. To make use of this + // key: + // 1) set it as a second parameter of at::Generator constructor call in + // the user-defined PRNG class. + // 2) use it as a dispatch key while registering custom kernels + // (templatized kernels specialized for user-defined PRNG class) + // intended for out of tree use; tested by aten/src/ATen/test/rng_test.cpp + CustomRNGKeyId, + + // TODO: Make Mkldnn a functionality key, so we can give it Meta + // support + // Here are backends which specify more specialized operators + // based on the layout of the tensor. Note that the sparse backends + // are one case where ordering matters: sparse multi-dispatches with + // the corresponding dense tensors, and must be handled before them. + MkldnnCPU, // registered at build/aten/src/ATen/RegisterMkldnnCPU.cpp + // NB: not to be confused with MKLDNN, which is Caffe2 only + + // See [Note: Per-Backend Functionality Dispatch Keys] + Sparse, + + SparseCsr, + + NestedTensor, + + // In some situations, it is not immediately obvious what the correct + // backend for function is, because the function in question doesn't + // have any "tensor" arguments. In this case, a BackendSelect function + // can be registered to implement the custom determination of the + // correct backend. + BackendSelect, + + Python, + + // Out-of-core key for Fake Tensor in torchdistx. + // See https://pytorch.org/torchdistx/latest/fake_tensor.html + // TODO: delete this in favor of Python-implemented fake tensor + Fake, + // See Note [Out-of-tree vmap+grad prototype]. The purpose of this key + // is to insert code after the "autograd subsystem" runs, so this key should + // be directly after ADInplaceOrView and all of the autograd keys. + FuncTorchDynamicLayerBackMode, + + // Alias and mutation removal. + // If some backends want to opt into only alias removal or only mutation + // removal, + // we can consider adding separate keys dedicated to those individual passes. + // See Note [Functionalization Pass In Core] for details. + Functionalize, + + // The named dispatch key is set for any tensors with named dimensions. + // Although we have a dispatch key for named tensors, for historical reasons, + // this dispatch key doesn't do any of the substantive functionality for named + // tensor (though, hypothetically, it could!) At the moment, it's just + // responsible for letting us give good error messages when operations + // don't support named tensors. + // + // NB: If you ever consider moving named tensor functionality into + // this dispatch key, note that it might be necessary add another dispatch + // key that triggers before composite operators, in case a composite operator + // has named dimension propagation that doesn't match that of its + // constituent parts. + // TODO: delete this once torchdim lands in functorch + Named, + + // The Conjugate dispatch key is set for any tensors that need to perform + // conjugation + // This is implemented at a dispatch level right before any backends run + Conjugate, + + // The Negative dispatch key is set for any tensors that need to perform + // negation + // This is implemented at a dispatch level right before any backends run + Negative, + + ZeroTensor, // registered at build/aten/src/ATen/RegisterZeroTensor.cpp + + // Note [ADInplaceOrView key] + // ADInplaceOrView key is used by inplace or view ops to register a kernel + // that does additional setup for future autograd computation. + // + // 1. For inplace ops this kernel does version bump + // 2. For view ops this kernel does `as_view` setup where we properly setup + // DifferentiableViewMeta on the view tensors. + // + // For other ops it's fallthrough kernel since there's no extra + // work to do. + // + // Note [Dream: skip VariableType kernel when requires_grad=false] + // + // In an ideal world where we can skip VariableType kernel for inputs + // with requires_grad=false, instead of a fallthrough kernel, we'll + // register a kernel shown below to all functional ops as well: + // torch::Tensor my_functional_op(...) { + // { + // // Note for every op in VariableType, you need to go through + // // `AutoDispatchBelowADInplaceOrView` guard exactly once to add the + // // key to TLS excluded set. If you don't go through it at all, + // // inplace/view ops called through `at::` inside your backend + // // kernel will dispatch to ADInplaceOrView kernels and do a lot + // // of extra work. + // at::AutoDispatchBelowADInplaceOrView guard; + // at::redispatch::my_functional_op(...); + // } + // } + // But this work is currently blocked since it adds an extra dispatch + // for all ops and it's non-trivial overhead at model level(a few percents). + // Thus our current approach takes advantage of the fact every kernel go + // through VariableType kernel first and pulls the + // `at::AutoDispatchBelowADInplaceOrView` guard of functional ops + // up to the `VariableType` kernel. Thus we only add the extra dispatch + // to view/inplace ops to minimize its perf impact to real models. + ADInplaceOrView, + // Note [Alias Dispatch Key : Autograd] + // All backends are oblivious to autograd; autograd is handled as a + // layer which happens on top of all backends. It inspects the autograd + // metadata of all inputs, determines what autograd metadata should be + // constructed by the output, and otherwise defers to the backend to + // actually do the numeric computation. Autograd contains + // the bulk of this logic. + + // Autograd is now an alias dispatch key which by default maps to all + // backend-specific autograd keys. + // Backend-specific allow backends to override the default kernel registered + // to Autograd key as needed. + // For example, XLA wants to define autograd for einsum directly. + // Registering a custom autograd implementation at the XLA key won't work + // because we process Autograd before XLA. This key has higher priority and + // gets processed first. You generally should NOT redispatch after handling + // autograd here (since that would result in execution of the Autograd + // operator, which you're trying to skip). In AutogradXLA implementations, + // you are responsible for handling autograd yourself, or deferring to other + // operators which support autograd. + + // Currently we only have backend-specific autograd keys for CPU/CUDA/XLA and + // reserved user-defined backends. All other in-tree backends share the + // AutogradOther key. We can add specific autograd key for those backends + // upon request. + AutogradOther, + + // See [Note: Per-Backend Functionality Dispatch Keys] + AutogradFunctionality, + + // NestedTensor is an example of something that isn't a "real backend" + // (because it mostly consists of redispatching kernels) + // but it would like to override autograd functionality in C++. + // We can handle cases like this by adding an extra functionality key + // exclusively for handling autograd for NestedTensor. + // lives out of tree at + // https://github.com/pytorch/nestedtensor + AutogradNestedTensor, + + Tracer, + + // TODO: make Autocast a functionality key + // Autocasting precedes VariableTypeId, to ensure casts are autograd-exposed + // and inputs are saved for backward in the post-autocast type. + AutocastCPU, + AutocastMTIA, + AutocastMAIA, + AutocastXPU, + AutocastIPU, + AutocastHPU, + AutocastXLA, + // AutocastXLA is only being used for TPUs. XLA GPUs continue to use + // AutocastCUDA. + AutocastMPS, + AutocastCUDA, + AutocastPrivateUse1, + + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~ WRAPPERS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // + // There are a number of alternative modes which may want to handle before + // autograd; for example, error checking, tracing, profiling or vmap. They + // go here. + + FuncTorchBatched, // See Note [Out-of-tree vmap+grad prototype] + + // Dispatch key for BatchedTensorImpl wrapping a nested tensor. + BatchedNestedTensor, + + FuncTorchVmapMode, // See Note [Out-of-tree vmap+grad prototype] + + // This is the dispatch key for BatchedTensorImpl, which is used to implement + // batching rules for vmap. + Batched, + + // When we are inside a vmap, all tensors dispatch on this key. + // See Note: [DispatchKey::VmapMode usage] for more details. + VmapMode, + + FuncTorchGradWrapper, // See Note [Out-of-tree vmap+grad prototype] + + // Out-of-core key for Deferred Module Initialization in torchdistx. + // See https://pytorch.org/torchdistx/latest/deferred_init.html + DeferredInit, + + // Used by Python key logic to know the set of tls on entry to the dispatcher + // This kernel assumes it is the top-most non-functorch-related DispatchKey. + // If you add a key above, make sure to update the fallback implementation for + // this. + PythonTLSSnapshot, + + // This key should be at the very top of the dispatcher + FuncTorchDynamicLayerFrontMode, // See Note [Out-of-tree vmap+grad prototype] + + // TESTING: This is intended to be a generic testing tensor type id. + // Don't use it for anything real; its only acceptable use is within a single + // process test. Use it by creating a TensorImpl with this DispatchKey, and + // then registering operators to operate on this type id. See + // aten/src/ATen/core/dispatch/backend_fallback_test.cpp for a usage example. + TESTING_ONLY_GenericWrapper, + + // TESTING: This is intended to be a generic testing tensor type id. + // Don't use it for anything real; its only acceptable use is within a ingle + // process test. Use it by toggling the mode on and off via + // TESTING_ONLY_tls_generic_mode_set_enabled and then registering operators + // to operate on this type id. See + // aten/src/ATen/core/dispatch/backend_fallback_test.cpp + // for a usage example + TESTING_ONLY_GenericMode, + + // This key is used for pre-dispatch tracing in make_fx. + // It has lower priority than the PythonDispatcher key + // because we use the PythonDispatcher to intercept the key from python, + // and avoid having to implement it in C++. + PreDispatch, + + // This is a bypass that allows you to skip running the C++ dispatcher + // entirely + PythonDispatcher, + + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FIN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // + EndOfFunctionalityKeys, // End of functionality keys. + +// ~~~~~~~~~~~~~~ "Dense" Per-Backend Dispatch keys ~~~~~~~~~~~~~~~~~~~~ // +// Here are backends which you think of as traditionally specifying +// how to implement operations on some device. + +#define DEFINE_PER_BACKEND_KEYS_FOR_BACKEND(n, prefix) prefix##n, + +#define DEFINE_PER_BACKEND_KEYS(fullname, prefix) \ + StartOf##fullname##Backends, \ + C10_FORALL_BACKEND_COMPONENTS( \ + DEFINE_PER_BACKEND_KEYS_FOR_BACKEND, prefix) \ + EndOf##fullname##Backends = prefix##Meta, + + C10_FORALL_FUNCTIONALITY_KEYS(DEFINE_PER_BACKEND_KEYS) + +#undef DEFINE_PER_BACKEND_KEYS +#undef DEFINE_PER_BACKEND_KEYS_FOR_BACKEND + + EndOfRuntimeBackendKeys = EndOfAutogradFunctionalityBackends, + + // ~~~~~~~~~~~~~~~~~~~~~~ Alias Dispatch Keys ~~~~~~~~~~~~~~~~~~~~~~~~~~ // + // Note [Alias Dispatch Keys] + // Alias dispatch keys are synthetic dispatch keys which map to multiple + // runtime dispatch keys. Alisa keys have precedence, but they are always + // lower precedence than runtime keys. You can register a kernel to an + // alias key, the kernel might be populated to the mapped runtime keys + // during dispatch table computation. + // If a runtime dispatch key has multiple kernels from alias keys, which + // kernel wins is done based on the precedence of alias keys (but runtime + // keys always have precedence over alias keys). + // Alias keys won't be directly called during runtime. + + // See Note [Alias Dispatch Key : Autograd] + Autograd, + CompositeImplicitAutograd, // registered at + // build/aten/src/ATen/RegisterCompositeImplicitAutograd.cpp + + // Note: The alias keyset for FuncTorchBatchedDecomposition is disjoint from + // all + // other alias keysets + // and so precedence order doesn't matter + FuncTorchBatchedDecomposition, // registered at + // build/aten/src/ATen/RegisterFuncTorchBatchedDecomposition.cpp + // Note: The alias keyset for CompositeImplicitAutogradNestedTensor is + // disjoint from all other alias keysets + CompositeImplicitAutogradNestedTensor, // registered at + // build/aten/src/ATen/RegisterCompositeImplicitAutogradNestedTensor.cpp + CompositeExplicitAutograd, // registered at + // build/aten/src/ATen/RegisterCompositeExplicitAutograd.cpp + // See Note [CompositeExplicitAutogradNonFunctional Key] + CompositeExplicitAutogradNonFunctional, // registered at + // build/aten/src/ATen/RegisterCompositeExplicitAutograd.cpp + + // Define an alias key to represent end of alias dispatch keys. + // If you add new alias keys after Autograd, please also update it here. + StartOfAliasKeys = Autograd, + EndOfAliasKeys = CompositeExplicitAutogradNonFunctional, // + + // ~~~~~~~~~~~~~~~~~~~~~~~~~ BC ALIASES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // + // The aliases exist for backwards compatibility reasons, they shouldn't + // be used + CPUTensorId = CPU, + CUDATensorId = CUDA, + DefaultBackend = CompositeExplicitAutograd, + PrivateUse1_PreAutograd = AutogradPrivateUse1, + PrivateUse2_PreAutograd = AutogradPrivateUse2, + PrivateUse3_PreAutograd = AutogradPrivateUse3, + Autocast = AutocastCUDA, +}; + +// Note [Private use DispatchKey] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// Private use tensor IDs are preallocated tensor type IDs for use in user +// applications. Similar to private use fields in HTTP, they can be used +// by end users for experimental or private applications, without needing +// to "standardize" the tensor ID (which would be done by submitting a PR +// to PyTorch to add your type ID). +// +// Private use tensor IDs are appropriate to use if you want to experiment +// with adding a new tensor type (without having to patch PyTorch first) or +// have a private, non-distributed application that needs to make use of a +// new tensor type. Private use tensor IDs are NOT appropriate to use for +// libraries intended to be distributed to further users: please contact +// the PyTorch developers to get a type ID registered in this case. +// +// We provide two classes of private user tensor id: regular DispatchKeys +// and Autograd DispatchKeys. DispatchKeys serve the role of ordinary "backend" +// DispatchKeys; if you were adding support for a new type of accelerator, you +// would use a backend DispatchKey, and ideally automatically reuse +// AutogradOther definitions already defined in PyTorch. AutogradPrivateUse +// DispatchKeys serve as "wrapper" DispatchKeys: they are only necessary for +// tensors that compose multiple internal tensors, and for cases when the +// built-in autograd formulas for operators are not appropriate. + +static_assert( + (static_cast(BackendComponent::EndOfBackendKeys) + + static_cast(DispatchKey::EndOfFunctionalityKeys)) <= 64, + "The BackendComponent and DispatchKey enums (below EndOfFunctionalityKeys)" + " both map to backend and functionality bits" + " into a 64-bit bitmask; you must have less than 64 total entries between them"); + +// Check if a DispatchKey is an alias mapping to other runtime keys. +constexpr bool isAliasDispatchKey(DispatchKey k) { + return k >= DispatchKey::StartOfAliasKeys && k <= DispatchKey::EndOfAliasKeys; +} + +// [Note: Per-Backend Functionality Dispatch Keys] +// Check if a DispatchKey is a per-backend functionality key +// Any functionalities that can be customized per-backend should be added here. +// These keys correspond to functionalities that can be customized individually +// per backend. While they only take up one bit in the `DispatchKeySet` bitset, +// they map to (# backends) slots in the operator table. +// Each of these keys also has a separate set of "runtime keys" in the dispatch +// key enum, per backend, which *do* map to the individual operator table slots. +// For example, the "Sparse" key maps to an individual bit in the +// DispatchKeySet, while `SparseCPU`, `SparseCUDA`, etc all map to individual +// slots in the runtime operator table. + +constexpr bool isPerBackendFunctionalityKey(DispatchKey k) { + if (k == DispatchKey::Dense || k == DispatchKey::Quantized || + k == DispatchKey::Sparse || k == DispatchKey::SparseCsr || + k == DispatchKey::AutogradFunctionality || + k == DispatchKey::NestedTensor) { + return true; + } else { + return false; + } +} + +// Note that this includes Undefined in the total count. +// BUT EndOfFunctionalityKeys is its own (placeholder) key. +// e.g. Undefined=0, Dense=1, Sparse=2, EndOfFunctionalityKeys=3. +// In the above example, there are 3 total functionality keys. +constexpr uint8_t num_functionality_keys = + static_cast(DispatchKey::EndOfFunctionalityKeys); + +constexpr uint8_t num_backends = + static_cast(BackendComponent::EndOfBackendKeys); + +// Note [No More Than 16 Backends] +// Search for this note to find places in the code where the "no more than 16 +// backends" invariant is baked in. +static_assert( + static_cast(BackendComponent::EndOfBackendKeys) <= 16, + "BackendComponent currently only supports <= 16 backends. If we really need to extend this, \ +there are a few places where this invariant is baked in"); + +constexpr uint8_t numPerBackendFunctionalityKeys() { + uint8_t count = 0; + for (uint8_t k = 0; k <= num_functionality_keys; ++k) { + if (isPerBackendFunctionalityKey(static_cast(k))) + ++count; + } + return count; +} + +#if defined(C10_MOBILE_TRIM_DISPATCH_KEYS) +// See [Note: Trimmed Mobile Dispatch Keys] +constexpr uint16_t num_runtime_entries = 8; +#else +constexpr uint16_t num_runtime_entries = num_functionality_keys + + (numPerBackendFunctionalityKeys() * (num_backends - 1)); +#endif + +// See Note [No More Than 16 Backends] +constexpr uint16_t full_backend_mask = + (static_cast(1) << num_backends) - 1; + +C10_API const char* toString(DispatchKey /*t*/); +C10_API const char* toString(BackendComponent /*t*/); +C10_API std::ostream& operator<<(std::ostream& /*str*/, DispatchKey /*rhs*/); +C10_API std::ostream& operator<<( + std::ostream& /*str*/, + BackendComponent /*rhs*/); + +C10_API DispatchKey getAutogradKeyFromBackend(BackendComponent k); + +// Parses a string into a dispatch key. +// If the string cannot be correctly parsed, throws an exception. +C10_API c10::DispatchKey parseDispatchKey(const std::string& k); + +// These are some convenience identifiers for dispatch keys which are +// shorter to type than their long counterparts. Note that some of these +// dispatch keys directly correspond to DeviceType; and most APIs that +// accept DispatchKey also accept DeviceType; e.g., +// torch::dispatch(torch::kCPU, ...) is also valid. +constexpr DispatchKey kAutograd = DispatchKey::Autograd; + +// See Note [The Ordering of Per-Backend Dispatch Keys Matters!] +// This function relies on the invariant that the dispatch keys between +// StartOfDenseBackends and EndOfRuntimeBackendKeys are ordered by backend +// in the same order as `BackendComponent`. +constexpr BackendComponent toBackendComponent(DispatchKey k) { + if (k >= DispatchKey::StartOfDenseBackends && + k <= DispatchKey::EndOfDenseBackends) { + return static_cast( + static_cast(k) - + static_cast(DispatchKey::StartOfDenseBackends)); + } else if ( + k >= DispatchKey::StartOfQuantizedBackends && + k <= DispatchKey::EndOfQuantizedBackends) { + return static_cast( + static_cast(k) - + static_cast(DispatchKey::StartOfQuantizedBackends)); + } else if ( + k >= DispatchKey::StartOfSparseBackends && + k <= DispatchKey::EndOfSparseBackends) { + return static_cast( + static_cast(k) - + static_cast(DispatchKey::StartOfSparseBackends)); + } else if ( + k >= DispatchKey::StartOfSparseCsrBackends && + k <= DispatchKey::EndOfSparseCsrBackends) { + return static_cast( + static_cast(k) - + static_cast(DispatchKey::StartOfSparseCsrBackends)); + } else if ( + k >= DispatchKey::StartOfNestedTensorBackends && + k <= DispatchKey::EndOfNestedTensorBackends) { + return static_cast( + static_cast(k) - + static_cast(DispatchKey::StartOfNestedTensorBackends)); + } else if ( + k >= DispatchKey::StartOfAutogradFunctionalityBackends && + k <= DispatchKey::EndOfAutogradFunctionalityBackends) { + return static_cast( + static_cast(k) - + static_cast( + DispatchKey::StartOfAutogradFunctionalityBackends)); + } else { + return BackendComponent::InvalidBit; + } +} + +constexpr DispatchKey toFunctionalityKey(DispatchKey k) { + if (k <= DispatchKey::EndOfFunctionalityKeys) { + return k; + } else if (k <= DispatchKey::EndOfDenseBackends) { + return DispatchKey::Dense; + } else if (k <= DispatchKey::EndOfQuantizedBackends) { + return DispatchKey::Quantized; + } else if (k <= DispatchKey::EndOfSparseBackends) { + return DispatchKey::Sparse; + } else if (k <= DispatchKey::EndOfSparseCsrBackends) { + return DispatchKey::SparseCsr; + } else if (k <= DispatchKey::EndOfNestedTensorBackends) { + return DispatchKey::NestedTensor; + } else if (k <= DispatchKey::EndOfAutogradFunctionalityBackends) { + return DispatchKey::AutogradFunctionality; + } else { + return DispatchKey::Undefined; + } +} + +BackendComponent toBackendComponent(DeviceType device_type); + +// Given (DispatchKey::Dense, BackendComponent::CUDABit), returns +// DispatchKey::CUDA. +// See Note [The Ordering of Per-Backend Dispatch Keys Matters!] +// This function relies on the invariant that the dispatch keys between +// StartOfDenseBackends and EndOfRuntimeBackendKeys are ordered by backend +// in the same order as `BackendComponent`. +constexpr DispatchKey toRuntimePerBackendFunctionalityKey( + DispatchKey functionality_k, + BackendComponent backend_k) { + if (functionality_k == DispatchKey::Dense) { + return static_cast( + static_cast(DispatchKey::StartOfDenseBackends) + + static_cast(backend_k)); + } + if (functionality_k == DispatchKey::Sparse) { + return static_cast( + static_cast(DispatchKey::StartOfSparseBackends) + + static_cast(backend_k)); + } + if (functionality_k == DispatchKey::SparseCsr) { + return static_cast( + static_cast(DispatchKey::StartOfSparseCsrBackends) + + static_cast(backend_k)); + } + if (functionality_k == DispatchKey::Quantized) { + return static_cast( + static_cast(DispatchKey::StartOfQuantizedBackends) + + static_cast(backend_k)); + } + if (functionality_k == DispatchKey::NestedTensor) { + return static_cast( + static_cast(DispatchKey::StartOfNestedTensorBackends) + + static_cast(backend_k)); + } + if (functionality_k == DispatchKey::AutogradFunctionality) { + return static_cast( + static_cast( + DispatchKey::StartOfAutogradFunctionalityBackends) + + static_cast(backend_k)); + } + return DispatchKey::Undefined; +} + +} // namespace c10 + +namespace torch { +// Expose the constant, but not the TYPE (DispatchKey is an implementation +// detail!) +// NOLINTNEXTLINE(misc-unused-using-decls) +using c10::kAutograd; +} // namespace torch + +// NB: You really shouldn't use this instance; this enum is guaranteed +// to be pretty small so a regular array should be acceptable. +namespace std { +template <> +struct hash { + typedef size_t result_type; + typedef c10::DispatchKey argument_type; + + size_t operator()(c10::DispatchKey x) const { + return static_cast(x); + } +}; +} // namespace std + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/DispatchKeySet.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/DispatchKeySet.h new file mode 100644 index 0000000000000000000000000000000000000000..ec3aff4e0c2295b2490cd29d30aa1117e6bb0441 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/DispatchKeySet.h @@ -0,0 +1,977 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-enum") + +namespace c10 { + +struct FunctionalityOffsetAndMask { + // empty constructor shouldn't be used; only needed to initialize + // the array before populating it. + FunctionalityOffsetAndMask() = default; + FunctionalityOffsetAndMask(uint16_t offset, uint16_t mask) + : offset(offset), mask(mask) {} + // This needs to big enough to cover the size of the operator table. + uint16_t offset{}; + // See Note [No More Than 16 Backends] + // This mask needs to be big enough to mask all of the backend bits. + // We probably don't ever want to have more than 16 backend bits, so uint16_t + // should be enough. + uint16_t mask{}; +}; +static_assert( + c10::num_runtime_entries < 65536, + "The dispatcher currently only supports up to 2^16 runtime entries"); + +C10_API std::array +initializeFunctionalityOffsetsAndMasks(); + +C10_ALWAYS_INLINE static const std:: + array& + offsetsAndMasks() { + static auto offsets_and_masks_ = initializeFunctionalityOffsetsAndMasks(); + return offsets_and_masks_; +} + +// A representation of a set of DispatchKeys. A DispatchKeySet contains both +// "functionality" bits and "backend bits", and every tensor holds its own +// DispatchKeySet. The Dispatcher implements multiple dispatch by grabbing the +// keyset on every input tensor, or’ing them together, and dispatching to a +// specific piece of functionality. The functionality bits are *ordered*. When +// multiple functionality bits are set, we use the highest priority +// functionality. Similarly, multiple backend bits can theoretically be set if +// you call an operator with multiple tensors from difference devices (e.g. CPU +// and CUDA), although support for mixed device dispatch is limited (the only +// kernels that gracefully handle mixed device inputs for now are cuda kernels +// that take in a scalar cpu tensor). + +// A representation of a set of DispatchKeys. A tensor may have multiple +// tensor type ids, e.g., a Variable tensor can also be a CPU tensor; the +// DispatchKeySet specifies what type ids apply. The internal representation is +// as a 64-bit bit set (this means only 64 tensor type ids are supported). +// +// As mentioned above, DispatchKeys are ordered; thus, we can ask questions like +// "what is the highest priority DispatchKey in the set"? (The set itself is +// not ordered; two sets with the same ids will always have the ids ordered in +// the same way.) +// +// Note [DispatchKeySet Internal Representation] +// Internally, dispatch keys are packed into 64-bit DispatchKeySet objects +// that get passed around at runtime. +// However, there isn't necessarily a 1-to-1 mapping between bits in the keyset +// and individual dispatch keys. +// +// First: why do we have this distinction, and why not map every dispatch key +// directly to a bit? This is mostly because we have several types of +// functionalities that different backends would like to customize. For example, +// we have: +// - "Dense": CPU, CUDA, XLA, ... (~12 keys) +// - "Sparse": SparseCPU, SparseCUDA, ... +// - "SparseCsr": SparseCsrCPU, SparseCsrCUDA, ... +// - "Quantized": QuantizedCPU, QuantizedCUDA, QuantizedXLA, ... +// - "Autograd": AutogradCPU, AutogradCUDA, Autograd XLA, ... +// The problem is that total number of keys grows quadratically with [# +// backends] x [# functionalities], making it very difficult to map each key +// directly to a bit in a bitset without dramatically increasing the size of the +// bitset over time. +// +// The two enums (BackendComponent and DispatchKey) can be divided roughly into +// 5 categories. +// +// (1) "Building block" keys +// (a) backends: Everything in the BackendComponent enum (e.g. CPUBit, +// CUDABit) (b) functionalities: (per-backend) functionality-bit DispatchKeys +// (e.g. AutogradFunctionality, SparseCsr, Sparse, Dense) +// (2) "Runtime" keys +// (a) "non-customizable backends" (e.g. FPGA) +// (b) "non-customizable functionalities" (e.g. Functionalize) +// (c) "per-backend instances of customizable functionalities" (e.g. CPU, +// SparseCPU, AutogradCPU) +// (3) "Alias" DispatchKeys (see Note [Alias Dispatch Keys]) +// +// (1) Building block keys always correspond to individual bits in a +// DispatchKeySet. They can also be combined in a DispatchKeySet to form actual +// runtime keys. e.g. +// auto dense_cpu_ks = DispatchKeySet({DispatchKey::CPUBit, +// DispatchKey::Dense}); +// // The keyset has the runtime dense-cpu key. +// dense_cpu_ks.has(DispatchKey::CPU); +// // And it contains the building block keys too. +// dense_cpu_ks.has(DispatchKey::CPUBit); +// dense_cpu_ks.has(DispatchKey::Dense); +// +// Not every backend and not every functionality counts as a "building block +// key". This is mostly to give us more levers to pull in the design space. +// Backend keys and functionality keys that count as "building blocks" will +// contribute to a full cross product of functionality that can be overridden. +// +// For example, right now we have at least 12 "backend" building +// blocks (CPU, CUDA, XLA, ...) and at least 5 "functionality" +// building blocks (Dense, Sparse, SparseCsr, Quantized, +// AutogradFunctionality, ...). These keys together allow every +// dispatcher operator to be customized in up to 12*4 different +// ways. Each of those requires a slot in the operator table of every +// dispatcher operator. Not every piece of functionality necessarily +// needs to be customizable per-backend, and not every backend +// necessarily needs to be able to customize every type of +// functionality. +// +// +// (2) Every runtime key corresponds directly to a slot in an operator's runtime +// dispatch table, and you can directly register kernels to a runtime dispatch +// key. +// +// For per-backend functionalities like "Dense" or "AutogradFunctionality", +// you can think of the corresponding runtime dispatch keys as "instances" of +// that functionality, per backend. E.g. "CPU", "CUDA", "XLA", etc. are all +// runtime instances of the "Dense" building block key. + +// (2a) and (2b) are represented identically in the DispatchKeySet logic: +// - backend-agnostic functionalities (e.g. FuncTorchBatched) are NOT +// customizable per backend. +// In order to do so, we'd need to promote it to a per-backend functionality +// "building block" key. +// - non-customizable backends (e.g. FPGA) can NOT customize existing +// functionality like Sparse, Autograd, etc. +// In order to do so, we'd need to promote it to a backend "building block" +// key. +// +// In both cases, these keys directly correspond to runtime slots in the +// operator table. +// +// +// (3) "Alias" keys +// See Note [Alias Dispatch Keys] +// +// Final note: for anyone making future changes to the Dispatcher + +// DispatchKeySet internals, there's a closed PR with a basic +// python-implementation of the Dispatcher that might be useful in quickly +// testing out and validating changes. See it at +// https://github.com/pytorch/pytorch/pull/68743 + +// An undefined tensor is one with an empty tensor type set. +class DispatchKeySet final { + public: + enum Full { FULL }; + enum FullAfter { FULL_AFTER }; + enum Raw { RAW }; + + // NB: default constructor representation as zero is MANDATORY as + // use of DispatchKeySet in TLS requires this. + constexpr DispatchKeySet() = default; + + constexpr DispatchKeySet(Full /*unused*/) + : repr_((1ULL << (num_backends + num_functionality_keys - 1)) - 1) {} + + constexpr DispatchKeySet(FullAfter /*unused*/, DispatchKey t) + // LSB after t are OK, but not t itself. + // "functionalities" have a notion of ordering (e.g. Autograd > Sparse > + // Quantized > Dense). But backends don't really have an ordering. + // Therefore, we're enforcing that FullAfter can only be used on + // "functionality" keys. + : repr_( + (1ULL + << (num_backends + static_cast(toFunctionalityKey(t)) - + 1)) - + 1) { + *this = add(DispatchKey::PythonDispatcher); + } + + // Public version of DispatchKeySet(uint64_t) API; external users + // must be explicit when they do this! + constexpr DispatchKeySet(Raw /*unused*/, uint64_t x) : repr_(x) {} + + constexpr explicit DispatchKeySet(BackendComponent k) { + if (k == BackendComponent::InvalidBit) { + repr_ = 0; + } else { + repr_ = 1ULL << (static_cast(k) - 1); + } + } + + constexpr explicit DispatchKeySet(DispatchKey k) { + // NOLINTNEXTLINE(bugprone-branch-clone) + if (k == DispatchKey::Undefined) { + // Case 1: handle Undefined specifically + repr_ = 0; + } else if (k <= DispatchKey::EndOfFunctionalityKeys) { + // Case 2: handle "functionality-only" keys + // These keys have a functionality bit set, but no backend bits + // These can technically be either: + // - valid runtime keys (e.g. DispatchKey::AutogradOther, + // DispatchKey::FuncTorchBatched, etc) + // - "building block" keys that aren't actual runtime keys (e.g. + // DispatchKey::Dense or Sparse) + uint64_t functionality_val = 1ULL + << (num_backends + static_cast(k) - 1); + repr_ = functionality_val; + } else if (k <= DispatchKey::EndOfRuntimeBackendKeys) { + // Case 3: "runtime" keys that have a functionality bit AND a backend bit. + // First compute which bit to flip for the functionality. + auto functionality_k = toFunctionalityKey(k); + // The - 1 is because Undefined is technically a "functionality" that + // doesn't show up in the bitset. So e.g. Dense is technically the second + // functionality, but the lowest functionality bit. + uint64_t functionality_val = 1ULL + << (num_backends + static_cast(functionality_k) - 1); + + // then compute which bit to flip for the backend + // Case 4a: handle the runtime instances of "per-backend functionality" + // keys For example, given DispatchKey::CPU, we should set: + // - the Dense functionality bit + // - the CPUBit backend bit + // first compute which bit to flip for the backend + auto backend_k = toBackendComponent(k); + uint64_t backend_val = backend_k == BackendComponent::InvalidBit + ? 0 + : 1ULL << (static_cast(backend_k) - 1); + repr_ = functionality_val + backend_val; + } else { + // At this point, we should have covered every case except for alias keys. + // Technically it would be possible to add alias dispatch keys to a + // DispatchKeySet, but the semantics are a little confusing and this + // currently isn't needed anywhere. + repr_ = 0; + } + } + + constexpr uint64_t keys_to_repr(std::initializer_list ks) { + uint64_t repr = 0; + for (auto k : ks) { + repr |= DispatchKeySet(k).repr_; + } + return repr; + } + + constexpr uint64_t backend_bits_to_repr( + std::initializer_list ks) { + uint64_t repr = 0; + for (auto k : ks) { + repr |= DispatchKeySet(k).repr_; + } + return repr; + } + + explicit constexpr DispatchKeySet(std::initializer_list ks) + : repr_(keys_to_repr(ks)) {} + + explicit constexpr DispatchKeySet(std::initializer_list ks) + // Note: for some reason, putting this logic directly in the constructor + // appears to fail to compile on CUDA 10.1. + // See an example internal failure at + // https://www.internalfb.com/intern/skycastle/run/76561193669136035/artifact/actionlog.76561193742069401.stderr + : repr_(backend_bits_to_repr(ks)) {} + + // Test if a DispatchKey is in the set + inline bool has(DispatchKey t) const { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(t != DispatchKey::Undefined); + return has_all(DispatchKeySet(t)); + } + constexpr bool has_backend(BackendComponent t) const { + return has_all(DispatchKeySet(t)); + } + + // Test if a DispatchKey is in the set + // Given a DispatchKeySet of functionality keys and (potentially) backend + // keys, tests if all of them are in the current set. + constexpr bool has_all(DispatchKeySet ks) const { + return static_cast((repr_ & ks.repr_) == ks.repr_); + } + + // Given a DispatchKeySet of functionality keys and (potentially) backend + // keys, tests if any of them are in the current set. This could technically + // be pretty easily implemented using has(). It is strictly a perf + // optimization though. There are many places in the code base where we want + // to test for multiple functionality keys together. HOWEVER, runtime + // per-backend functionality keys aren't allowed to be used with this + // function, because you can end up with weird results. e.g. + // DispatchKeySet(DispatchKey::AutogradCPU).has_any(DispatchKeySet(DispatchKey::CPU)) + // would return true. + inline bool has_any(DispatchKeySet ks) const { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + // Either there are no backend bits in the input keyset + ((ks.repr_ & full_backend_mask) == 0) || + // or there are no per-backend-functionality bits + // See [Note: Per-Backend Functionality Dispatch Keys] + ((ks & + DispatchKeySet({ + DispatchKey::Dense, + DispatchKey::Quantized, + DispatchKey::Sparse, + DispatchKey::SparseCsr, + DispatchKey::AutogradFunctionality, + }) + .repr_) == 0)); + return static_cast((repr_ & ks.repr_) != 0); + } + // Test if DispatchKeySet is a superset of ks. + bool isSupersetOf(DispatchKeySet ks) const { + return (repr_ & ks.repr_) == ks.repr_; + } + // Perform set union + constexpr DispatchKeySet operator|(DispatchKeySet other) const { + return DispatchKeySet(repr_ | other.repr_); + } + // Perform set intersection + constexpr DispatchKeySet operator&(DispatchKeySet other) const { + return DispatchKeySet(repr_ & other.repr_); + } + // Compute the set difference self - other, + // but ONLY for the functionality keys. + // Any backend bits set on self will remain unchanged. + // See Note [Removing keys from DispatchKeySet Only Affects Functionality + // Keys] + constexpr DispatchKeySet operator-(DispatchKeySet other) const { + return DispatchKeySet(repr_ & (full_backend_mask | ~other.repr_)); + } + + // Compute self ^ other + constexpr DispatchKeySet operator^(DispatchKeySet other) const { + return DispatchKeySet(repr_ ^ other.repr_); + } + bool operator==(DispatchKeySet other) const { + return repr_ == other.repr_; + } + bool operator!=(DispatchKeySet other) const { + return repr_ != other.repr_; + } + // Add a DispatchKey to the DispatchKey set. Does NOT mutate, + // returns the extended DispatchKeySet! + [[nodiscard]] constexpr DispatchKeySet add(DispatchKey t) const { + return *this | DispatchKeySet(t); + } + [[nodiscard]] constexpr DispatchKeySet add(DispatchKeySet ks) const { + return *this | ks; + } + + // Remove a DispatchKey from the DispatchKey set. + // This is generally not an operation you should be doing + // (it's used to implement the printing overload, operator<<) + // + // Note [Removing keys from DispatchKeySet Only Affects Functionality Keys] + // Only functionality bits are allowed to be removed from a keyset. + // For now, we're only allowing removal of "functionality bits" from the + // keyset, which is specifically needed by the fallthrough key calculation + // logic. Why is removing backend bits problematic? Consider this example: + // + // DispatchKeySet([DispatchKey.CPU, DispatchKey.AutogradCUDA, + // DispatchKey.CUDA]).remove(DispatchKey.AutogradCUDA) + // DispatchKeySet([DispatchKey.CPU, + // DispatchKey.AutogradCUDA]).remove(DispatchKey.AutogradCUDA) + // + // What do we want to happen? + // Technically, we'd like it to be true that after removal, + // the first keyset still has the CUDA dispatch key while the second doesn't. + // Unfortunately there's no way to represent that, because the two keysets are + // represented the same way internally: functionality bits: Autograd, Dense + // backend bits: CPU, CUDA + // + // Instead, remove(DispatchKey.AutogradCPU) will only remove the "Autograd" + // bit from the bitset. + [[nodiscard]] constexpr DispatchKeySet remove(DispatchKey t) const { + return DispatchKeySet( + repr_ & ~(DispatchKeySet(t).repr_ & ~full_backend_mask)); + } + // You're allowed to remove a backend bit from a DispatchKeySet, + // but you have to be explicit about it (remove_backend() instead of + // remove()). + constexpr DispatchKeySet remove_backend(BackendComponent b) const { + return DispatchKeySet(repr_ & ~(DispatchKeySet(b).repr_)); + } + // Is the set empty? (AKA undefined tensor) + bool empty() const { + return repr_ == 0; + } + uint64_t raw_repr() const { + return repr_; + } + + static DispatchKeySet from_raw_repr(uint64_t x) { + return DispatchKeySet(RAW, x); + } + + DispatchKey highestFunctionalityKey() const { + auto functionality_idx = indexOfHighestBit(); + // This means that none of the functionality bits were set. + if (functionality_idx < num_backends) + return DispatchKey::Undefined; + // The first num_backend bits in the keyset don't correspond to real + // dispatch keys. + return static_cast(functionality_idx - num_backends); + } + + // This is similar like toBackendComponent(DispatchKey), but less restrictive. + // toBackendComponent() errors out if the key that it was passed has no + // backend bits, which is useful for error checking. We need a version of that + // here that can also handle "fake" backends like FPGA, because they need to + // map to the AutogradOther key. For those backends, we return + // BackendComponent::InvalidBit. + BackendComponent highestBackendKey() const { + // mask to mask out functionality bits + auto backend_idx = + DispatchKeySet(repr_ & full_backend_mask).indexOfHighestBit(); + // all zeros across the backend bits means that no backend bits are set. + if (backend_idx == 0) + return BackendComponent::InvalidBit; + return static_cast(backend_idx); + } + + // returns the DispatchKey of highest priority in the set. + DispatchKey highestPriorityTypeId() const { + auto functionality_k = highestFunctionalityKey(); + if (isPerBackendFunctionalityKey(functionality_k)) { + return toRuntimePerBackendFunctionalityKey( + functionality_k, highestBackendKey()); + } + return functionality_k; + } + + // Returns the index of the most-significant bit in the keyset. + // This is used to as part of the calculation into the operator table to get: + // - the highest "functionality" bit in the keyset. + // - the highest "backend" bit in the keyset. + uint8_t indexOfHighestBit() const { + return 64 - llvm::countLeadingZeros(repr_); + } + +#if defined(C10_MOBILE_TRIM_DISPATCH_KEYS) + // [Note: Trimmed Mobile Dispatch Keys] + /** + * The method below maps the dispatch key in the enum DispatchKey to an + * integer index in the dispatchTable_ array in OperatorEntry. The array + * is trimmed for mobile to reduce peak memory usage since it's + * unnecessary to reserve additional space for dispatch keys that will + * never be used on mobile. + */ + int getDispatchTableIndexForDispatchKeySet() const { + auto dk = highestPriorityTypeId(); + switch (dk) { + case DispatchKey::Undefined: + return 0; + case DispatchKey::CPU: + return 1; + case DispatchKey::QuantizedCPU: + return 2; + case DispatchKey::SparseCPU: + return 3; + case DispatchKey::BackendSelect: + return 4; + case DispatchKey::ADInplaceOrView: + return 5; + case DispatchKey::AutogradOther: + return 6; + case DispatchKey::AutogradCPU: + return 7; + default: + return -1; + } + } +#else + // returns the index in the operator table of highest priority key in the the + // keyset Note that we could in theory implement this using + // highestPriorityTypeId(), but this code is very hotpath and we can do it + // faster without it. + int getDispatchTableIndexForDispatchKeySet() const { + auto functionality_idx = + DispatchKeySet(repr_ >> num_backends).indexOfHighestBit(); + auto offset_and_mask = offsetsAndMasks()[functionality_idx]; + // Mask the functionality bits out first, then right-shift by 1. + // right-shifting by 1 because everything is zero-indexed. + // E.g. 000001 (CPU) should give us an offset of 0, 000010 (CUDA) should + // give us an offset of 1, etc. + auto backend_idx = + DispatchKeySet((repr_ & offset_and_mask.mask) >> 1).indexOfHighestBit(); + return offset_and_mask.offset + backend_idx; + } +#endif + + // returns the "index" of the highest priority backend in the keyset. + // This is pretty similar to getBackendKey(), but: + // - It's hotpath code (part of the runtime bitset calculation) + // - I's returns an integer index, not an enum value + // - Everything is shifted to the right by 1. + // BackendComponent::InvalidBit is technically the lowest enum value, + // but it isn't included in the runtime table. So CPUBit = 1, CUDABit = 2, + // etc. + uint64_t getBackendIndex() const { + return DispatchKeySet((repr_ & full_backend_mask) >> 1).indexOfHighestBit(); + } + + private: + constexpr DispatchKeySet(uint64_t repr) : repr_(repr) {} + uint64_t repr_ = 0; + + public: + // STL iterator for DispatchKeySet. Iterates through all runtime DispatchKeys + // in the set. The iterator is only invalidated by the destruction of the + // underlying DispatchKeySet as the iterator stores a pointer to the raw + // representation of the DispatchKeySet. Note: When we encounter a per-backend + // functionality (e.g. Dense or Sparse), we will iterate through EVERY backend + // in the keyset, for that functionality. For example, if the next + // functionality key to iterate over is Autograd, and the backend bits in the + // keyset correspond to [BackendComponent::CPUBit, BackendComponent::CUDABit], + // then the next two keys we return will be DispatchKey::AutogradCPU, + // DispatchKey::AutogradCUDA (CPU first because it has lower precedence than + // CUDA in DispatchKey.h). + class iterator { + public: + using self_type = iterator; + using iterator_category = std::input_iterator_tag; + using value_type = DispatchKey; + using difference_type = ptrdiff_t; + using reference = value_type&; + using pointer = value_type*; + // final mask value should mask out the entire keyset + static const uint8_t end_iter_mask_val = + num_backends + num_functionality_keys; + // final key value should be the last DispatchKey + static const uint8_t end_iter_key_val = num_functionality_keys; + + // current_dispatchkey_idx_ will iterate through all functionality bits. + // current_backendcomponent_idx_ will iterate through all backend bits. + explicit iterator( + const uint64_t* data_ptr, + uint8_t next_functionality = num_backends, + uint8_t next_backend = 0) + : data_ptr_(data_ptr), + next_functionality_(next_functionality), + next_backend_(next_backend), + // These are in an invalid state at construction time, and set by the + // first increment call + current_dispatchkey_idx_(end_iter_key_val), + current_backendcomponent_idx_(end_iter_key_val) { + // Go to the first key in the set + TORCH_INTERNAL_ASSERT( + next_functionality_ >= num_backends, + "num_backends=", + static_cast(num_backends), + "next_functionality_=", + static_cast(next_functionality_)); + ++(*this); + } + + C10_API self_type& operator++(); + + self_type operator++(int) { + self_type previous_iterator = *this; + ++(*this); + return previous_iterator; + } + + bool operator==(const self_type& rhs) const { + return next_functionality_ == rhs.next_functionality_ && + current_dispatchkey_idx_ == rhs.current_dispatchkey_idx_ && + next_backend_ == rhs.next_backend_ && + current_backendcomponent_idx_ == rhs.current_backendcomponent_idx_; + } + bool operator!=(const self_type& rhs) const { + return next_functionality_ != rhs.next_functionality_ || + current_dispatchkey_idx_ != rhs.current_dispatchkey_idx_ || + next_backend_ != rhs.next_backend_ || + current_backendcomponent_idx_ != rhs.current_backendcomponent_idx_; + } + DispatchKey operator*() const { + auto functionality_key = + static_cast(current_dispatchkey_idx_); + if (isPerBackendFunctionalityKey(functionality_key)) { + auto next_key = toRuntimePerBackendFunctionalityKey( + functionality_key, + static_cast(current_backendcomponent_idx_)); + // We expect all of the Dense, Sparse, Quantized, and Autograd keys to + // be ordered the same way with respect to their backends + TORCH_INTERNAL_ASSERT( + toBackendComponent(next_key) == + static_cast(current_backendcomponent_idx_), + "Tried to map functionality key ", + toString(functionality_key), + " and backend bit ", + toString( + static_cast(current_backendcomponent_idx_)), + " to a runtime key, but ended up with ", + toString(next_key), + ". This can happen if the order of the backend dispatch keys in DispatchKey.h isn't consistent.", + " Please double check that enum for inconsistencies."); + return next_key; + } else { + return functionality_key; + } + } + + private: + const uint64_t* data_ptr_; + uint8_t next_functionality_; + uint8_t next_backend_; + uint8_t current_dispatchkey_idx_; + uint8_t current_backendcomponent_idx_; + }; + + public: + // Returns iterator to the first key in the set. If no keys are in the + // set, then will return the end iterator. + iterator begin() const { + return iterator(&repr_); + } + + // We do not need to iterate beyond EndOfFunctionalityKeys so we will treat + // this as the end iterator. + iterator end() const { + return iterator(&repr_, iterator::end_iter_mask_val); + } +}; + +C10_API std::string toString(DispatchKeySet /*ts*/); +C10_API std::ostream& operator<<(std::ostream& /*os*/, DispatchKeySet /*ts*/); + +inline int getDispatchTableIndexForDispatchKey(DispatchKey k) { + return DispatchKeySet(k).getDispatchTableIndexForDispatchKeySet(); +} + +// Alias key DispatchKey::Autograd maps to +// (autograd_dispatch_keyset x full_backend_mask) +// NB: keys in this set also get associated with CompositeImplicitAutograd +// +// Note [autograd_dispatch_keyset Does Not Include Backend Bits] +// We don't want to include any backend bits (BackendComponent::CPUBit, etc) +// directly in autograd_dispatch_keyset. +// Why? keysets like autograd_dispatch_keyset are commonly used to remove +// autograd keys from a DispatchKeySet throughout the code base. However, you +// are only allowed to remove functionality bits from a keyset, not backend +// bits. See Note [Removing keys from DispatchKeySet Only Affects Functionality +// Keys] for details. To be consistent and avoid confusion, we're explicitly +// setting up autograd_dispatch_keyset to not have any backend bits. +constexpr DispatchKeySet autograd_dispatch_keyset = DispatchKeySet({ + DispatchKey::AutogradFunctionality, + DispatchKey::AutogradOther, + DispatchKey::AutogradNestedTensor, +}); + +constexpr DispatchKeySet autocast_dispatch_keyset = DispatchKeySet({ + DispatchKey::AutocastCPU, + DispatchKey::AutocastMPS, + DispatchKey::AutocastCUDA, + DispatchKey::AutocastXPU, + DispatchKey::AutocastIPU, + DispatchKey::AutocastHPU, + DispatchKey::AutocastXLA, + DispatchKey::AutocastPrivateUse1, + DispatchKey::AutocastMTIA, + DispatchKey::AutocastMAIA, +}); + +// See Note [TLS Initialization] +constexpr DispatchKeySet default_included_set = DispatchKeySet({ + DispatchKey::BackendSelect, + DispatchKey::ADInplaceOrView, +}); + +constexpr DispatchKeySet default_excluded_set = DispatchKeySet({ + DispatchKey::AutocastCPU, + DispatchKey::AutocastMPS, + DispatchKey::AutocastCUDA, + DispatchKey::AutocastXPU, + DispatchKey::AutocastIPU, + DispatchKey::AutocastHPU, + DispatchKey::AutocastXLA, + DispatchKey::AutocastPrivateUse1, + DispatchKey::AutocastMTIA, + DispatchKey::AutocastMAIA, +}); + +constexpr DispatchKeySet autograd_dispatch_keyset_with_ADInplaceOrView = + autograd_dispatch_keyset | DispatchKeySet(DispatchKey::ADInplaceOrView); + +constexpr DispatchKeySet python_ks = DispatchKeySet({ + DispatchKey::Python, + DispatchKey::PythonTLSSnapshot, +}); + +constexpr DispatchKeySet sparse_ks = DispatchKeySet(DispatchKey::Sparse); + +constexpr DispatchKeySet sparse_csr_ks = DispatchKeySet(DispatchKey::SparseCsr); + +constexpr DispatchKeySet mkldnn_ks = DispatchKeySet(DispatchKey::MkldnnCPU); + +// backend dispatch keys that map to DispatchKey::AutogradOther +// NB: keys in this set also get associated with CompositeImplicitAutograd +constexpr DispatchKeySet autogradother_backends = + DispatchKeySet( + // HIP and VE aren't in this list: they now have their own backend bits + // which means that they can now have their own Autograd keys. + // Technically, HIP will now redispatch to its own custom AutogradHIP + // slot in the runtime table. + {DispatchKey::FPGA, + DispatchKey::Vulkan, + DispatchKey::Metal, + DispatchKey::CustomRNGKeyId, + DispatchKey::MkldnnCPU, + // Sparse and Quantized backends also live here. + DispatchKey::Sparse, + DispatchKey::SparseCsr, + DispatchKey::Quantized}) + // Including the backend bits because this keyset is used during op + // registration, which requires looping over all runtime autogradother + // backend keys. + | DispatchKeySet(DispatchKeySet::RAW, full_backend_mask); + +// The set of dispatch keys that come after autograd +// n.b. this relies on the fact that AutogradOther is currently the lowest +// Autograd key +constexpr DispatchKeySet after_autograd_keyset = + DispatchKeySet(DispatchKeySet::FULL_AFTER, c10::DispatchKey::AutogradOther); + +// The set of dispatch keys that come after ADInplaceOrView +constexpr DispatchKeySet after_ADInplaceOrView_keyset = DispatchKeySet( + DispatchKeySet::FULL_AFTER, + c10::DispatchKey::ADInplaceOrView); + +// The set of dispatch keys that come after Functionalize +constexpr DispatchKeySet after_func_keyset = + DispatchKeySet(DispatchKeySet::FULL_AFTER, c10::DispatchKey::Functionalize) + .remove( + // NOTE: we also need to remove ADInplaceOrView from the keyset when + // redispatching after the func kernels. This is because we're not + // calling the same op; we originally called an inplace op, and now + // we aren't. The original key calculation figured out which keys + // were Fallthrough based on the inplace op. That means that it did + // not include the ADInPlaceOrView kernel as a fallthrough key. + // However, we WANT the ADInPlaceOrView kernel to be ignored now + // that we're calling an out-of-place op. Re-invoking + // Dispatcher::call would re-run the Fallthrough key calculation and + // get us that, But at::redispatch is more performant. We can get + // away with it by explicitly removing the key here. + c10::DispatchKey::ADInplaceOrView); + +constexpr DispatchKeySet backend_bitset_mask = + DispatchKeySet(DispatchKeySet::RAW, (1ULL << num_backends) - 1); + +constexpr auto inplace_or_view_ks = + DispatchKeySet(DispatchKey::ADInplaceOrView); +constexpr auto autograd_cpu_ks = DispatchKeySet(DispatchKey::AutogradCPU); +constexpr auto autograd_ipu_ks = DispatchKeySet(DispatchKey::AutogradIPU); +constexpr auto autograd_mtia_ks = DispatchKeySet(DispatchKey::AutogradMTIA); +constexpr auto autograd_maia_ks = DispatchKeySet(DispatchKey::AutogradMAIA); +constexpr auto autograd_xpu_ks = DispatchKeySet(DispatchKey::AutogradXPU); +constexpr auto autograd_cuda_ks = DispatchKeySet(DispatchKey::AutogradCUDA); +constexpr auto autograd_xla_ks = DispatchKeySet(DispatchKey::AutogradXLA); +constexpr auto autograd_lazy_ks = DispatchKeySet(DispatchKey::AutogradLazy); +constexpr auto autograd_meta_ks = DispatchKeySet(DispatchKey::AutogradMeta); +constexpr auto autograd_mps_ks = DispatchKeySet(DispatchKey::AutogradMPS); +constexpr auto autograd_hpu_ks = DispatchKeySet(DispatchKey::AutogradHPU); +constexpr auto autograd_privateuse1_ks = + DispatchKeySet(DispatchKey::AutogradPrivateUse1); +constexpr auto autograd_privateuse2_ks = + DispatchKeySet(DispatchKey::AutogradPrivateUse2); +constexpr auto autograd_privateuse3_ks = + DispatchKeySet(DispatchKey::AutogradPrivateUse3); +constexpr auto autograd_other_ks = DispatchKeySet(DispatchKey::AutogradOther); +constexpr auto autograd_nested = + DispatchKeySet(DispatchKey::AutogradNestedTensor); +// keyset corresponding to functorch keys that have their own dedicated +// TensorImpl subclass. +constexpr auto functorch_transforms_ks = DispatchKeySet( + {DispatchKey::FuncTorchBatched, + DispatchKey::FuncTorchVmapMode, + DispatchKey::Batched, + DispatchKey::VmapMode, + DispatchKey::FuncTorchGradWrapper}); + +constexpr auto functorch_batched_ks = + DispatchKeySet({DispatchKey::FuncTorchBatched}); + +// This keyset has: +// (1) the functionality bits corresponding to backends (dense, sparse, +// quantized) (2) all of the backend bits set +constexpr DispatchKeySet backend_functionality_keys = + DispatchKeySet({ + DispatchKey::Dense, + DispatchKey::Quantized, + DispatchKey::Sparse, + DispatchKey::SparseCsr, + }) | + DispatchKeySet(DispatchKeySet::RAW, full_backend_mask); + +struct OpTableOffsetAndMask { + uint16_t offset; + uint16_t backend_mask; +}; + +static_assert( + num_backends <= 16, + "Right now we expect the number of backends not to exceed 16. In the (unlikely) event" + " that this changes, the size of OpTableOffsetAndMask::backend_mask needs to be increased too."); + +// true if t is a backend dispatch key +C10_API bool isBackendDispatchKey(DispatchKey t); + +// Resolve alias dispatch key to DispatchKeySet if applicable +C10_API DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t); + +// Resolve alias dispatch key to DispatchKeySet if applicable, +// and check if k is a part of that set +C10_API bool runtimeDispatchKeySetHas(DispatchKey t, DispatchKey k); + +// Returns a DispatchKeySet of all backend keys mapped to Autograd dispatch key +// t, DispatchKeySet is empty if t is not alias of DispatchKey::Autograd. +C10_API DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t); + +// Returns a DispatchKeySet of autograd related keys mapped to backend. +// for a given backend key, use the associated autograd key. +// for non-backend keys, use AutogradOther as a default. +// Note: it's convenient and fast to return a default here rather than (say) +// returning an std::optional, or throwing. But it makes callers +// responsible for either a) enforcing the invariant that only backend keys +// be passed as arguments, or b) interpreting our return value carefully. +inline DispatchKeySet getAutogradRelatedKeySetFromBackend(BackendComponent t) { + switch (t) { + case BackendComponent::CPUBit: + return inplace_or_view_ks | autograd_cpu_ks; + case BackendComponent::IPUBit: + return inplace_or_view_ks | autograd_ipu_ks; + case BackendComponent::MTIABit: + return inplace_or_view_ks | autograd_mtia_ks; + case BackendComponent::MAIABit: + return inplace_or_view_ks | autograd_maia_ks; + case BackendComponent::XPUBit: + return inplace_or_view_ks | autograd_xpu_ks; + case BackendComponent::CUDABit: + return inplace_or_view_ks | autograd_cuda_ks; + case BackendComponent::XLABit: + return inplace_or_view_ks | autograd_xla_ks; + case BackendComponent::LazyBit: + return inplace_or_view_ks | autograd_lazy_ks; + case BackendComponent::MetaBit: + return inplace_or_view_ks | autograd_meta_ks; + case BackendComponent::MPSBit: + return inplace_or_view_ks | autograd_mps_ks; + case BackendComponent::HPUBit: + return inplace_or_view_ks | autograd_hpu_ks; + case BackendComponent::PrivateUse1Bit: + return inplace_or_view_ks | autograd_privateuse1_ks; + case BackendComponent::PrivateUse2Bit: + return inplace_or_view_ks | autograd_privateuse2_ks; + case BackendComponent::PrivateUse3Bit: + return inplace_or_view_ks | autograd_privateuse3_ks; + default: + return inplace_or_view_ks | autograd_other_ks; + } +} + +// Returns a DispatchKeySet of autocast related keys mapped to backend. +inline DispatchKeySet getAutocastRelatedKeySetFromBackend(BackendComponent t) { + constexpr auto autocast_cpu_ks = DispatchKeySet(DispatchKey::AutocastCPU); + constexpr auto autocast_mtia_ks = DispatchKeySet(DispatchKey::AutocastMTIA); + constexpr auto autocast_maia_ks = DispatchKeySet(DispatchKey::AutocastMAIA); + constexpr auto autocast_xpu_ks = DispatchKeySet(DispatchKey::AutocastXPU); + constexpr auto autocast_ipu_ks = DispatchKeySet(DispatchKey::AutocastIPU); + constexpr auto autocast_hpu_ks = DispatchKeySet(DispatchKey::AutocastHPU); + constexpr auto autocast_cuda_ks = DispatchKeySet(DispatchKey::AutocastCUDA); + constexpr auto autocast_xla_ks = DispatchKeySet(DispatchKey::AutocastXLA); + constexpr auto autocast_privateuse1_ks = + DispatchKeySet(DispatchKey::AutocastPrivateUse1); + constexpr auto autocast_mps_ks = DispatchKeySet(DispatchKey::AutocastMPS); + switch (t) { + case BackendComponent::CPUBit: + return autocast_cpu_ks; + case BackendComponent::MTIABit: + return autocast_mtia_ks; + case BackendComponent::MAIABit: + return autocast_maia_ks; + case BackendComponent::XPUBit: + return autocast_xpu_ks; + case BackendComponent::IPUBit: + return autocast_ipu_ks; + case BackendComponent::HPUBit: + return autocast_hpu_ks; + case BackendComponent::CUDABit: + return autocast_cuda_ks; + case BackendComponent::XLABit: + return autocast_xla_ks; + case BackendComponent::PrivateUse1Bit: + return autocast_privateuse1_ks; + case BackendComponent::MPSBit: + return autocast_mps_ks; + default: + return DispatchKeySet(); + } +} + +// returns the "backend" DispatchKey of highest priority in the set. +// This is basically like highestBackendKey(), except that we have some +// "functionality" bits that correspond to backends (Sparse, Quantized) +inline DispatchKey highestPriorityBackendTypeId(DispatchKeySet ks) { + return (ks & backend_functionality_keys).highestPriorityTypeId(); +} + +// This API exists because we have a use case for checking +// getRuntimeDispatchKeySet(alias).has(DispatchKey::Undefined) +// in OperatorEntry.cpp but we disallow it in has() API. +C10_API bool isIncludedInAlias(DispatchKey k, DispatchKey alias); + +// Historically, every tensor only had a single DispatchKey, and it was always +// something like CPU, and there wasn't any of this business where TLS +// could cause the DispatchKey of a tensor to change. But we still have some +// legacy code that is still using DispatchKey for things like instanceof +// checks; if at all possible, refactor the code to stop using DispatchKey in +// those cases. +inline DispatchKey legacyExtractDispatchKey(DispatchKeySet s) { + // NB: If you add any extra keys that can be stored in TensorImpl on + // top of existing "backend" keys like CPU/CUDA, you need to add it + // here. At the moment, autograd keys and ADInplaceOrView key need this + // treatment; + return (s - autograd_dispatch_keyset_with_ADInplaceOrView - + autocast_dispatch_keyset - + DispatchKeySet( + {DispatchKey::Functionalize, + DispatchKey::PythonTLSSnapshot, + DispatchKey::FuncTorchGradWrapper, + DispatchKey::FuncTorchVmapMode, + DispatchKey::FuncTorchBatched, + DispatchKey::Python})) + .highestPriorityTypeId(); +} + +template +using is_not_DispatchKeySet = std::negation>; + +// Given a function type, constructs a function_traits type that drops the first +// parameter type if the first parameter is of type DispatchKeySet. NB: +// DispatchKeySet is currently explicitly hidden from JIT (mainly to avoid +// pushing unnecessary arguments on the stack - see Note [ Plumbing Keys Through +// the Dispatcher] for details). If at any point in the future we need to expose +// this type to JIT, revisit the usage of this type alias. +template +using remove_DispatchKeySet_arg_from_func = guts::make_function_traits_t< + typename guts::infer_function_traits_t::return_type, + typename std::conditional_t< + std::is_same_v< + DispatchKeySet, + typename guts::typelist::head_with_default_t< + void, + typename guts::infer_function_traits_t< + FuncType>::parameter_types>>, + guts::typelist::drop_if_nonempty_t< + typename guts::infer_function_traits_t::parameter_types, + 1>, + typename guts::infer_function_traits_t::parameter_types>>; +} // namespace c10 + +C10_DIAGNOSTIC_POP() + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/DynamicCast.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/DynamicCast.h new file mode 100644 index 0000000000000000000000000000000000000000..d0f0f0b27c97bf7521a09fae5c6d7c04d9e0b46e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/DynamicCast.h @@ -0,0 +1,134 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include + +namespace c10 { + +// Dynamic type casting utils: +// - fetch_and_cast +// - cast_and_store +// +// fetch_and_cast fetch a value with dynamic type specified by a ScalarType +// from a void pointer and cast it to a static type. +// +// cast_and_store casts a static typed value into dynamic type specified +// by a ScalarType, and store it into a void pointer. +// +// NOTE: +// +// Dynamic casting allows us to support type promotion without blowing up +// the combination space: For example, without dynamic cast, in order to +// implement `add_` with type promotion, we would need something like +// +// AT_DISPATCH_ALL_TYPES(output.dtype(), +// AT_DISPATCH_ALL_TYPES(input1.dtype(), +// AT_DISPATCH_ALL_TYPES(input2.dtype(), +// [](arg0_t a, arg1_t b) -> out_t { return a + b; } +// ) +// ) +// ) +// +// If we support N dtypes, the above code would generate the a+b kernel for +// all the N * N * N different supported types, the compilation time and +// binary size would become horrible. +// +// Dynamic casting might sounds like a bad idea in terms of performance. +// Especially if you ever do it in a loop, you are going to do a billion tests. +// But in practice it is not as bad as it might look: +// +// - on CPU, this is a branch that always has the same outcome, therefore +// hopefully the branch predictor could do the job pretty well +// - on GPU, these branches will not diverge, so we could still have the same +// warp executing the same line of code +// - Most kernels, like `add`, are bandwidth bound, adding a few clock cycles to +// check an integer does not hurt the performance much because the ALUs would +// wait for load instructions anyway. +// +// For the discussion and benchmark, refer to: +// - https://github.com/pytorch/pytorch/pull/28343 +// - https://github.com/pytorch/pytorch/pull/28344 +// - https://github.com/pytorch/pytorch/pull/28345 +// + +#ifdef C10_HOST_DEVICE +#define ERROR_UNSUPPORTED_CAST CUDA_KERNEL_ASSERT(false); +#else +#define ERROR_UNSUPPORTED_CAST TORCH_CHECK(false, "Unexpected scalar type"); +#endif + +// Fetch a value with dynamic type src_type from ptr, and cast it to static type +// dest_t. +#define FETCH_AND_CAST_CASE(type, scalartype) \ + case ScalarType::scalartype: \ + return c10::convert(c10::load(ptr)); + +template +C10_HOST_DEVICE inline dest_t fetch_and_cast( + const ScalarType src_type, + const void* ptr) { + C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-enum") + switch (src_type) { + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(FETCH_AND_CAST_CASE) + FETCH_AND_CAST_CASE(uint16_t, UInt16) + FETCH_AND_CAST_CASE(uint32_t, UInt32) + FETCH_AND_CAST_CASE(uint64_t, UInt64) + default: + ERROR_UNSUPPORTED_CAST + } + C10_DIAGNOSTIC_POP() + return dest_t(0); // just to avoid compiler warning +} + +// Cast a value with static type src_t into dynamic dest_type, and store it to +// ptr. +#define CAST_AND_STORE_CASE(type, scalartype) \ + case ScalarType::scalartype: \ + *(type*)ptr = c10::convert(value); \ + return; +template +C10_HOST_DEVICE inline void cast_and_store( + const ScalarType dest_type, + void* ptr, + src_t value) { + C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-enum") + switch (dest_type) { + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(CAST_AND_STORE_CASE) + CAST_AND_STORE_CASE(uint16_t, UInt16) + CAST_AND_STORE_CASE(uint32_t, UInt32) + CAST_AND_STORE_CASE(uint64_t, UInt64) + default:; + } + C10_DIAGNOSTIC_POP() + ERROR_UNSUPPORTED_CAST +} + +#define DEFINE_UNCASTABLE(T, scalartype_) \ + template <> \ + C10_HOST_DEVICE inline T fetch_and_cast( \ + const ScalarType src_type, const void* ptr) { \ + CUDA_KERNEL_ASSERT(ScalarType::scalartype_ == src_type); \ + return c10::load(ptr); \ + } \ + template <> \ + C10_HOST_DEVICE inline void cast_and_store( \ + const ScalarType dest_type, void* ptr, T value) { \ + CUDA_KERNEL_ASSERT(ScalarType::scalartype_ == dest_type); \ + *(T*)ptr = value; \ + } + +AT_FORALL_QINT_TYPES(DEFINE_UNCASTABLE) + +#undef FETCH_AND_CAST_CASE +#undef CAST_AND_STORE_CASE +#undef DEFINE_UNCASTABLE +#undef ERROR_UNSUPPORTED_CAST + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/Event.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/Event.h new file mode 100644 index 0000000000000000000000000000000000000000..aed1a213bfb4724b5019909adafc237297262f9e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/Event.h @@ -0,0 +1,142 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace c10 { + +/** + * A backend-generic movable, not copyable, not thread-safe event. + * + * The design of this event follows that of CUDA and HIP events. These events + * are recorded and waited on by streams and can be rerecorded to, + * each rerecording essentially creating a new version of the event. + * For example, if (in CPU time), stream X is asked to record E, + * stream Y waits on E, and stream X is asked to record E again, then Y will + * wait for X to finish the first call to record and not the second, because + * it's waiting on the first version of event E, not the second. + * Querying an event only returns the status of its most recent version. + * + * Backend-generic events are implemented by this class and + * impl::InlineEvent. In addition to these events there are also + * some backend-specific events, like ATen's CUDAEvent. Each of these + * classes has its own use. + * + * impl::InlineEvent<...> or a backend-specific event should be + * preferred when the backend is known at compile time and known to + * be compiled. Backend-specific events may have additional functionality. + * + * This Event should be used if a particular backend may not be available, + * or the backend required is not known at compile time. + * + * These generic events are built on top of DeviceGuardImpls, analogous + * to DeviceGuard and InlineDeviceGuard. The name "DeviceGuardImpls," + * is no longer entirely accurate, as these classes implement the + * backend-specific logic for a generic backend interface. + * + * See DeviceGuardImplInterface.h for a list of all supported flags. + */ + +struct Event final { + // Constructors + Event() = delete; + Event( + const DeviceType _device_type, + const EventFlag _flag = EventFlag::PYTORCH_DEFAULT) + : impl_{_device_type, _flag} {} + + // Copy constructor and copy assignment operator (deleted) + Event(const Event&) = delete; + Event& operator=(const Event&) = delete; + + // Move constructor and move assignment operator + Event(Event&&) noexcept = default; + Event& operator=(Event&&) noexcept = default; + + // Destructor + ~Event() = default; + + // Getters + Device device() const noexcept { + return Device(device_type(), device_index()); + } + DeviceType device_type() const noexcept { + return impl_.device_type(); + } + DeviceIndex device_index() const noexcept { + return impl_.device_index(); + } + EventFlag flag() const noexcept { + return impl_.flag(); + } + bool was_marked_for_recording() const noexcept { + return impl_.was_marked_for_recording(); + } + + /** + * Calls record() if and only if record() has never been called for this + * event. Note: because Event is not thread-safe recordOnce() may call + * record() multiple times if called from multiple threads. + */ + void recordOnce(const Stream& stream) { + impl_.recordOnce(stream); + } + + /** + * Increments the event's version and enqueues a job with this version + * in the stream's work queue. When the stream process that job + * it notifies all streams waiting on / blocked by that version of the + * event to continue and marks that version as recorded. + * */ + void record(const Stream& stream) { + impl_.record(stream); + } + + /** + * Does nothing if the event has not been scheduled to be recorded. + * If the event was previously enqueued to be recorded, a command + * to wait for the version of the event that exists at the time of this call + * is inserted in the stream's work queue. + * When the stream reaches this command it will stop processing + * additional commands until that version of the event is marked as recorded. + */ + void block(const Stream& stream) const { + impl_.block(stream); + } + + /** + * Returns true if (and only if) + * (1) the event has never been scheduled to be recorded + * (2) the current version is marked as recorded. + * Returns false otherwise. + */ + bool query() const { + return impl_.query(); + } + + double elapsedTime(const Event& event) const { + return impl_.elapsedTime(event.impl_); + } + + void* eventId() const { + return impl_.eventId(); + } + + void synchronize() const { + impl_.synchronize(); + } + + private: + impl::InlineEvent impl_; +}; + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/GeneratorImpl.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/GeneratorImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..7d7aac9243ffbbfc4f79471ebceee04ced485219 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/GeneratorImpl.h @@ -0,0 +1,116 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +#include +#include +#include +#include +#include +#include + +/** + * Note [Generator] + * ~~~~~~~~~~~~~~~~ + * A Pseudo Random Number Generator (PRNG) is an engine that uses an algorithm + * to generate a seemingly random sequence of numbers, that may be later be used + * in creating a random distribution. Such an engine almost always maintains a + * state and requires a seed to start off the creation of random numbers. Often + * times, users have found it beneficial to be able to explicitly create, + * retain, and destroy PRNG states and also be able to have control over the + * seed value. + * + * A Generator in ATen gives users the ability to read, write and modify a PRNG + * engine. For instance, it does so by letting users seed a PRNG engine, fork + * the state of the engine, etc. + * + * By default, there is one generator per device, and a device's generator is + * lazily created. A user can use the torch.Generator() api to create their own + * generator. Currently torch.Generator() can only create a CPUGeneratorImpl. + */ + +/** + * Note [Acquire lock when using random generators] + * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + * Generator and its derived classes are NOT thread-safe. Please note that most + * of the places where we have inserted locking for generators are historically + * based, and we haven't actually checked that everything is truly thread safe + * (and it probably isn't). Please use the public mutex_ when using any methods + * from these classes, except for the read-only methods. You can learn about the + * usage by looking into the unittests (aten/src/ATen/cpu_generator_test.cpp) + * and other places where we have used lock_guard. + * + * TODO: Look into changing the threading semantics of Generators in ATen (e.g., + * making them non-thread safe and instead making the generator state + * splittable, to accommodate forks into other threads). + */ + +namespace c10 { + +// The default seed is selected to be a large number +// with good distribution of 0s and 1s in bit representation +constexpr uint64_t default_rng_seed_val = 67280421310721; + +struct C10_API GeneratorImpl : public c10::intrusive_ptr_target { + // Constructors + GeneratorImpl(Device device_in, DispatchKeySet key_set); + + // Delete all copy and move assignment in favor of clone() + // method + GeneratorImpl(const GeneratorImpl& other) = delete; + GeneratorImpl(GeneratorImpl&& other) = delete; + GeneratorImpl& operator=(const GeneratorImpl& other) = delete; + GeneratorImpl& operator=(GeneratorImpl&& other) = delete; + + ~GeneratorImpl() override = default; + c10::intrusive_ptr clone() const; + + // Common methods for all generators + virtual void set_current_seed(uint64_t seed) = 0; + virtual void set_offset(uint64_t offset) = 0; + virtual uint64_t get_offset() const = 0; + virtual uint64_t current_seed() const = 0; + virtual uint64_t seed() = 0; + virtual void set_state(const c10::TensorImpl& new_state) = 0; + virtual c10::intrusive_ptr get_state() const = 0; + virtual void graphsafe_set_state( + const c10::intrusive_ptr& new_state); + virtual c10::intrusive_ptr graphsafe_get_state() const; + Device device() const; + + // See Note [Acquire lock when using random generators] + std::mutex mutex_; + + DispatchKeySet key_set() const { + return key_set_; + } + + inline void set_pyobj(PyObject* pyobj) noexcept { + pyobj_ = pyobj; + } + + inline PyObject* pyobj() const noexcept { + return pyobj_; + } + + protected: + Device device_; + DispatchKeySet key_set_; + PyObject* pyobj_ = nullptr; + + virtual GeneratorImpl* clone_impl() const = 0; +}; + +namespace detail { + +C10_API uint64_t getNonDeterministicRandom(bool is_cuda = false); + +} // namespace detail + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/GradMode.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/GradMode.h new file mode 100644 index 0000000000000000000000000000000000000000..391b293f9f005af1035dbf9e43be91bf5b353bed --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/GradMode.h @@ -0,0 +1,57 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +namespace c10 { + +struct C10_API GradMode { + static bool is_enabled(); + static void set_enabled(bool enabled); +}; + +// A RAII, thread local (!) guard that enables or disables grad mode upon +// construction, and sets it back to the original value upon destruction. +struct C10_API AutoGradMode { + AutoGradMode(bool enabled) : prev_mode(GradMode::is_enabled()) { + GradMode::set_enabled(enabled); + } + AutoGradMode(const AutoGradMode&) = delete; + AutoGradMode(AutoGradMode&&) = delete; + AutoGradMode& operator=(const AutoGradMode&) = delete; + AutoGradMode& operator=(AutoGradMode&&) = delete; + ~AutoGradMode() { + GradMode::set_enabled(prev_mode); + } + bool prev_mode; +}; + +// A RAII, thread local (!) guard that stops future operations from building +// gradients. +struct C10_API NoGradGuard : public AutoGradMode { + NoGradGuard() : AutoGradMode(/*enabled=*/false) {} +}; + +// A RAII, thread local (!) guard that enables or disables forward grad mode +// upon construction, and sets it back to the original value upon destruction. +struct C10_API AutoFwGradMode { + AutoFwGradMode(bool enabled) + : prev_mode(AutogradState::get_tls_state().get_fw_grad_mode()) { + AutogradState::get_tls_state().set_fw_grad_mode(enabled); + } + AutoFwGradMode(const AutoFwGradMode&) = delete; + AutoFwGradMode(AutoFwGradMode&&) = delete; + AutoFwGradMode& operator=(const AutoFwGradMode&) = delete; + AutoFwGradMode& operator=(AutoFwGradMode&&) = delete; + ~AutoFwGradMode() { + AutogradState::get_tls_state().set_fw_grad_mode(prev_mode); + } + bool prev_mode; +}; + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/InferenceMode.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/InferenceMode.h new file mode 100644 index 0000000000000000000000000000000000000000..8da25b5427e61d250268a352f11757a4e1d7ab24 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/InferenceMode.h @@ -0,0 +1,96 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include + +namespace c10 { + +// A RAII, thread local (!) guard that enables or disables inference mode upon +// construction, and sets it back to the original value upon destruction. +struct C10_API InferenceMode { + // Note [Expected TLS state in InferenceMode]: + // InferenceMode: ADInplaceOrView not in + // raw_local_dispatch_key_set.included(), + // Autograd in raw_local_dispatch_key_set.excluded() + // GradMode is disabled. + // NormalMode: ADInplaceOrView in raw_local_dispatch_key_set.included(), + // Autograd not in raw_local_dispatch_key_set.excluded() + // GradMode is enabled by default unless toggled manually + // through other APIs, e.g. NoGradGuard. + // + // Invariant: + // - ADInplaceOrView is never in the excluded set + // - Autograd is never in the included set + // - Setting InferenceMode will set GradMode accordingly, but not vice versa. + // + // 1. Why do we put ADInplaceOrView in included set outside InferenceMode? + // + // Inplace update to inference tensor outside InferenceMode is not + // allowed. See Note [Inplace update inference tensor] for more details. + // Without going through ADInplaceOrView kernel, we cannot throw error + // for `inference_tensor.add_(1)` case. + // + // 2. Why not put ADInplaceOrView in the excluded set inside InferenceMode? + // + // For example: + // torch::Tensor a = torch::ones({1, 2, 3}).set_requires_grad(true); + // torch::Tensor k = a + 2; + // { + // c10::InferenceMode guard(true); + // k.add_(2); + // } + // `k.add_(2)` still need to go through ADInplaceOrView kernel so that it's + // prepared for future autograd. + // + // 3. Why does setting InferenceMode also set GradMode? + // + // This is required since InferenceMode is a faster and more restrictive + // version of NoGradGuard. All runtime checks using GradMode::is_enabled() + // are applicable to InferenceMode as well, e.g. + // `tensorTypeInCurrentExecutionContext` in interpreter.cpp. + InferenceMode(bool enabled = true) + : prev_mode(AutogradState::get_tls_state()), + prev_keyset(c10::impl::tls_local_dispatch_key_set()) { + // Enabling inference mode means disabling grad modes + // And disabling inference mode means enabling grad modes + AutogradState::set_tls_state(AutogradState( + /* grad_mode */ !enabled, + /* inference_mode */ enabled, + /* fw_grad_mode */ !enabled, + /* multithreading_enabled*/ !enabled)); + DispatchKeySet included = enabled + ? prev_keyset.included_.remove(c10::DispatchKey::ADInplaceOrView) + : prev_keyset.included_.add(c10::DispatchKey::ADInplaceOrView); + DispatchKeySet excluded = enabled + ? (prev_keyset.excluded_ | c10::autograd_dispatch_keyset) + : (prev_keyset.excluded_ - c10::autograd_dispatch_keyset); + c10::impl::PODLocalDispatchKeySet cur_keyset{}; + cur_keyset.set_included(included); + cur_keyset.set_excluded(excluded); + c10::impl::_force_tls_local_dispatch_key_set(cur_keyset); + } + + InferenceMode(const InferenceMode&) = delete; + InferenceMode(InferenceMode&&) = delete; + InferenceMode& operator=(const InferenceMode&) = delete; + InferenceMode& operator=(InferenceMode&&) = delete; + + ~InferenceMode() { + AutogradState::set_tls_state(prev_mode); + c10::impl::_force_tls_local_dispatch_key_set(prev_keyset); + } + static bool is_enabled(); + + private: + AutogradState prev_mode; + c10::impl::LocalDispatchKeySet prev_keyset; +}; +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/Layout.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/Layout.h new file mode 100644 index 0000000000000000000000000000000000000000..194e1863cb18cf2759f2c4e3e1ace298efd76150 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/Layout.h @@ -0,0 +1,67 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +#include + +namespace c10 { + +inline Layout layout_from_backend(Backend backend) { + C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-enum") + switch (backend) { + case Backend::SparseCPU: + case Backend::SparseCUDA: + case Backend::SparseMPS: + case Backend::SparseHIP: + case Backend::SparseVE: + case Backend::SparseXPU: + case Backend::SparsePrivateUse1: + return Layout::Sparse; + case Backend::MkldnnCPU: + return Layout::Mkldnn; + case Backend::SparseCsrCPU: + case Backend::SparseCsrCUDA: + case Backend::SparseCsrMPS: + case Backend::SparseCsrHIP: + case Backend::SparseCsrVE: + case Backend::SparseCsrXPU: + TORCH_CHECK( + false, + "Cannot map Backend SparseCsr(CPU|CUDA|HIP|VE|XPU|MPS) to a unique layout."); + default: + return Layout::Strided; + } + C10_DIAGNOSTIC_POP() +} + +inline std::ostream& operator<<(std::ostream& stream, at::Layout layout) { + switch (layout) { + case at::kStrided: + return stream << "Strided"; + case at::kSparse: + return stream << "Sparse"; + case at::kSparseCsr: + return stream << "SparseCsr"; + case at::kSparseCsc: + return stream << "SparseCsc"; + case at::kSparseBsr: + return stream << "SparseBsr"; + case at::kSparseBsc: + return stream << "SparseBsc"; + case at::kMkldnn: + return stream << "Mkldnn"; + case at::kJagged: + return stream << "Jagged"; + case Layout::NumOptions: + default: + TORCH_CHECK(false, "Unknown layout"); + } +} + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/MemoryFormat.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/MemoryFormat.h new file mode 100644 index 0000000000000000000000000000000000000000..63cdb757952b073d957fc91c33357136c1287679 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/MemoryFormat.h @@ -0,0 +1,268 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +#include + +#include +#include + +namespace c10 { + +// If you are seeing this, it means that this call site was not checked if +// the memory format could be preserved, and it was switched to old default +// behaviour of contiguous +#define LEGACY_CONTIGUOUS_MEMORY_FORMAT c10::get_contiguous_memory_format() + +inline std::ostream& operator<<( + std::ostream& stream, + at::MemoryFormat memory_format) { + switch (memory_format) { + case MemoryFormat::Preserve: + return stream << "Preserve"; + case MemoryFormat::Contiguous: + return stream << "Contiguous"; + case MemoryFormat::ChannelsLast: + return stream << "ChannelsLast"; + case MemoryFormat::ChannelsLast3d: + return stream << "ChannelsLast3d"; + case MemoryFormat::NumOptions: + default: + TORCH_CHECK(false, "Unknown memory format ", memory_format); + } +} + +// Note: Hardcoded the channel last stride indices here to get better +// performance +template +inline std::vector get_channels_last_strides_2d(ArrayRef sizes) { + std::vector strides(sizes.size()); + switch (sizes.size()) { + case 4: + strides[1] = 1; + strides[3] = sizes[1]; + strides[2] = strides[3] * sizes[3]; + strides[0] = strides[2] * sizes[2]; + return strides; + case 3: + strides[0] = 1; + strides[2] = sizes[0]; + strides[1] = strides[2] * sizes[2]; + return strides; + default: + TORCH_INTERNAL_ASSERT( + false, "ChannelsLast2d doesn't support size ", sizes.size()); + } +} + +inline std::vector get_channels_last_strides_2d(IntArrayRef sizes) { + return get_channels_last_strides_2d(sizes); +} + +template +std::vector get_channels_last_strides_3d(ArrayRef sizes) { + std::vector strides(sizes.size()); + switch (sizes.size()) { + case 5: + strides[1] = 1; + strides[4] = sizes[1]; + strides[3] = strides[4] * sizes[4]; + strides[2] = strides[3] * sizes[3]; + strides[0] = strides[2] * sizes[2]; + return strides; + case 4: + strides[0] = 1; + strides[3] = sizes[0]; + strides[2] = strides[3] * sizes[3]; + strides[1] = strides[2] * sizes[2]; + return strides; + default: + TORCH_INTERNAL_ASSERT( + false, "ChannelsLast3d doesn't support size ", sizes.size()); + } +} + +inline std::vector get_channels_last_strides_3d(IntArrayRef sizes) { + return get_channels_last_strides_3d(sizes); +} + +// NOTE: +// Below are Helper functions for is_channels_last_strides_xd. +// 1. Please do not combine these helper functions, each helper function handles +// exactly one case of sizes + memory_format, by doing this, the strides indices +// will be a constant array and we can access it using constant index number, +// the compiler will fully unroll the loop on strides indices to gain a better +// performance. +// 2. No error check in helper function, caller ensures the correctness of the +// input +// 3. All helper functions have similar comments, only 1st helper function is +// commented here. +template +inline bool is_channels_last_strides_2d_s4( + const ArrayRef sizes, + const ArrayRef strides) { + T min = 0; + // special case for trivial C dimension. default to NCHW + if (strides[1] == 0) { + return false; + } + // loop strides indices + for (auto& d : {1, 3, 2, 0}) { + if (sizes[d] == 0) { + return false; + } + if (strides[d] < min) { + return false; + } + // Fallback to NCHW as default layout for ambiguous cases + // This is the flaw of implicit memory_format from strides. + // N111 tensor with identical strides for size 1 dimension; + // Two cases could lead us here: + // a. N111 contiguous Tensor ([N,1,1,1]@[1,1,1,1]) + // b. N11W contiguous Tensor sliced on the W-dimension. + // ([N,1,1,1]@[W,W,W,W]) + if (d == 0 && min == strides[1]) { + return false; + } + // This is necessary to: + // 1. distinguish the memory_format of N1H1; + // [H, 1, 1, 1] channels_last stride + // [H, H, 1, 1] contiguous stride + // 2. permutation of 1C1W: + // [1, C, 1, H]@[HC, H, H, 1] transpose(1, 3) + // [1, H, 1, C]@[HC, 1, H, H] shouldn't be identified as channels_last + min = strides[d]; + if (sizes[d] > 1) { + min *= sizes[d]; + } + } + return true; +} + +template +inline bool is_channels_last_strides_3d_s5( + const ArrayRef sizes, + const ArrayRef strides) { + T min = 0; + if (strides[1] == 0) { + return false; + } + for (auto& d : {1, 4, 3, 2, 0}) { + if (sizes[d] == 0) { + return false; + } + if (strides[d] < min) { + return false; + } + if (d == 0 && min == strides[1]) { + return false; + } + min = strides[d]; + if (sizes[d] > 1) { + min *= sizes[d]; + } + } + return true; +} + +// Note [Ambiguous is_channels_last_strides_xd] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// The flaw of carrying memory_format implicitly through strides is very hard +// to WAR properly. issue #24090 +// Without the history of permutation, we can't infer the memory_format of a +// tensor from the snapshot of its size & stride +// e.g. +// +// 1. We can NOT specify the memory_format of N111 tensor through strides in a +// meaningful way; +// +// 2. Two path that ended up with identical size/stride +// N11W contiguous tensor sliced at w-dimension becomes [N,1,1,1]@[W,W,W,W] +// NC11 channels_last tensor sliced at c-dimension becomes [N,1,1,1]@[C,C,C,C] +// So if we see a tensor [N,1,1,1]@[X,X,X,X], there's no way for us to infer +// the memory_format of the original tensor. +// +// Due to the limitations, our temporary WAR `is_channels_last_strides` does the +// best effort to infer whether the original memory_format of a tensor is +// at::MemoryFormat::ChannelsLast. The two objectives of this function (ordered +// by their importance): +// 1. Ensure that normal shape manipulation does not accidentally change the +// MemoryFormat of an existing tensor. +// 2. Allows user to mark MemoryFormat::ChannelsLast to tensors; +// +// The function does so via checking strides of the tensor, including strides of +// size-1 dimensions. Although conventionally PyTorch implies no restriction on +// trivial stride (stride for size-1 dimension). +// +// Note that this approach is a compromise. We did not solve the problem +// completely. Many cases we will not be able to infer the correct memory +// format. +// The implementation of `is_channels_last_strides` is to serve the objectives: +// MemoryFormat::ChannelsLast has to be explicitly opted-in (no accidental +// conversion); Best effort to maintain the ChannelsLast flag. +// +// Due to the fact that this is not a bulletproof solution, through testing +// (aten/src/ATen/test/memory_format_test.cpp) +// a. we ensure that the common tasks are supported; +// a. we identify corner cases where the implementation compromises on. +// +// By the time accumulated permutation is enabled to replace implicit +// memory_format through strides, we should be updating our tests and fix the +// issues in our tests. +// +// We use Channels Last 2d as an example above. +// This is a general problem for all the is_channels_last_strides_xd +// implementation. Please check the helper functions +// (is_channels_last_strides_*d_s*) for more details. + +template +inline bool is_channels_last_strides_2d( + const ArrayRef sizes, + const ArrayRef strides) { + switch (sizes.size()) { + case 4: + return is_channels_last_strides_2d_s4(sizes, strides); + // NOLINTNEXTLINE(bugprone-branch-clone) + case 3: + // TODO dim == 3 case will be enabled once it is fully tested + return false; + default: + return false; + } +} + +template +inline bool is_channels_last_strides_3d( + const ArrayRef sizes, + const ArrayRef strides) { + switch (sizes.size()) { + case 5: + return is_channels_last_strides_3d_s5(sizes, strides); + // NOLINTNEXTLINE(bugprone-branch-clone) + case 4: + // TODO dim == 4 case will be enabled once it is fully tested + return false; + default: + return false; + } +} + +inline bool is_channels_last_strides_2d( + const IntArrayRef sizes, + const IntArrayRef strides) { + return is_channels_last_strides_2d(sizes, strides); +} + +inline bool is_channels_last_strides_3d( + const IntArrayRef sizes, + const IntArrayRef strides) { + return is_channels_last_strides_3d(sizes, strides); +} + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/OptionalRef.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/OptionalRef.h new file mode 100644 index 0000000000000000000000000000000000000000..f1199e1945a65866cfd17c5301e20454721dc117 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/OptionalRef.h @@ -0,0 +1,36 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +namespace c10 { + +template +class OptionalRef { + public: + OptionalRef() : data_(nullptr) {} + OptionalRef(const T* data) : data_(data) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(data_); + } + OptionalRef(const T& data) : data_(&data) {} + + bool has_value() const { + return data_ != nullptr; + } + + const T& get() const { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(data_); + return *data_; + } + + operator bool() const { + return has_value(); + } + + private: + const T* data_; +}; + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/PyHandleCache.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/PyHandleCache.h new file mode 100644 index 0000000000000000000000000000000000000000..1c39510078bc70aa95e205176fd8bebeeb332065 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/PyHandleCache.h @@ -0,0 +1,81 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include + +#include + +namespace c10 { + +// A PyHandleCache represents a cached pointer from a C++ object to +// a Python object that represents that object analogously in Python. +// Upon a cache hit, the relevant object can be retrieved after a test +// and then a memory load. Two conditions must hold to be able to use this +// class: +// +// - This must truly be a cache; e.g., the caller must be able to produce +// the object some other way if the cache hit misses. +// +// - This must truly be a handle; e.g., the Python object referenced by +// this class must have static lifetime. This means we don't have to +// maintain strong ownership or deallocate the object when the C++ object +// dies. Static lifetime is a good idea in conjunction with the cache, +// since if you are producing a fresh object on miss you won't be +// maintaining object identity. If you need bidirectional ownership, +// you will want to factor out the pattern in TensorImpl with +// resurrection. +// +// This cache is expected to not improve perf under torchdeploy, as one +// interpreter will fill up the cache, and all the interpreters will be +// unable to use the slot. A potential improvement is to have multiple +// slots (one per interpreter), which will work in deployment scenarios +// where there a stable, fixed number of interpreters. You can also store +// the relevant state in the Python library, rather than in the non-Python +// library (although in many cases, this is not convenient, as there may +// not be a way to conveniently index based on the object.) +class PyHandleCache { + public: + PyHandleCache() : pyinterpreter_(nullptr) {} + + // Attempt to fetch the pointer from the cache, if the PyInterpreter + // matches. If it doesn't exist, or the cache entry is not valid, + // use slow_accessor to get the real pointer value and return that + // (possibly writing it to the cache, if the cache entry is + // available.) + template + PyObject* ptr_or(impl::PyInterpreter* self_interpreter, F slow_accessor) + const { + // Note [Memory ordering on Python interpreter tag] + impl::PyInterpreter* interpreter = + pyinterpreter_.load(std::memory_order_acquire); + if (C10_LIKELY(interpreter == self_interpreter)) { + return data_; + } else if (interpreter == nullptr) { + auto* r = slow_accessor(); + impl::PyInterpreter* expected = nullptr; + // attempt to claim this cache entry with the specified interpreter tag + if (pyinterpreter_.compare_exchange_strong( + expected, self_interpreter, std::memory_order_acq_rel)) { + data_ = r; + } + // This shouldn't be possible, as you should be GIL protected + TORCH_INTERNAL_ASSERT(expected != self_interpreter); + return r; + } else { + return slow_accessor(); + } + } + + private: + mutable std::atomic pyinterpreter_; + mutable PyObject* data_{nullptr}; +}; + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/QEngine.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/QEngine.h new file mode 100644 index 0000000000000000000000000000000000000000..b0bb6a245643a3e093c02ae80756403b931245ba --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/QEngine.h @@ -0,0 +1,51 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include + +namespace c10 { + +/** + * QEngine is an enum that is used to select the engine to run quantized ops. + * Keep this enum in sync with get_qengine_id() in + * torch/backends/quantized/__init__.py + */ +enum class QEngine : uint8_t { + NoQEngine = 0, + FBGEMM = 1, + QNNPACK = 2, + ONEDNN = 3, + X86 = 4, +}; + +constexpr auto kNoQEngine = QEngine::NoQEngine; +constexpr auto kFBGEMM = QEngine::FBGEMM; +constexpr auto kQNNPACK = QEngine::QNNPACK; +constexpr auto kONEDNN = QEngine::ONEDNN; +constexpr auto kX86 = QEngine::X86; + +inline std::string toString(QEngine qengine) { + switch (qengine) { + case kNoQEngine: + return "NoQEngine"; + case kFBGEMM: + return "FBGEMM"; + case kQNNPACK: + return "QNNPACK"; + case kONEDNN: + return "ONEDNN"; + case kX86: + return "X86"; + default: + TORCH_CHECK( + false, "Unrecognized Quantized Engine: ", static_cast(qengine)); + } +} + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/QScheme.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/QScheme.h new file mode 100644 index 0000000000000000000000000000000000000000..f557affb1de8ff54fc961159d3cc67e2f11ef3b7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/QScheme.h @@ -0,0 +1,60 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include + +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-enum") + +namespace c10 { + +/** + * QScheme is an enum that specifies the type of quantization. This has a one + * to one correspondence with Quantizer + * Please refer to ATen/quantized/Quantizer.h to see the Quantizers classes. + * Keep this file in sync with torch/nn/_qscheme.py + */ +enum class QScheme : uint8_t { + PER_TENSOR_AFFINE = 0, + PER_CHANNEL_AFFINE = 1, + PER_TENSOR_SYMMETRIC = 2, + PER_CHANNEL_SYMMETRIC = 3, + PER_CHANNEL_AFFINE_FLOAT_QPARAMS = 4, + COMPILE_TIME_NUM_QSCHEMES = 5, +}; + +constexpr auto kPerTensorAffine = QScheme::PER_TENSOR_AFFINE; +constexpr auto kPerChannelAffine = QScheme::PER_CHANNEL_AFFINE; +constexpr auto kPerTensorSymmetric = QScheme::PER_TENSOR_SYMMETRIC; +constexpr auto kPerChannelSymmetric = QScheme::PER_CHANNEL_SYMMETRIC; +constexpr auto kPerChannelAffineFloatQParams = + QScheme::PER_CHANNEL_AFFINE_FLOAT_QPARAMS; +constexpr int COMPILE_TIME_NUM_QSCHEMES = + static_cast(QScheme::COMPILE_TIME_NUM_QSCHEMES); + +inline std::string toString(QScheme qscheme) { + switch (qscheme) { + case kPerTensorAffine: + return "per_tensor_affine"; + case kPerChannelAffine: + return "per_channel_affine"; + case kPerTensorSymmetric: + return "per_tensor_symmetric"; + case kPerChannelSymmetric: + return "per_channel_symmetric"; + case kPerChannelAffineFloatQParams: + return "per_channel_affine_float_qparams"; + default: + TORCH_CHECK(false, "Unrecognized qscheme: ", static_cast(qscheme)); + } +} + +} // namespace c10 + +C10_DIAGNOSTIC_POP() + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/RefcountedDeleter.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/RefcountedDeleter.h new file mode 100644 index 0000000000000000000000000000000000000000..8b1e9ca7071a032e6a383dc539b8010af535471b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/RefcountedDeleter.h @@ -0,0 +1,57 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include + +#include +#include + +namespace c10 { + +// A RefcountedDeleterContext object is used as the `ctx` argument for DataPtr +// to implement a shared DataPtr. Normally, a DataPtr is unique, but we use +// this custom context and the `refcounted_deleter` function below to make the +// DataPtr act like a non-unique DataPtr. This context object holds onto an +// inner context and deleter function which handle the actual deletion of the +// data when the refcount reaches 0. +// +// This shared DataPtr feature is only used when storages are shared between +// multiple Python interpreters in MultiPy. // codespell:ignore multipy +// Before storages had PyObject preservation, interpreters could just share the +// same StorageImpl instance. But now a StorageImpl can only be associated with +// one interpreter in order to properly manage a zombie PyObject. So we share +// storages across Python interpreters by creating a different StorageImpl +// instance for each one, but they all point to the same data. +struct C10_API RefcountedDeleterContext { + RefcountedDeleterContext(void* other_ctx, c10::DeleterFnPtr other_deleter) + : other_ctx(other_ctx, other_deleter), refcount(1) {} + + std::unique_ptr other_ctx; + std::atomic_int refcount; +}; + +// `refcounted_deleter` is used as the `ctx_deleter` for DataPtr to implement +// a shared DataPtr. +// +// Warning: This should only be called on a pointer to +// a RefcountedDeleterContext that was allocated on the heap with `new`, +// because when the refcount reaches 0, the context is deleted with `delete` +C10_API void refcounted_deleter(void* ctx_); + +// If the storage's DataPtr does not use `refcounted_deleter`, replace it with +// a DataPtr that does, so it can be shared between multiple StorageImpls +C10_API void maybeApplyRefcountedDeleter(const c10::Storage& storage); + +// Create a new StorageImpl that points to the same data. If the original +// StorageImpl's DataPtr does not use `refcounted_deleter`, it will be replaced +// with one that does +C10_API c10::Storage newStorageImplFromRefcountedDataPtr( + const c10::Storage& storage); + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/SafePyObject.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/SafePyObject.h new file mode 100644 index 0000000000000000000000000000000000000000..bf8eee0e004b5e49c39d9718736df1099769ef24 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/SafePyObject.h @@ -0,0 +1,125 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include + +namespace c10 { + +// This is an safe owning holder for a PyObject, akin to pybind11's +// py::object, with two major differences: +// +// - It is in c10/core; i.e., you can use this type in contexts where +// you do not have a libpython dependency +// +// - It is multi-interpreter safe (ala torchdeploy); when you fetch +// the underlying PyObject* you are required to specify what the current +// interpreter context is and we will check that you match it. +// +// It is INVALID to store a reference to a Tensor object in this way; +// you should just use TensorImpl directly in that case! +struct C10_API SafePyObject { + // Steals a reference to data + SafePyObject(PyObject* data, c10::impl::PyInterpreter* pyinterpreter) + : data_(data), pyinterpreter_(pyinterpreter) {} + SafePyObject(SafePyObject&& other) noexcept + : data_(std::exchange(other.data_, nullptr)), + pyinterpreter_(other.pyinterpreter_) {} + // For now it's not used, so we just disallow it. + SafePyObject& operator=(SafePyObject&&) = delete; + + SafePyObject(SafePyObject const& other) + : data_(other.data_), pyinterpreter_(other.pyinterpreter_) { + if (data_ != nullptr) { + (*pyinterpreter_)->incref(data_); + } + } + + SafePyObject& operator=(SafePyObject const& other) { + if (this == &other) { + return *this; // Handle self-assignment + } + if (other.data_ != nullptr) { + (*other.pyinterpreter_)->incref(other.data_); + } + if (data_ != nullptr) { + (*pyinterpreter_)->decref(data_); + } + data_ = other.data_; + pyinterpreter_ = other.pyinterpreter_; + return *this; + } + + ~SafePyObject() { + if (data_ != nullptr) { + (*pyinterpreter_)->decref(data_); + } + } + + c10::impl::PyInterpreter& pyinterpreter() const { + return *pyinterpreter_; + } + PyObject* ptr(const c10::impl::PyInterpreter* /*interpreter*/) const; + + // stop tracking the current object, and return it + PyObject* release() { + auto rv = data_; + data_ = nullptr; + return rv; + } + + private: + PyObject* data_; + c10::impl::PyInterpreter* pyinterpreter_; +}; + +// A newtype wrapper around SafePyObject for type safety when a python object +// represents a specific type. Note that `T` is only used as a tag and isn't +// actually used for any true purpose. +template +struct SafePyObjectT : private SafePyObject { + SafePyObjectT(PyObject* data, c10::impl::PyInterpreter* pyinterpreter) + : SafePyObject(data, pyinterpreter) {} + ~SafePyObjectT() = default; + SafePyObjectT(SafePyObjectT&& other) noexcept : SafePyObject(other) {} + SafePyObjectT(SafePyObjectT const&) = delete; + SafePyObjectT& operator=(SafePyObjectT const&) = delete; + SafePyObjectT& operator=(SafePyObjectT&&) = delete; + + using SafePyObject::ptr; + using SafePyObject::pyinterpreter; + using SafePyObject::release; +}; + +// Like SafePyObject, but non-owning. Good for references to global PyObjects +// that will be leaked on interpreter exit. You get a copy constructor/assign +// this way. +struct C10_API SafePyHandle { + SafePyHandle() : data_(nullptr), pyinterpreter_(nullptr) {} + SafePyHandle(PyObject* data, c10::impl::PyInterpreter* pyinterpreter) + : data_(data), pyinterpreter_(pyinterpreter) {} + + c10::impl::PyInterpreter& pyinterpreter() const { + return *pyinterpreter_; + } + PyObject* ptr(const c10::impl::PyInterpreter* /*interpreter*/) const; + void reset() { + data_ = nullptr; + pyinterpreter_ = nullptr; + } + operator bool() { + return data_; + } + + private: + PyObject* data_; + c10::impl::PyInterpreter* pyinterpreter_; +}; + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/Scalar.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/Scalar.h new file mode 100644 index 0000000000000000000000000000000000000000..863a993ed08a614ca4526fee426ebd46f5633be0 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/Scalar.h @@ -0,0 +1,471 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace c10 { + +/** + * Scalar represents a 0-dimensional tensor which contains a single element. + * Unlike a tensor, numeric literals (in C++) are implicitly convertible to + * Scalar (which is why, for example, we provide both add(Tensor) and + * add(Scalar) overloads for many operations). It may also be used in + * circumstances where you statically know a tensor is 0-dim and single size, + * but don't know its type. + */ +class C10_API Scalar { + public: + Scalar() : Scalar(int64_t(0)) {} + + void destroy() { + if (Tag::HAS_si == tag || Tag::HAS_sd == tag || Tag::HAS_sb == tag) { + raw::intrusive_ptr::decref(v.p); + v.p = nullptr; + } + } + + ~Scalar() { + destroy(); + } + +#define DEFINE_IMPLICIT_CTOR(type, name) \ + Scalar(type vv) : Scalar(vv, true) {} + + AT_FORALL_SCALAR_TYPES_AND3(Half, BFloat16, ComplexHalf, DEFINE_IMPLICIT_CTOR) + AT_FORALL_COMPLEX_TYPES(DEFINE_IMPLICIT_CTOR) + AT_FORALL_FLOAT8_TYPES(DEFINE_IMPLICIT_CTOR) + + // Helper constructors to allow Scalar creation from long and long long types + // As std::is_same_v is false(except Android), one needs to + // provide a constructor from either long or long long in addition to one from + // int64_t +#if defined(__APPLE__) || defined(__MACOSX) + static_assert( + std::is_same_v, + "int64_t is the same as long long on MacOS"); + Scalar(long vv) : Scalar(vv, true) {} +#endif +#if defined(_MSC_VER) + static_assert( + std::is_same_v, + "int64_t is the same as long long on Windows"); + Scalar(long vv) : Scalar(vv, true) {} +#endif +#if defined(__linux__) && !defined(__ANDROID__) + static_assert( + sizeof(void*) != 8 || std::is_same_v, + "int64_t is the same as long on 64 bit Linux"); +#if LONG_MAX != INT_MAX + Scalar(long long vv) : Scalar(vv, true) {} +#endif /* not 32-bit system */ +#endif + + Scalar(uint16_t vv) : Scalar(vv, true) {} + Scalar(uint32_t vv) : Scalar(vv, true) {} + Scalar(uint64_t vv) { + if (vv > static_cast(INT64_MAX)) { + tag = Tag::HAS_u; + v.u = vv; + } else { + tag = Tag::HAS_i; + // NB: no need to use convert, we've already tested convertibility + v.i = static_cast(vv); + } + } + +#undef DEFINE_IMPLICIT_CTOR + + // Value* is both implicitly convertible to SymbolicVariable and bool which + // causes ambiguity error. Specialized constructor for bool resolves this + // problem. + template < + typename T, + typename std::enable_if_t, bool>* = nullptr> + Scalar(T vv) : tag(Tag::HAS_b) { + v.i = convert(vv); + } + + template < + typename T, + typename std::enable_if_t, bool>* = + nullptr> + Scalar(T vv) : tag(Tag::HAS_sb) { + v.i = convert(vv); + } + +#define DEFINE_ACCESSOR(type, name) \ + type to##name() const { \ + if (Tag::HAS_d == tag) { \ + return checked_convert(v.d, #type); \ + } else if (Tag::HAS_z == tag) { \ + return checked_convert>(v.z, #type); \ + } else if (Tag::HAS_sd == tag) { \ + return checked_convert( \ + toSymFloat().guard_float(__FILE__, __LINE__), #type); \ + } \ + if (Tag::HAS_b == tag) { \ + return checked_convert(v.i, #type); \ + } else if (Tag::HAS_i == tag) { \ + return checked_convert(v.i, #type); \ + } else if (Tag::HAS_u == tag) { \ + return checked_convert(v.u, #type); \ + } else if (Tag::HAS_si == tag) { \ + return checked_convert( \ + toSymInt().guard_int(__FILE__, __LINE__), #type); \ + } else if (Tag::HAS_sb == tag) { \ + return checked_convert( \ + toSymBool().guard_bool(__FILE__, __LINE__), #type); \ + } \ + TORCH_CHECK(false) \ + } + + // TODO: Support ComplexHalf accessor + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_ACCESSOR) + DEFINE_ACCESSOR(uint16_t, UInt16) + DEFINE_ACCESSOR(uint32_t, UInt32) + DEFINE_ACCESSOR(uint64_t, UInt64) + +#undef DEFINE_ACCESSOR + + SymInt toSymInt() const { + if (Tag::HAS_si == tag) { + return c10::SymInt(intrusive_ptr::reclaim_copy( + static_cast(v.p))); + } else { + return toLong(); + } + } + + SymFloat toSymFloat() const { + if (Tag::HAS_sd == tag) { + return c10::SymFloat(intrusive_ptr::reclaim_copy( + static_cast(v.p))); + } else { + return toDouble(); + } + } + + SymBool toSymBool() const { + if (Tag::HAS_sb == tag) { + return c10::SymBool(intrusive_ptr::reclaim_copy( + static_cast(v.p))); + } else { + return toBool(); + } + } + + // also support scalar.to(); + // Deleted for unsupported types, but specialized below for supported types + template + T to() const = delete; + + // audit uses of data_ptr + const void* data_ptr() const { + TORCH_INTERNAL_ASSERT(!isSymbolic()); + return static_cast(&v); + } + + bool isFloatingPoint() const { + return Tag::HAS_d == tag || Tag::HAS_sd == tag; + } + + [[deprecated( + "isIntegral is deprecated. Please use the overload with 'includeBool' parameter instead.")]] bool + isIntegral() const { + return Tag::HAS_i == tag || Tag::HAS_si == tag || Tag::HAS_u == tag; + } + + bool isIntegral(bool includeBool) const { + return Tag::HAS_i == tag || Tag::HAS_si == tag || Tag::HAS_u == tag || + (includeBool && isBoolean()); + } + + // See Note [Meaning of HAS_u] + bool isUnsigned() const { + return Tag::HAS_u == tag || (Tag::HAS_i == tag && v.i >= 0); + } + + bool isComplex() const { + return Tag::HAS_z == tag; + } + bool isBoolean() const { + return Tag::HAS_b == tag || Tag::HAS_sb == tag; + } + + // you probably don't actually want these; they're mostly for testing + bool isSymInt() const { + return Tag::HAS_si == tag; + } + bool isSymFloat() const { + return Tag::HAS_sd == tag; + } + bool isSymBool() const { + return Tag::HAS_sb == tag; + } + + bool isSymbolic() const { + return Tag::HAS_si == tag || Tag::HAS_sd == tag || Tag::HAS_sb == tag; + } + + C10_ALWAYS_INLINE Scalar& operator=(Scalar&& other) noexcept { + if (&other == this) { + return *this; + } + + destroy(); + moveFrom(std::move(other)); + return *this; + } + + C10_ALWAYS_INLINE Scalar& operator=(const Scalar& other) { + if (&other == this) { + return *this; + } + + *this = Scalar(other); + return *this; + } + + Scalar operator-() const; + Scalar conj() const; + Scalar log() const; + + template < + typename T, + typename std::enable_if_t::value, int> = 0> + bool equal(T num) const { + if (isComplex()) { + TORCH_INTERNAL_ASSERT(!isSymbolic()); + auto val = v.z; + return (val.real() == num) && (val.imag() == T()); + } else if (isFloatingPoint()) { + return toDouble() == num; + } else if (tag == Tag::HAS_i) { + if (overflows(v.i, /* strict_unsigned */ true)) { + return false; + } else { + return static_cast(v.i) == num; + } + } else if (tag == Tag::HAS_u) { + if (overflows(v.u, /* strict_unsigned */ true)) { + return false; + } else { + return static_cast(v.u) == num; + } + } else if (tag == Tag::HAS_si) { + TORCH_INTERNAL_ASSERT(false, "NYI SymInt equality"); + } else if (isBoolean()) { + // boolean scalar does not equal to a non boolean value + TORCH_INTERNAL_ASSERT(!isSymbolic()); + return false; + } else { + TORCH_INTERNAL_ASSERT(false); + } + } + + template < + typename T, + typename std::enable_if_t::value, int> = 0> + bool equal(T num) const { + if (isComplex()) { + TORCH_INTERNAL_ASSERT(!isSymbolic()); + return v.z == num; + } else if (isFloatingPoint()) { + return (toDouble() == num.real()) && (num.imag() == T()); + } else if (tag == Tag::HAS_i) { + if (overflows(v.i, /* strict_unsigned */ true)) { + return false; + } else { + return static_cast(v.i) == num.real() && num.imag() == T(); + } + } else if (tag == Tag::HAS_u) { + if (overflows(v.u, /* strict_unsigned */ true)) { + return false; + } else { + return static_cast(v.u) == num.real() && num.imag() == T(); + } + } else if (tag == Tag::HAS_si) { + TORCH_INTERNAL_ASSERT(false, "NYI SymInt equality"); + } else if (isBoolean()) { + // boolean scalar does not equal to a non boolean value + TORCH_INTERNAL_ASSERT(!isSymbolic()); + return false; + } else { + TORCH_INTERNAL_ASSERT(false); + } + } + + bool equal(bool num) const { + if (isBoolean()) { + TORCH_INTERNAL_ASSERT(!isSymbolic()); + return static_cast(v.i) == num; + } else { + return false; + } + } + + ScalarType type() const { + if (isComplex()) { + return ScalarType::ComplexDouble; + } else if (isFloatingPoint()) { + return ScalarType::Double; + } else if (isIntegral(/*includeBool=*/false)) { + // Represent all integers as long, UNLESS it is unsigned and therefore + // unrepresentable as long + if (Tag::HAS_u == tag) { + return ScalarType::UInt64; + } + return ScalarType::Long; + } else if (isBoolean()) { + return ScalarType::Bool; + } else { + TORCH_CHECK(false, "Unknown scalar type."); + } + } + + Scalar(Scalar&& rhs) noexcept : tag(rhs.tag) { + moveFrom(std::move(rhs)); + } + + Scalar(const Scalar& rhs) : tag(rhs.tag), v(rhs.v) { + if (isSymbolic()) { + c10::raw::intrusive_ptr::incref(v.p); + } + } + + Scalar(c10::SymInt si) { + if (auto m = si.maybe_as_int()) { + tag = Tag::HAS_i; + v.i = *m; + } else { + tag = Tag::HAS_si; + v.p = std::move(si).release(); + } + } + + Scalar(c10::SymFloat sd) { + if (sd.is_symbolic()) { + tag = Tag::HAS_sd; + v.p = std::move(sd).release(); + } else { + tag = Tag::HAS_d; + v.d = sd.as_float_unchecked(); + } + } + + Scalar(c10::SymBool sb) { + if (auto m = sb.maybe_as_bool()) { + tag = Tag::HAS_b; + v.i = *m; + } else { + tag = Tag::HAS_sb; + v.p = std::move(sb).release(); + } + } + + // We can't set v in the initializer list using the + // syntax v{ .member = ... } because it doesn't work on MSVC + private: + enum class Tag { HAS_d, HAS_i, HAS_u, HAS_z, HAS_b, HAS_sd, HAS_si, HAS_sb }; + + // Note [Meaning of HAS_u] + // ~~~~~~~~~~~~~~~~~~~~~~~ + // HAS_u is a bit special. On its face, it just means that we + // are holding an unsigned integer. However, we generally don't + // distinguish between different bit sizes in Scalar (e.g., we represent + // float as double), instead, it represents a mathematical notion + // of some quantity (integral versus floating point). So actually, + // HAS_u is used solely to represent unsigned integers that could + // not be represented as a signed integer. That means only uint64_t + // potentially can get this tag; smaller types like uint8_t fits into a + // regular int and so for BC reasons we keep as an int. + + // NB: assumes that self has already been cleared + // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved) + C10_ALWAYS_INLINE void moveFrom(Scalar&& rhs) noexcept { + v = rhs.v; + tag = rhs.tag; + if (rhs.tag == Tag::HAS_si || rhs.tag == Tag::HAS_sd || + rhs.tag == Tag::HAS_sb) { + // Move out of scalar + rhs.tag = Tag::HAS_i; + rhs.v.i = 0; + } + } + + Tag tag; + + union v_t { + double d{}; + int64_t i; + // See Note [Meaning of HAS_u] + uint64_t u; + c10::complex z; + c10::intrusive_ptr_target* p; + // NOLINTNEXTLINE(modernize-use-equals-default) + v_t() {} // default constructor + } v; + + template < + typename T, + typename std::enable_if_t< + std::is_integral_v && !std::is_same_v, + bool>* = nullptr> + Scalar(T vv, bool /*unused*/) : tag(Tag::HAS_i) { + v.i = convert(vv); + } + + template < + typename T, + typename std::enable_if_t< + !std::is_integral_v && !c10::is_complex::value, + bool>* = nullptr> + Scalar(T vv, bool /*unused*/) : tag(Tag::HAS_d) { + v.d = convert(vv); + } + + template < + typename T, + typename std::enable_if_t::value, bool>* = nullptr> + Scalar(T vv, bool /*unused*/) : tag(Tag::HAS_z) { + v.z = convert(vv); + } +}; + +using OptionalScalarRef = c10::OptionalRef; + +// define the scalar.to() specializations +#define DEFINE_TO(T, name) \ + template <> \ + inline T Scalar::to() const { \ + return to##name(); \ + } +AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_TO) +DEFINE_TO(uint16_t, UInt16) +DEFINE_TO(uint32_t, UInt32) +DEFINE_TO(uint64_t, UInt64) +#undef DEFINE_TO + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/ScalarType.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/ScalarType.h new file mode 100644 index 0000000000000000000000000000000000000000..b678a22630d3d9e625b62149a580b3a0b3bbed9a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/ScalarType.h @@ -0,0 +1,285 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include + +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-enum") +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-default") + +namespace c10 { + +// See [dtype Macros note] in torch/headeronly/core/ScalarType.h +// regarding macros. + +#define DEFINE_CONSTANT(_, name) \ + constexpr ScalarType k##name = ScalarType::name; + +// NOLINTNEXTLINE(clang-diagnostic-unused-const-variable) +AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CONSTANT) +#undef DEFINE_CONSTANT + +inline size_t elementSize(ScalarType t) { +#define CASE_ELEMENTSIZE_CASE(ctype, name) \ + case ScalarType::name: \ + return sizeof(ctype); + + switch (t) { + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(CASE_ELEMENTSIZE_CASE) + default: + TORCH_CHECK(false, "Unknown ScalarType"); + } +#undef CASE_ELEMENTSIZE_CASE +} + +inline bool isIntegralType(ScalarType t, bool includeBool) { + bool isIntegral = + (t == ScalarType::Byte || t == ScalarType::Char || t == ScalarType::Int || + t == ScalarType::Long || t == ScalarType::Short || + t == ScalarType::UInt16 || t == ScalarType::UInt32 || + t == ScalarType::UInt64); + + return isIntegral || (includeBool && t == ScalarType::Bool); +} + +[[deprecated( + "isIntegralType is deprecated. Please use the overload with 'includeBool' parameter instead.")]] inline bool +isIntegralType(ScalarType t) { + return isIntegralType(t, /*includeBool=*/false); +} + +inline bool isFloat8Type(ScalarType t) { + return t == ScalarType::Float8_e5m2 || t == ScalarType::Float8_e5m2fnuz || + t == ScalarType::Float8_e4m3fn || t == ScalarType::Float8_e4m3fnuz || + t == ScalarType::Float8_e8m0fnu; +} + +inline bool isReducedFloatingType(ScalarType t) { + return t == ScalarType::Half || t == ScalarType::BFloat16 || + isFloat8Type(t) || t == ScalarType::Float4_e2m1fn_x2; +} + +inline bool isFloatingType(ScalarType t) { + return t == ScalarType::Double || t == ScalarType::Float || + isReducedFloatingType(t); +} + +inline bool isComplexType(ScalarType t) { + return ( + t == ScalarType::ComplexHalf || t == ScalarType::ComplexFloat || + t == ScalarType::ComplexDouble); +} + +inline bool isBitsType(ScalarType t) { + return t == ScalarType::Bits1x8 || t == ScalarType::Bits2x4 || + t == ScalarType::Bits4x2 || t == ScalarType::Bits8 || + t == ScalarType::Bits16; +} + +inline bool isBarebonesUnsignedType(ScalarType t) { + return t == ScalarType::UInt1 || t == ScalarType::UInt2 || + t == ScalarType::UInt3 || t == ScalarType::UInt4 || + t == ScalarType::UInt5 || t == ScalarType::UInt6 || + t == ScalarType::UInt7 || t == ScalarType::UInt16 || + t == ScalarType::UInt32 || t == ScalarType::UInt64; +} + +inline ScalarType toQIntType(ScalarType t) { + switch (t) { + case ScalarType::Byte: + return ScalarType::QUInt8; + case ScalarType::Char: + return ScalarType::QInt8; + case ScalarType::Int: + return ScalarType::QInt32; + default: + return t; + } +} + +inline bool isSignedType(ScalarType t) { +#define CASE_ISSIGNED(name) \ + case ScalarType::name: \ + return std::numeric_limits< \ + ::c10::impl::ScalarTypeToCPPTypeT>::is_signed; + + // TODO(#146647): If we expect to have numeric_limits for everything, + // let's just have a big macro for the whole thing. + // If we're hardcoding it, let's just use the macro and a "true"/"false" + // below? + switch (t) { + case ScalarType::QInt8: + case ScalarType::QUInt8: + case ScalarType::QInt32: + case ScalarType::QUInt4x2: + case ScalarType::QUInt2x4: + TORCH_CHECK(false, "isSignedType not supported for quantized types"); + case ScalarType::Bits1x8: + case ScalarType::Bits2x4: + case ScalarType::Bits4x2: + case ScalarType::Bits8: + case ScalarType::Bits16: + TORCH_CHECK(false, "Bits types are undefined"); + CASE_ISSIGNED(UInt16); + CASE_ISSIGNED(UInt32); + CASE_ISSIGNED(UInt64); + CASE_ISSIGNED(BFloat16); + CASE_ISSIGNED(Float8_e5m2); + CASE_ISSIGNED(Float8_e5m2fnuz); + CASE_ISSIGNED(Float8_e4m3fn); + CASE_ISSIGNED(Float8_e4m3fnuz); + CASE_ISSIGNED(Float8_e8m0fnu); + CASE_ISSIGNED(Byte); + CASE_ISSIGNED(Char); + CASE_ISSIGNED(Short); + CASE_ISSIGNED(Int); + CASE_ISSIGNED(Long); + CASE_ISSIGNED(Half); + CASE_ISSIGNED(Float); + CASE_ISSIGNED(Double); + CASE_ISSIGNED(ComplexHalf); + CASE_ISSIGNED(ComplexFloat); + CASE_ISSIGNED(ComplexDouble); + CASE_ISSIGNED(Bool); + case ScalarType::Int1: + case ScalarType::Int2: + case ScalarType::Int3: + case ScalarType::Int4: + case ScalarType::Int5: + case ScalarType::Int6: + case ScalarType::Int7: + case ScalarType::Float4_e2m1fn_x2: + return true; + case ScalarType::UInt1: + case ScalarType::UInt2: + case ScalarType::UInt3: + case ScalarType::UInt4: + case ScalarType::UInt5: + case ScalarType::UInt6: + case ScalarType::UInt7: + return false; + case ScalarType::Undefined: + case ScalarType::NumOptions: + break; + // Do not add default here, but rather define behavior of every new entry + // here. `-Wswitch-enum` would raise a warning in those cases. + // TODO: get PyTorch to adopt exhaustive switches by default with a way to + // opt specific switches to being non-exhaustive. + // Exhaustive: + // `-Wswitch-enum`, `-Wswitch-default`, `-Wno-covered-switch-default` + // Non-Exhaustive: + // `-Wno-switch-enum`, `-Wswitch-default`, `-Wcovered-switch-default` + } + TORCH_CHECK(false, "Unknown ScalarType ", t); +#undef CASE_ISSIGNED +} + +inline bool isUnderlying(ScalarType type, ScalarType qtype) { + return type == toUnderlying(qtype); +} + +inline ScalarType toRealValueType(ScalarType t) { + switch (t) { + case ScalarType::ComplexHalf: + return ScalarType::Half; + case ScalarType::ComplexFloat: + return ScalarType::Float; + case ScalarType::ComplexDouble: + return ScalarType::Double; + default: + return t; + } +} + +inline ScalarType toComplexType(ScalarType t) { + switch (t) { + case ScalarType::BFloat16: + // BFloat16 has range equivalent to Float, + // so we map it to ComplexFloat. + return ScalarType::ComplexFloat; + case ScalarType::Half: + return ScalarType::ComplexHalf; + case ScalarType::Float: + return ScalarType::ComplexFloat; + case ScalarType::Double: + return ScalarType::ComplexDouble; + case ScalarType::ComplexHalf: + return ScalarType::ComplexHalf; + case ScalarType::ComplexFloat: + return ScalarType::ComplexFloat; + case ScalarType::ComplexDouble: + return ScalarType::ComplexDouble; + default: + TORCH_CHECK(false, "Unknown Complex ScalarType for ", t); + } +} + +// see tensor_attributes.rst for detailed explanation and examples +// of casting rules. +inline bool canCast(const ScalarType from, const ScalarType to) { + // We disallow complex -> non complex, e.g., float_tensor *= complex is + // disallowed. + if (isComplexType(from) && !isComplexType(to)) { + return false; + } + // We disallow float -> integral, e.g., int_tensor *= float is disallowed. + if (isFloatingType(from) && isIntegralType(to, false)) { + return false; + } + + // Treat bool as a distinct "category," to be consistent with type promotion + // rules (e.g. `bool_tensor + 5 -> int64_tensor`). If `5` was in the same + // category as `bool_tensor`, we would not promote. Differing categories + // implies `bool_tensor += 5` is disallowed. + // + // NB: numpy distinguishes "unsigned" as a category to get the desired + // `bool_tensor + 5 -> int64_tensor` behavior. We don't, because: + // * We don't want the performance hit of checking the runtime sign of + // Scalars. + // * `uint8_tensor + 5 -> int64_tensor` would be undesirable. + if (from != ScalarType::Bool && to == ScalarType::Bool) { + return false; + } + return true; +} + +C10_API ScalarType promoteTypes(ScalarType a, ScalarType b); + +// Returns a pair of strings representing the names for each dtype. +// The returned pair is (name, legacy_name_if_applicable) +C10_API std::pair getDtypeNames( + c10::ScalarType scalarType); + +// Returns a map of string name to dtype. +C10_API const std::unordered_map& getStringToDtypeMap(); + +} // namespace c10 + +C10_DIAGNOSTIC_POP() + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/ScalarTypeToTypeMeta.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/ScalarTypeToTypeMeta.h new file mode 100644 index 0000000000000000000000000000000000000000..d952b0dd2089207bef2bd3b53d348d6cb667e046 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/ScalarTypeToTypeMeta.h @@ -0,0 +1,62 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include + +// these just expose TypeMeta/ScalarType bridge functions in c10 +// TODO move to typeid.h (or codemod away) when TypeMeta et al +// are moved from caffe2 to c10 (see note at top of typeid.h) + +namespace c10 { + +/** + * convert ScalarType enum values to TypeMeta handles + */ +inline caffe2::TypeMeta scalarTypeToTypeMeta(ScalarType scalar_type) { + return caffe2::TypeMeta::fromScalarType(scalar_type); +} + +/** + * convert TypeMeta handles to ScalarType enum values + */ +inline ScalarType typeMetaToScalarType(caffe2::TypeMeta dtype) { + return dtype.toScalarType(); +} + +/** + * typeMetaToScalarType(), lifted to optional + */ +inline std::optional optTypeMetaToScalarType( + std::optional type_meta) { + if (!type_meta.has_value()) { + return std::nullopt; + } + return type_meta->toScalarType(); +} + +/** + * convenience: equality across TypeMeta/ScalarType conversion + */ +inline bool operator==(ScalarType t, caffe2::TypeMeta m) { + return m.isScalarType(t); +} + +inline bool operator==(caffe2::TypeMeta m, ScalarType t) { + return t == m; +} + +inline bool operator!=(ScalarType t, caffe2::TypeMeta m) { + return !(t == m); +} + +inline bool operator!=(caffe2::TypeMeta m, ScalarType t) { + return !(t == m); +} + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/Storage.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/Storage.h new file mode 100644 index 0000000000000000000000000000000000000000..203eec24c05e28e413b69dc71fbb0b7be65538a2 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/Storage.h @@ -0,0 +1,293 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace c10 { + +struct Storage; + +C10_API bool isSharedStorageAlias( + const Storage& storage0, + const Storage& storage1); + +struct C10_API Storage { + public: + struct use_byte_size_t {}; + struct unsafe_borrow_t { + explicit unsafe_borrow_t() = default; + }; + + Storage() = default; + Storage(c10::intrusive_ptr ptr) + : storage_impl_(std::move(ptr)) {} + + // Allocates memory buffer using given allocator and creates a storage with it + Storage( + use_byte_size_t /*use_byte_size*/, + const SymInt& size_bytes, + Allocator* allocator = nullptr, + bool resizable = false) + : storage_impl_(c10::make_intrusive( + StorageImpl::use_byte_size_t(), + size_bytes, + allocator, + resizable)) {} + + // Creates storage with pre-allocated memory buffer. Allocator is given for + // potential future reallocations, however it can be nullptr if the storage + // is non-resizable + Storage( + use_byte_size_t /*use_byte_size*/, + size_t size_bytes, + at::DataPtr data_ptr, + at::Allocator* allocator = nullptr, + bool resizable = false) + : storage_impl_(c10::make_intrusive( + StorageImpl::use_byte_size_t(), + size_bytes, + std::move(data_ptr), + allocator, + resizable)) {} + + // Creates storage with pre-allocated memory buffer. Allocator is given for + // potential future reallocations, however it can be nullptr if the storage + // is non-resizable + Storage( + use_byte_size_t /*use_byte_size*/, + SymInt size_bytes, + at::DataPtr data_ptr, + at::Allocator* allocator = nullptr, + bool resizable = false) + : storage_impl_(c10::make_intrusive( + StorageImpl::use_byte_size_t(), + std::move(size_bytes), + std::move(data_ptr), + allocator, + resizable)) {} + + protected: + explicit Storage(unsafe_borrow_t /*unused*/, const Storage& rhs) + : storage_impl_(c10::intrusive_ptr::reclaim( + rhs.storage_impl_.get())) {} + + friend MaybeOwnedTraits; + + public: + // Legacy constructor for partially initialized (dtype or memory) storages + // that can be temporarily created with Caffe2 APIs. See the note on top of + // TensorImpl.h for details. + static Storage create_legacy(at::Device device) { + auto allocator = GetAllocator(device.type()); + return Storage(c10::make_intrusive( + StorageImpl::use_byte_size_t(), + 0, + allocator->allocate(0), // materialize a non-default Device. + allocator, + true)); + } + + // Mimic create_legacy, but without requiring a newly-created StorageImpl. + void reset_legacy() { + TORCH_CHECK(resizable() && allocator()); + set_nbytes(0); + set_data_ptr_noswap(allocator()->allocate(0)); + } + + // TODO: remove later + void set_nbytes(size_t size_bytes) const { + storage_impl_->set_nbytes(size_bytes); + } + + void set_nbytes(c10::SymInt size_bytes) const { + storage_impl_->set_nbytes(std::move(size_bytes)); + } + + bool resizable() const { + return storage_impl_->resizable(); + } + + size_t nbytes() const { + return storage_impl_->nbytes(); + } + + SymInt sym_nbytes() const { + return storage_impl_->sym_nbytes(); + } + // get() use here is to get const-correctness + + const void* data() const { + return storage_impl_->data(); + } + + void* mutable_data() const { + return storage_impl_->mutable_data(); + } + + at::DataPtr& mutable_data_ptr() const { + return storage_impl_->mutable_data_ptr(); + } + + const at::DataPtr& data_ptr() const { + return storage_impl_->data_ptr(); + } + + // Returns the previous data_ptr + at::DataPtr set_data_ptr(at::DataPtr&& data_ptr) const { + return storage_impl_->set_data_ptr(std::move(data_ptr)); + } + + void set_data_ptr_noswap(at::DataPtr&& data_ptr) const { + storage_impl_->set_data_ptr_noswap(std::move(data_ptr)); + } + + DeviceType device_type() const { + return storage_impl_->device_type(); + } + + at::Allocator* allocator() const { + return storage_impl_->allocator(); + } + + at::Device device() const { + return storage_impl_->device(); + } + + StorageImpl* unsafeReleaseStorageImpl() { + return storage_impl_.release(); + } + + StorageImpl* unsafeGetStorageImpl() const noexcept { + return storage_impl_.get(); + } + + c10::weak_intrusive_ptr getWeakStorageImpl() const { + return c10::weak_intrusive_ptr(storage_impl_); + } + + operator bool() const { + return storage_impl_; + } + + size_t use_count() const { + return storage_impl_.use_count(); + } + + inline bool unique() const { + return storage_impl_.unique(); + } + + bool is_alias_of(const Storage& other) const { + return ( + storage_impl_ == other.storage_impl_ || + isSharedStorageAlias(*this, other)); + } + + void UniqueStorageShareExternalPointer( + void* src, + size_t capacity, + DeleterFnPtr d = nullptr) { + if (!storage_impl_.unique()) { + TORCH_CHECK( + false, + "UniqueStorageShareExternalPointer can only be called when use_count == 1"); + } + storage_impl_->UniqueStorageShareExternalPointer(src, capacity, d); + } + + void UniqueStorageShareExternalPointer( + at::DataPtr&& data_ptr, + size_t capacity) { + if (!storage_impl_.unique()) { + TORCH_CHECK( + false, + "UniqueStorageShareExternalPointer can only be called when use_count == 1"); + } + storage_impl_->UniqueStorageShareExternalPointer( + std::move(data_ptr), capacity); + } + + protected: + c10::intrusive_ptr storage_impl_; +}; + +template <> +struct MaybeOwnedTraits { + using owned_type = c10::Storage; + using borrow_type = c10::Storage; + + static borrow_type createBorrow(const owned_type& from) { + return borrow_type(borrow_type::unsafe_borrow_t{}, from); + } + + static void assignBorrow(borrow_type& lhs, const borrow_type& rhs) { + lhs.unsafeReleaseStorageImpl(); + lhs = borrow_type(borrow_type::unsafe_borrow_t{}, rhs); + } + + static void destroyBorrow(borrow_type& toDestroy) { + toDestroy.unsafeReleaseStorageImpl(); // "leak" it, but it was already +0. + } + + static const owned_type& referenceFromBorrow(const borrow_type& borrow) { + return borrow; + } + + static const owned_type* pointerFromBorrow(const borrow_type& borrow) { + return &borrow; + } + + static bool debugBorrowIsValid(const borrow_type& /*borrow*/) { + return true; + } +}; + +template <> +struct ExclusivelyOwnedTraits { + using repr_type = c10::Storage; + using pointer_type = c10::Storage*; + using const_pointer_type = const c10::Storage*; + + static repr_type nullRepr() { + return c10::Storage(); + } + + template + static repr_type createInPlace(Args&&... args) { + return c10::Storage(std::forward(args)...); + } + + static repr_type moveToRepr(c10::Storage&& x) { + return std::move(x); + } + + static c10::Storage take(c10::Storage& x) { + return std::move(x); + } + + static pointer_type getImpl(repr_type& x) { + return &x; + } + + static const_pointer_type getImpl(const repr_type& x) { + return &x; + } +}; + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/StorageImpl.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/StorageImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..2acfa40771c5f29fb41565a06dfd6944a1a55ea4 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/StorageImpl.h @@ -0,0 +1,398 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace c10 { + +[[noreturn]] C10_API void throwNullDataPtrError(); +C10_API void warnDeprecatedDataPtr(); + +// Used in StorageImpl to store extra metadata. +// Currently used only for storing a custom error message +// used when throwing an exception when data_ptr is accessed. +struct C10_API StorageExtraMeta { + std::optional custom_data_ptr_error_msg_ = std::nullopt; +}; + +// A storage represents the underlying backing data buffer for a +// tensor. This concept was inherited from the original Torch7 +// codebase; we'd kind of like to get rid of the concept +// (see https://github.com/pytorch/pytorch/issues/14797) but +// it's hard work and no one has gotten around to doing it. +// +// NB: storage is supposed to uniquely own a data pointer; e.g., +// two non-null data pointers alias if and only if they are from +// the same storage. Technically you can violate this invariant +// (e.g., you can create a non-owning StorageImpl with at::from_blob) +// but a lot of things won't work correctly, including: +// +// - An ordinary deleter on such a storage is wrong, because normal deleters +// assume unique ownership, but if you have two storages at the same data, +// that implies there is some sort of shared ownership. So your deleter would +// have to actually be internally doing some sort of refcount thing +// - Deepcopy in Python side relies on storage equality and not data pointer +// equality; so if there are two separate storages pointing to the same data, +// the data will actually get duplicated in that case (one data ptr before, +// two data ptrs after) +// - Version counts won't work correctly, because we do all VC tracking at the +// level of storages (unless you explicitly disconnect the VC with detach); +// mutation because data pointers are the same are totally untracked +struct C10_API StorageImpl : public c10::intrusive_ptr_target { + public: + struct use_byte_size_t {}; + + StorageImpl( + use_byte_size_t /*use_byte_size*/, + SymInt size_bytes, + at::DataPtr data_ptr, + at::Allocator* allocator, + bool resizable) + : data_ptr_(std::move(data_ptr)), + size_bytes_(std::move(size_bytes)), + size_bytes_is_heap_allocated_(size_bytes_.is_heap_allocated()), + resizable_(resizable), + received_cuda_(false), + allocator_(allocator) { + if (resizable) { + TORCH_INTERNAL_ASSERT( + allocator_, "For resizable storage, allocator must be provided"); + } + refresh_has_data_ptr_check(); + } + + StorageImpl( + use_byte_size_t /*use_byte_size*/, + const SymInt& size_bytes, + at::Allocator* allocator, + bool resizable) + : StorageImpl( + use_byte_size_t(), + size_bytes, + size_bytes.is_heap_allocated() + ? allocator->allocate(0) + : allocator->allocate(size_bytes.as_int_unchecked()), + allocator, + resizable) {} + + StorageImpl& operator=(StorageImpl&& other) = delete; + StorageImpl& operator=(const StorageImpl&) = delete; + StorageImpl() = delete; + StorageImpl(StorageImpl&& other) = delete; + StorageImpl(const StorageImpl&) = delete; + ~StorageImpl() override = default; + + void reset() { + data_ptr_.clear(); + size_bytes_ = 0; + size_bytes_is_heap_allocated_ = false; + } + + // Destructor doesn't call release_resources because it's + // unnecessary; don't forget to change that if needed! + void release_resources() override { + data_ptr_.clear(); + } + + void incref_pyobject() const noexcept override final; + + void decref_pyobject() const noexcept override final; + + bool try_incref_pyobject() const noexcept override final; + + size_t nbytes() const { + // OK to do this instead of maybe_as_int as nbytes is guaranteed positive + TORCH_CHECK(!size_bytes_is_heap_allocated_); + return size_bytes_.as_int_unchecked(); + } + + SymInt sym_nbytes() const { + return size_bytes_; + } + + // TODO: remove later + void set_nbytes(size_t size_bytes) { + size_bytes_ = static_cast(size_bytes); + size_bytes_is_heap_allocated_ = false; + } + + void unsafe_set_nbytes(size_t size_bytes) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!size_bytes_is_heap_allocated_); + size_bytes_.unsafe_set_data(size_bytes); + } + + void set_nbytes(c10::SymInt size_bytes) { + size_bytes_ = std::move(size_bytes); + } + + bool resizable() const { + return resizable_; + } + + const at::DataPtr& data_ptr() const { + if (C10_UNLIKELY(throw_on_immutable_data_ptr_)) { + throw_data_ptr_access_error(); + } + return data_ptr_; + } + + at::DataPtr& mutable_data_ptr() { + if (C10_UNLIKELY(has_mutable_data_ptr_check_)) { + if (throw_on_immutable_data_ptr_) { + throw_data_ptr_access_error(); + } + if (throw_on_mutable_data_ptr_) { + throwNullDataPtrError(); + } + if (warn_deprecated_on_mutable_data_ptr_) { + warnDeprecatedDataPtr(); + } + maybe_materialize_cow(); + } + return data_ptr_; + } + + // Returns the data_ptr. Bypasses all checks. + at::DataPtr& _mutable_data_ptr_no_checks() { + return data_ptr_; + } + + // Returns the previous data_ptr + at::DataPtr set_data_ptr(at::DataPtr&& data_ptr) { + // We need to materialize the old COW DataPtr because it is + // being returned as mutable. + maybe_materialize_cow(); + return set_data_ptr_no_materialize_cow(std::move(data_ptr)); + } + + void set_data_ptr_noswap(at::DataPtr&& data_ptr) { + data_ptr_ = std::move(data_ptr); + refresh_has_data_ptr_check(); + } + + const void* data() const { + if (C10_UNLIKELY(throw_on_immutable_data_ptr_)) { + throw_data_ptr_access_error(); + } + return data_ptr_.get(); + } + + void* mutable_data() { + if (C10_UNLIKELY(has_mutable_data_ptr_check_)) { + if (throw_on_immutable_data_ptr_) { + throw_data_ptr_access_error(); + } + if (throw_on_mutable_data_ptr_) { + throwNullDataPtrError(); + } + if (warn_deprecated_on_mutable_data_ptr_) { + warnDeprecatedDataPtr(); + } + maybe_materialize_cow(); + } + return data_ptr_.mutable_get(); + } + + at::DeviceType device_type() const { + return data_ptr_.device().type(); + } + + at::Allocator* allocator() { + return allocator_; + } + + const at::Allocator* allocator() const { + return allocator_; + } + + // You generally shouldn't use this method, but it is occasionally + // useful if you want to override how a tensor will be reallocated, + // after it was already allocated (and its initial allocator was + // set) + void set_allocator(at::Allocator* allocator) { + allocator_ = allocator; + } + + Device device() const { + return data_ptr_.device(); + } + + void set_resizable(bool resizable) { + if (resizable) { + // We need an allocator to be resizable + AT_ASSERT(allocator_); + } + resizable_ = resizable; + } + + /** + * Can only be called when use_count is 1 + */ + void UniqueStorageShareExternalPointer( + void* src, + size_t size_bytes, + DeleterFnPtr d = nullptr) { + UniqueStorageShareExternalPointer( + at::DataPtr(src, src, d, data_ptr_.device()), size_bytes); + } + + /** + * Can only be called when use_count is 1 + */ + void UniqueStorageShareExternalPointer( + at::DataPtr&& data_ptr, + size_t size_bytes) { + data_ptr_ = std::move(data_ptr); + size_bytes_ = static_cast(size_bytes); + size_bytes_is_heap_allocated_ = false; + allocator_ = nullptr; + resizable_ = false; + } + + // This method can be used only after storage construction and cannot be used + // to modify storage status + void set_received_cuda(bool received_cuda) { + received_cuda_ = received_cuda; + } + + bool received_cuda() { + return received_cuda_; + } + + impl::PyObjectSlot* pyobj_slot() { + return &pyobj_slot_; + } + + const impl::PyObjectSlot* pyobj_slot() const { + return &pyobj_slot_; + } + + StorageExtraMeta& get_extra_meta() { + if (!extra_meta_) { + extra_meta_ = std::make_unique(); + } + return *extra_meta_; + } + + [[noreturn]] void throw_data_ptr_access_error() const; + + void release_data_and_set_meta_custom_data_ptr_error_msg_( + std::optional s) { + throw_on_immutable_data_ptr_ = true; + get_extra_meta().custom_data_ptr_error_msg_ = std::move(s); + refresh_has_data_ptr_check(); + } + + void set_throw_on_mutable_data_ptr() { + throw_on_mutable_data_ptr_ = true; + refresh_has_data_ptr_check(); + } + + void set_warn_deprecated_on_mutable_data_ptr() { + warn_deprecated_on_mutable_data_ptr_ = true; + refresh_has_data_ptr_check(); + } + + protected: + // materialize_cow_storage needs to call set_data_ptr_no_materlize_cow + friend void c10::impl::cow::materialize_cow_storage(StorageImpl& storage); + + // Returns the previous data_ptr. If the old data_ptr was COW, + // this avoids materializing it + at::DataPtr set_data_ptr_no_materialize_cow(at::DataPtr&& data_ptr) { + at::DataPtr old_data_ptr(std::move(data_ptr_)); + data_ptr_ = std::move(data_ptr); + refresh_has_data_ptr_check(); + return old_data_ptr; + } + + private: + void refresh_has_data_ptr_check() { + has_mutable_data_ptr_check_ = is_cow() || throw_on_mutable_data_ptr_ || + warn_deprecated_on_mutable_data_ptr_ || throw_on_immutable_data_ptr_; + } + + inline bool is_cow() const { + return c10::impl::cow::is_cow_data_ptr(data_ptr_); + } + + // Triggers a copy if this is a copy-on-write tensor. + void maybe_materialize_cow() { + if (is_cow()) { + impl::cow::materialize_cow_storage(*this); + } + } + + DataPtr data_ptr_; + SymInt size_bytes_; + bool size_bytes_is_heap_allocated_; + bool resizable_; + // Identifies that Storage was received from another process and doesn't have + // local to process cuda memory allocation + bool received_cuda_; + // All special checks in data/data_ptr calls are guarded behind this single + // boolean. This is for performance: .data/.data_ptr calls are commonly in the + // hot-path. + bool has_mutable_data_ptr_check_ = false; + // If we should throw when mutable_data_ptr() or mutable_data() is called. + bool throw_on_mutable_data_ptr_ = false; + // If we should throw when data_ptr() or data() is called. + bool throw_on_immutable_data_ptr_ = false; + // If we warn when mutable_data_ptr() or mutable_data() is called. + bool warn_deprecated_on_mutable_data_ptr_ = false; + Allocator* allocator_; + impl::PyObjectSlot pyobj_slot_; + std::unique_ptr extra_meta_ = nullptr; +}; + +// Declare StorageImpl create function pointer types. +using StorageImplCreateHelper = intrusive_ptr (*)( + StorageImpl::use_byte_size_t, + SymInt size_bytes, + DataPtr data_ptr, + Allocator* allocator, + bool resizable); + +C10_API void SetStorageImplCreate(DeviceType t, StorageImplCreateHelper fptr); + +C10_API StorageImplCreateHelper GetStorageImplCreate(DeviceType t); + +C10_API c10::intrusive_ptr make_storage_impl( + c10::StorageImpl::use_byte_size_t use_byte_size, + c10::SymInt size_bytes, + c10::DataPtr data_ptr, + c10::Allocator* allocator, + bool resizable, + std::optional device_opt); + +namespace detail { + +#ifndef C10_MOBILE +template +struct TargetTraits< + T, + std::enable_if_t< + std::is_base_of_v>>> { + static constexpr bool can_have_pyobject = true; +}; +#endif + +} // namespace detail + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/Stream.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/Stream.h new file mode 100644 index 0000000000000000000000000000000000000000..4d3a50984ec6e9093a321b7df2855383758e50ce --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/Stream.h @@ -0,0 +1,182 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace c10 { + +/// An index representing a specific stream. A StreamId is not independently +/// meaningful without knowing the Device it is associated with; try to +/// use Stream rather than StreamId directly. +/// +/// StreamIds are opaque; they are assigned by some DeviceType-specific +/// numbering system which is not visible to the user. HOWEVER, we +/// guarantee that StreamId 0 is always a valid stream, and corresponds +/// to some sort of "default" stream. +using StreamId = int64_t; + +struct C10_API StreamData3 { + StreamId stream_id; + DeviceIndex device_index; + DeviceType device_type; +}; + +// NB: I decided not to call the above StreamIndex to avoid confusion with +// DeviceIndex. This way, you access device index with index(), and stream id +// with id() + +/** + * A stream is a software mechanism used to synchronize launched kernels + * without requiring explicit synchronizations between kernels. The basic + * model is that every kernel launch is associated with a stream: every + * kernel on the same stream is implicitly synchronized so that if I launch + * kernels A and B on the same stream, A is guaranteed to finish before B + * launches. If I want B to run concurrently with A, I must schedule + * it on a different stream. + * + * The Stream class is a backend agnostic value class representing a stream + * which I may schedule a kernel on. Every stream is associated with a device, + * which is recorded in stream, which is used to avoid confusion about which + * device a stream refers to. + * + * Streams are explicitly thread-safe, in the sense that it is OK to pass + * a Stream from one thread to another, and kernels queued from two different + * threads will still get serialized appropriately. (Of course, the + * time when the kernels get queued is undetermined unless you synchronize + * host side ;) + * + * Stream does NOT have a default constructor. Streams are for expert + * users; if you want to use Streams, we're going to assume you know + * how to deal with C++ template error messages if you try to + * resize() a vector of Streams. + * + * Known instances of streams in backends: + * + * - cudaStream_t (CUDA) + * - hipStream_t (HIP) + * - cl_command_queue (OpenCL) (NB: Caffe2's existing OpenCL integration + * does NOT support command queues.) + * + * Because this class is device agnostic, it cannot provide backend-specific + * functionality (e.g., get the cudaStream_t of a CUDA stream.) There are + * wrapper classes which provide this functionality, e.g., CUDAStream. + */ +class C10_API Stream final { + private: + Device device_; + StreamId id_; + + public: + enum Unsafe { UNSAFE }; + enum Default { DEFAULT }; + + /// Unsafely construct a stream from a Device and a StreamId. In + /// general, only specific implementations of streams for a + /// backend should manufacture Stream directly in this way; other users + /// should use the provided APIs to get a stream. In particular, + /// we don't require backends to give any guarantees about non-zero + /// StreamIds; they are welcome to allocate in whatever way they like. + explicit Stream(Unsafe /*unused*/, Device device, StreamId id) + : device_(device), id_(id) {} + + /// Construct the default stream of a Device. The default stream is + /// NOT the same as the current stream; default stream is a fixed stream + /// that never changes, whereas the current stream may be changed by + /// StreamGuard. + explicit Stream(Default /*unused*/, Device device) + : device_(device), id_(0) {} + + bool operator==(const Stream& other) const noexcept { + return this->device_ == other.device_ && this->id_ == other.id_; + } + bool operator!=(const Stream& other) const noexcept { + return !(*this == other); + } + + Device device() const noexcept { + return device_; + } + DeviceType device_type() const noexcept { + return device_.type(); + } + DeviceIndex device_index() const noexcept { + return device_.index(); + } + StreamId id() const noexcept { + return id_; + } + + // Enqueues a wait instruction in the stream's work queue. + // This instruction is a no-op unless the event is marked + // for recording. In that case the stream stops processing + // until the event is recorded. + template + void wait(const T& event) const { + event.block(*this); + } + + // Return whether all asynchronous work previously enqueued on this stream + // has completed running on the device. + bool query() const; + + // Wait (by blocking the calling thread) until all asynchronous work enqueued + // on this stream has completed running on the device. + void synchronize() const; + + // The purpose of this function is to more conveniently permit binding + // of Stream to and from Python. Without packing, I have to setup a whole + // class with two fields (device and stream id); with packing I can just + // store a single uint64_t. + // + // The particular way we pack streams into a uint64_t is considered an + // implementation detail and should not be relied upon. + uint64_t hash() const noexcept { + // Concat these together into a 64-bit integer + uint64_t bits = static_cast(device_type()) << 56 | + static_cast(device_index()) << 48 | + // Remove the sign extension part of the 64-bit address because + // the id might be used to hold a pointer. + (static_cast(id()) & ((1ull << 48) - 1)); + return bits; + } + + struct StreamData3 pack3() const { + return {id(), device_index(), device_type()}; + } + + static Stream unpack3( + StreamId stream_id, + DeviceIndex device_index, + DeviceType device_type) { + TORCH_CHECK(isValidDeviceType(device_type)); + return Stream(UNSAFE, Device(device_type, device_index), stream_id); + } + + // I decided NOT to provide setters on this class, because really, + // why would you change the device of a stream? Just construct + // it correctly from the beginning dude. +}; + +C10_API std::ostream& operator<<(std::ostream& stream, const Stream& s); + +} // namespace c10 + +namespace std { +template <> +struct hash { + size_t operator()(c10::Stream s) const noexcept { + return std::hash{}(s.hash()); + } +}; +} // namespace std + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/StreamGuard.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/StreamGuard.h new file mode 100644 index 0000000000000000000000000000000000000000..003816d62f6ce12223cc5106eee6ae37a26e04e9 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/StreamGuard.h @@ -0,0 +1,178 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace c10 { + +/** + * A StreamGuard is an RAII class that changes the current device + * to the device corresponding to some stream, and changes the + * default stream on that device to be this stream. + * + * Use of StreamGuard is HIGHLY discouraged in operator definitions. In + * a single operator, you probably don't know enough about the global + * state of the world to profitably decide how to set streams. Let + * the caller handle this appropriately, and just use the current stream + * in your operator code. + * + * This StreamGuard does NOT have an uninitialized state; it is guaranteed + * to reset the stream and device on exit. If you are in a situation + * where you *might* want to setup a stream guard, see OptionalStreamGuard. + */ +struct StreamGuard { + /// No default constructor, see Note [Omitted default constructor from RAII] + explicit StreamGuard() = delete; + ~StreamGuard() = default; + + /// Set the current device to the device associated with the passed stream, + /// and set the current stream on that device to the passed stream. + explicit StreamGuard(Stream stream) : guard_(stream) {} + + /// Copy is disallowed + StreamGuard(const StreamGuard&) = delete; + StreamGuard& operator=(const StreamGuard&) = delete; + + /// Move is disallowed, as StreamGuard does not have an uninitialized state, + /// which is required for moves on types with nontrivial destructors. + StreamGuard(StreamGuard&& other) = delete; + StreamGuard& operator=(StreamGuard&& other) = delete; + + /// Resets the currently set stream to the original stream and + /// the currently set device to the original device. Then, + /// set the current device to the device associated with the passed stream, + /// and set the current stream on that device to the passed stream. + /// + /// NOTE: this implementation may skip some stream/device setting if + /// it can prove that it is unnecessary. + /// + /// WARNING: reset_stream does NOT preserve previously set streams on + /// different devices. If you need to set streams on multiple devices + /// on , use MultiStreamGuard instead. + void reset_stream(Stream stream) { + guard_.reset_stream(stream); + } + + /// Returns the stream that was set at the time the guard was constructed. + Stream original_stream() const { + return guard_.original_stream(); + } + + /// Returns the most recent stream that was set using this device guard, + /// either from construction, or via set_stream. + Stream current_stream() const { + return guard_.current_stream(); + } + + /// Returns the most recent device that was set using this device guard, + /// either from construction, or via set_device/reset_device/set_index. + Device current_device() const { + return guard_.current_device(); + } + + /// Returns the device that was set at the most recent reset_stream(), + /// or otherwise the device at construction time. + Device original_device() const { + return guard_.original_device(); + } + + private: + c10::impl::InlineStreamGuard guard_; +}; + +/** + * An OptionalStreamGuard is an RAII class that sets a device to some value on + * initialization, and resets the device to its original value on destruction. + * See OptionalDeviceGuard for more guidance on how to use this class. + */ +struct OptionalStreamGuard { + /// Create an uninitialized guard. + explicit OptionalStreamGuard() = default; + + /// Set the current device to the device associated with the passed stream, + /// and set the current stream on that device to the passed stream. + explicit OptionalStreamGuard(Stream stream) : guard_(stream) {} + + /// Set the current device to the device associated with the passed stream, + /// and set the current stream on that device to the passed stream, + /// if the passed stream is not nullopt. + explicit OptionalStreamGuard(std::optional stream_opt) + : guard_(stream_opt) {} + + /// Copy is disallowed + OptionalStreamGuard(const OptionalStreamGuard&) = delete; + OptionalStreamGuard& operator=(const OptionalStreamGuard&) = delete; + + // See Note [Move construction for RAII guards is tricky] + OptionalStreamGuard(OptionalStreamGuard&& other) = delete; + + // See Note [Move assignment for RAII guards is tricky] + OptionalStreamGuard& operator=(OptionalStreamGuard&& other) = delete; + ~OptionalStreamGuard() = default; + + /// Resets the currently set stream to the original stream and + /// the currently set device to the original device. Then, + /// set the current device to the device associated with the passed stream, + /// and set the current stream on that device to the passed stream. + /// Initializes the guard if it was not previously initialized. + void reset_stream(Stream stream) { + guard_.reset_stream(stream); + } + + /// Returns the stream that was set at the time the guard was most recently + /// initialized, or nullopt if the guard is uninitialized. + std::optional original_stream() const { + return guard_.original_stream(); + } + + /// Returns the most recent stream that was set using this stream guard, + /// either from construction, or via reset_stream, if the guard is + /// initialized, or nullopt if the guard is uninitialized. + std::optional current_stream() const { + return guard_.current_stream(); + } + + /// Restore the original device and stream, resetting this guard to + /// uninitialized state. + void reset() { + guard_.reset(); + } + + private: + c10::impl::InlineOptionalStreamGuard guard_; +}; + +/** + * A MultiStreamGuard is an RAII class that sets the current streams of a set of + * devices all at once, and resets them to their original values on destruction. + */ +struct MultiStreamGuard { + /// Set the current streams to the passed streams on each of their respective + /// devices. + explicit MultiStreamGuard(ArrayRef streams) : guard_(streams) {} + + /// Copy is disallowed + MultiStreamGuard(const MultiStreamGuard&) = delete; + MultiStreamGuard& operator=(const MultiStreamGuard&) = delete; + + // See Note [Move construction for RAII guards is tricky] + MultiStreamGuard(MultiStreamGuard&& other) = delete; + + // See Note [Move assignment for RAII guards is tricky] + MultiStreamGuard& operator=(MultiStreamGuard&& other) = delete; + ~MultiStreamGuard() = default; + + private: + c10::impl::InlineMultiStreamGuard guard_; +}; + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/SymBool.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/SymBool.h new file mode 100644 index 0000000000000000000000000000000000000000..d12fa75fb41446f3f9967a73aed8a25fc1a60f4b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/SymBool.h @@ -0,0 +1,184 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace c10 { + +class SymInt; + +class C10_API SymBool { + public: + /*implicit*/ SymBool(bool b) : data_(b) {} + SymBool(SymNode ptr) : data_(false), ptr_(std::move(ptr)) { + TORCH_CHECK(ptr_->is_bool()); + } + SymBool() : data_(false) {} + + SymNodeImpl* toSymNodeImplUnowned() const { + return ptr_.get(); + } + + SymNodeImpl* release() && { + return std::move(ptr_).release(); + } + + // Only valid if is_heap_allocated() + SymNode toSymNodeImpl() const; + + // Guaranteed to return a SymNode, wrapping using base if necessary + SymNode wrap_node(const SymNode& base) const; + + bool expect_bool() const { + std::optional c = maybe_as_bool(); + TORCH_CHECK(c.has_value()); + return *c; + } + + SymBool sym_and(const SymBool& /*sci*/) const; + SymBool sym_or(const SymBool& /*sci*/) const; + SymBool sym_not() const; + + SymBool operator&(const SymBool& other) const { + return sym_and(other); + } + SymBool operator|(const SymBool& other) const { + return sym_or(other); + } + SymBool operator||(const SymBool& other) const { + return sym_or(other); + } + SymBool operator~() const { + return sym_not(); + } + + // Insert a guard for the bool to be its concrete value, and then return + // that value. Note that C++ comparison operations default to returning + // bool, so it's not so common to have to call this + bool guard_bool(const char* file, int64_t line) const; + bool expect_true(const char* file, int64_t line) const; + bool guard_size_oblivious(const char* file, int64_t line) const; + bool statically_known_true(const char* file, int64_t line) const; + bool guard_or_false(const char* file, int64_t line) const; + bool guard_or_true(const char* file, int64_t line) const; + + bool has_hint() const; + + bool as_bool_unchecked() const { + return data_; + } + + std::optional maybe_as_bool() const { + if (!is_heap_allocated()) { + return data_; + } + return toSymNodeImplUnowned()->constant_bool(); + } + + // Convert SymBool to SymInt (0 or 1) + // This is the C++ equivalent of Python's cast_symbool_to_symint_guardless + SymInt toSymInt() const; + + bool is_heap_allocated() const { + return ptr_; + } + + private: + // TODO: optimize to union + bool data_; + SymNode ptr_; +}; + +C10_API std::ostream& operator<<(std::ostream& os, const SymBool& s); + +#define TORCH_SYM_CHECK(cond, ...) \ + TORCH_CHECK((cond).expect_true(__FILE__, __LINE__), __VA_ARGS__) +#define TORCH_SYM_INTERNAL_ASSERT(cond, ...) \ + TORCH_INTERNAL_ASSERT((cond).expect_true(__FILE__, __LINE__), __VA_ARGS__) +#define TORCH_MAYBE_SYM_CHECK(cond, ...) \ + if constexpr (std::is_same_v, SymBool>) { \ + TORCH_CHECK((cond).expect_true(__FILE__, __LINE__), __VA_ARGS__) \ + } else { \ + TORCH_CHECK((cond), __VA_ARGS__) \ + } + +inline bool guard_size_oblivious( + bool b, + const char* file [[maybe_unused]], + int64_t line [[maybe_unused]]) { + return b; +} + +inline bool guard_size_oblivious( + const c10::SymBool& b, + const char* file, + int64_t line) { + return b.guard_size_oblivious(file, line); +} + +inline bool guard_or_false( + bool b, + const char* file [[maybe_unused]], + int64_t line [[maybe_unused]]) { + return b; +} + +inline bool guard_or_false( + const c10::SymBool& b, + const char* file, + int64_t line) { + return b.guard_or_false(file, line); +} + +inline bool statically_known_true( + bool b, + const char* file [[maybe_unused]], + int64_t line [[maybe_unused]]) { + return b; +} + +inline bool statically_known_true( + const c10::SymBool& b, + const char* file, + int64_t line) { + return b.statically_known_true(file, line); +} + +inline bool guard_or_true( + bool b, + const char* file [[maybe_unused]], + int64_t line [[maybe_unused]]) { + return b; +} + +inline bool guard_or_true( + const c10::SymBool& b, + const char* file, + int64_t line) { + return b.guard_or_true(file, line); +} + +#define TORCH_GUARD_SIZE_OBLIVIOUS(cond) \ + c10::guard_size_oblivious((cond), __FILE__, __LINE__) + +#define TORCH_STATICALLY_KNOWN_TRUE(cond) \ + c10::statically_known_true((cond), __FILE__, __LINE__) + +#define TORCH_GUARD_OR_FALSE(cond) \ + c10::guard_or_false((cond), __FILE__, __LINE__) + +#define TORCH_GUARD_OR_TRUE(cond) c10::guard_or_true((cond), __FILE__, __LINE__) + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/SymFloat.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/SymFloat.h new file mode 100644 index 0000000000000000000000000000000000000000..332726ba4c5dade5accef6a3dac6076366c04d95 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/SymFloat.h @@ -0,0 +1,123 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace c10 { + +// NB: this is actually double precision; we're using the Python naming here +class C10_API SymFloat { + public: + /*implicit*/ SymFloat(double d) : data_(d) {} + SymFloat(SymNode ptr) + : data_(std::numeric_limits::quiet_NaN()), ptr_(std::move(ptr)) { + TORCH_CHECK(ptr_->is_float()); + } + SymFloat() : data_(0.0) {} + + SymNodeImpl* toSymNodeImplUnowned() const { + return ptr_.get(); + } + + SymNodeImpl* release() && { + return std::move(ptr_).release(); + } + + // Only valid if is_symbolic() + SymNode toSymNodeImpl() const; + + // Guaranteed to return a SymNode, wrapping using base if necessary + SymNode wrap_node(const SymNode& base) const; + + double expect_float() const { + TORCH_CHECK(!is_symbolic()); + return data_; + } + + SymFloat operator+(const SymFloat& /*sci*/) const; + SymFloat operator-(const SymFloat& /*sci*/) const; + SymFloat operator*(const SymFloat& /*sci*/) const; + SymFloat operator/(const SymFloat& /*sci*/) const; + + SymBool sym_eq(const SymFloat& /*sci*/) const; + SymBool sym_ne(const SymFloat& /*sci*/) const; + SymBool sym_lt(const SymFloat& /*sci*/) const; + SymBool sym_le(const SymFloat& /*sci*/) const; + SymBool sym_gt(const SymFloat& /*sci*/) const; + SymBool sym_ge(const SymFloat& /*sci*/) const; + + bool operator==(const SymFloat& o) const { + return sym_eq(o).guard_bool(__FILE__, __LINE__); + } + bool operator!=(const SymFloat& o) const { + return sym_ne(o).guard_bool(__FILE__, __LINE__); + } + bool operator<(const SymFloat& o) const { + return sym_lt(o).guard_bool(__FILE__, __LINE__); + } + bool operator<=(const SymFloat& o) const { + return sym_le(o).guard_bool(__FILE__, __LINE__); + } + bool operator>(const SymFloat& o) const { + return sym_gt(o).guard_bool(__FILE__, __LINE__); + } + bool operator>=(const SymFloat& o) const { + return sym_ge(o).guard_bool(__FILE__, __LINE__); + } + + SymFloat min(const SymFloat& sci) const; + SymFloat max(const SymFloat& sci) const; + + // Need guidance on where to put this code + SymFloat sqrt() const; + + // Insert a guard for the float to be its concrete value, and then return + // that value. This operation always works, even if the float is symbolic, + // so long as we know what the underlying value is. Don't blindly put this + // everywhere; you can cause overspecialization of PyTorch programs with + // this method. + // + // It should be called as guard_float(__FILE__, __LINE__). The file and line + // number can be used to diagnose overspecialization. + double guard_float(const char* file, int64_t line) const; + + bool has_hint() const; + + // N.B. It's important to keep this definition in the header + // as we expect if checks to be folded for mobile builds + // where `is_symbolic` is always false + C10_ALWAYS_INLINE bool is_symbolic() const { + return ptr_; + } + + // UNSAFELY coerce this SymFloat into a double. You MUST have + // established that this is a non-symbolic by some other means, + // typically by having tested is_symbolic(). You will get garbage + // from this function if is_symbolic() + double as_float_unchecked() const { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!is_symbolic()); + return data_; + } + + private: + // TODO: optimize to union + double data_; + SymNode ptr_; +}; + +C10_API std::ostream& operator<<(std::ostream& os, const SymFloat& s); +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/SymInt.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/SymInt.h new file mode 100644 index 0000000000000000000000000000000000000000..f9fa7f645047dbf5f8a2f1831d362606e8d98e98 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/SymInt.h @@ -0,0 +1,586 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace c10 { + +class SymFloat; + +// SymInt represents either a regular int64_t, or a symbolic integer +// (represented in a type erased way as SymNode). The intention is for SymInt +// to represent symbolic sizes that arise when doing shape computation in +// operator kernels. This allows for tracing through programs without baking in +// concrete sizes into kernel calls. +// +// SymInt has an API equivalent to int64_t. In particular, it is a value type. +// Internally, SymInt is represented in a clever packed way, so that it only +// occupies one word of space; but morally, it is a union between an int64_t +// and an intrusive pointer to SymNodeImpl. +// +// Invariant: the referenced SymNodeImpl is guaranteed to be a SymNode where +// is_int() returns true + +class C10_API SymInt { + public: + enum Unchecked { + UNCHECKED, + }; + + /*implicit*/ SymInt(int64_t d) : data_(d) { + if (is_heap_allocated()) { + // Large negative number, heap allocate it + promote_to_negative(); + } + } + SymInt() : data_(0) {} + SymInt(SymNode n); + + // unchecked c-tor accepting raw `data_` + // One appropriate use for this is when you are constructing a symint + // in a situation where you know it is non-negative (or, if it is negative, + // the negative value is -1; i.e., not user controlled) + SymInt(Unchecked /*unused*/, int64_t d) : data_(d) {} + + // TODO: these implementations are not optimal because they allocate a + // temporary and then use the move constructor/assignment + SymInt(const SymInt& s) : data_(0) { + if (s.is_heap_allocated()) { + *this = SymInt(s.toSymNode()); + } else { + data_ = s.data_; + } + } + SymInt(SymInt&& s) noexcept : data_(s.data_) { + s.data_ = 0; + } + + SymInt& operator=(const SymInt& s) { + if (this != &s) { + if (s.is_heap_allocated()) { + *this = SymInt(s.toSymNode()); + } else { + data_ = s.data_; + } + } + return *this; + } + SymInt& operator=(SymInt&& s) noexcept { + if (this != &s) { + release_(); // release the current SymNode if any + data_ = s.data_; + if (s.is_heap_allocated()) + s.data_ = 0; + }; + return *this; + } + + SymNodeImpl* toSymNodeImplUnowned() const { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(is_heap_allocated()); + uint64_t unextended_bits = static_cast(data_) & ~MASK; + uint64_t sign_bit_mask = 1ULL << (62 - 1); + // https://stackoverflow.com/questions/42534749/signed-extension-from-24-bit-to-32-bit-in-c + uint64_t extended_bits = (unextended_bits ^ sign_bit_mask) - sign_bit_mask; + return static_cast( + // NOLINTNEXTLINE(performance-no-int-to-ptr, bugprone*) + reinterpret_cast(static_cast(extended_bits))); + } + + void release_() { + if (is_heap_allocated()) { + SymNode::reclaim(toSymNodeImplUnowned()); // steal + } + } + + SymNodeImpl* release() && { +#ifndef C10_MOBILE + TORCH_INTERNAL_ASSERT(is_heap_allocated()); + auto* r = toSymNodeImplUnowned(); + data_ = 0; // transfer ownership + return r; +#else + TORCH_INTERNAL_ASSERT(false); +#endif + } + + // Only valid if is_heap_allocated() + SymNode toSymNode() const; + + // Guaranteed to return a SymNode, wrapping using base if necessary + SymNode wrap_node(const SymNode& base) const; + + ~SymInt() { + release_(); + } + + // Require the int to be non-symbolic, and if it is symbolic raise an + // error. This is safe to use for C++ code that doesn't work for symbolic + // shapes, and you don't have time to fix it immediately, as if we + // try to trigger the path in C++ you'll appropriately get an error + int64_t expect_int() const { + if (auto r = maybe_as_int()) { + return *r; + } + TORCH_CHECK_ALWAYS_SHOW_CPP_STACKTRACE( + false, "when unpacking SymInt, expected int but got ", *this); + } + + // Test if we have a hint for this int (e.g., guard_int would work). + // Most of the time this is true; it is only false when you have + // an unbacked SymInt. + bool has_hint() const; + + // Insert a guard for the int to be its concrete value, and then return + // that value. This operation always works, even if the int is symbolic, + // so long as we know what the underlying value is (e.g., this won't work + // if you call it on the size of nonzero output). Don't blindly put this + // everywhere; you can cause overspecialization of PyTorch programs with + // this method. + // + // It should be called as guard_int(__FILE__, __LINE__). The file and line + // number can be used to diagnose overspecialization. + int64_t guard_int(const char* file, int64_t line) const; + + // Distinguish actual symbolic values from constants stored on the heap + bool is_symbolic() const { + return is_heap_allocated() && + !toSymNodeImplUnowned()->constant_int().has_value(); + } + + // N.B. It's important to keep this definition in the header + // as we expect if checks to be folded for mobile builds + // where `is_heap_allocated` is always false and optimize dead code paths + C10_ALWAYS_INLINE bool is_heap_allocated() const { +#ifdef C10_MOBILE + return false; +#else + return !check_range(data_); +#endif + } + + SymInt operator+(const SymInt& sci) const { + if (auto ma = maybe_as_int()) { + if (auto mb = sci.maybe_as_int()) { + return SymInt(*ma + *mb); + } + } + return operator_add_slow_path(sci); + } + + SymInt operator-(const SymInt& sci) const { + if (auto ma = maybe_as_int()) { + if (auto mb = sci.maybe_as_int()) { + return SymInt(*ma - *mb); + } + } + return operator_sub_slow_path(sci); + } + + SymInt operator*(const SymInt& sci) const { + if (auto ma = maybe_as_int()) { + if (auto mb = sci.maybe_as_int()) { + return SymInt(*ma * *mb); + } + } + return operator_mul_slow_path(sci); + } + + SymInt operator/(const SymInt& sci) const { + if (auto ma = maybe_as_int()) { + if (auto mb = sci.maybe_as_int()) { + return SymInt(*ma / *mb); + } + } + return operator_div_slow_path(sci); + } + + SymInt operator%(const SymInt& sci) const { + if (auto ma = maybe_as_int()) { + if (auto mb = sci.maybe_as_int()) { + return SymInt(*ma % *mb); + } + } + return operator_mod_slow_path(sci); + } + + void operator*=(const SymInt& sci) { + if (auto ma = maybe_as_int()) { + if (auto mb = sci.maybe_as_int()) { + *this = SymInt(*ma * *mb); + return; + } + } + operator_imul_slow_path(sci); + } + + void operator+=(const SymInt& sci) { + if (auto ma = maybe_as_int()) { + if (auto mb = sci.maybe_as_int()) { + *this = SymInt(*ma + *mb); + return; + } + } + operator_iadd_slow_path(sci); + } + + void operator/=(const SymInt& sci) { + if (auto ma = maybe_as_int()) { + if (auto mb = sci.maybe_as_int()) { + *this = SymInt(*ma / *mb); + return; + } + } + operator_idiv_slow_path(sci); + } + + SymInt clone() const; + + SymBool sym_eq(const SymInt& sci) const { + if (auto ma = maybe_as_int()) { + if (auto mb = sci.maybe_as_int()) { + return SymBool(*ma == *mb); + } + } + return sym_eq_slow_path(sci); + } + + SymBool sym_ne(const SymInt& sci) const { + if (auto ma = maybe_as_int()) { + if (auto mb = sci.maybe_as_int()) { + return SymBool(*ma != *mb); + } + } + return sym_ne_slow_path(sci); + } + + SymBool sym_lt(const SymInt& sci) const { + if (auto ma = maybe_as_int()) { + if (auto mb = sci.maybe_as_int()) { + return SymBool(*ma < *mb); + } + } + return sym_lt_slow_path(sci); + } + + SymBool sym_le(const SymInt& sci) const { + if (auto ma = maybe_as_int()) { + if (auto mb = sci.maybe_as_int()) { + return SymBool(*ma <= *mb); + } + } + return sym_le_slow_path(sci); + } + + SymBool sym_gt(const SymInt& sci) const { + if (auto ma = maybe_as_int()) { + if (auto mb = sci.maybe_as_int()) { + return SymBool(*ma > *mb); + } + } + return sym_gt_slow_path(sci); + } + + SymBool sym_ge(const SymInt& sci) const { + if (auto ma = maybe_as_int()) { + if (auto mb = sci.maybe_as_int()) { + return SymBool(*ma >= *mb); + } + } + return sym_ge_slow_path(sci); + } + + bool operator==(const SymInt& o) const { + return sym_eq(o).guard_bool(__FILE__, __LINE__); + } + bool operator!=(const SymInt& o) const { + return sym_ne(o).guard_bool(__FILE__, __LINE__); + } + bool operator<(const SymInt& o) const { + return sym_lt(o).guard_bool(__FILE__, __LINE__); + } + bool operator<=(const SymInt& o) const { + return sym_le(o).guard_bool(__FILE__, __LINE__); + } + bool operator>(const SymInt& o) const { + return sym_gt(o).guard_bool(__FILE__, __LINE__); + } + bool operator>=(const SymInt& o) const { + return sym_ge(o).guard_bool(__FILE__, __LINE__); + } + + SymInt min(const SymInt& sci) const { + if (auto ma = maybe_as_int()) { + if (auto mb = sci.maybe_as_int()) { + return SymInt(std::min(*ma, *mb)); + } + } + return min_slow_path(sci); + } + + SymInt max(const SymInt& sci) const { + if (auto ma = maybe_as_int()) { + if (auto mb = sci.maybe_as_int()) { + return SymInt(std::max(*ma, *mb)); + } + } + return max_slow_path(sci); + } + + // If both are symbolic, this checks if + // they share the same node. + // If both are not symbolic this just checks normal equality. + bool is_same(const SymInt& other) const; + + operator SymFloat() const; + + void unsafe_set_data(size_t nbytes) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!is_heap_allocated()); + data_ = static_cast(nbytes); + } + + // Don't use this. Prefer maybe_as_int instead + int64_t as_int_unchecked() const { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!is_heap_allocated()); + return data_; + } + + std::optional maybe_as_int() const { + if (!is_heap_allocated()) { + return data_; + } + return maybe_as_int_slow_path(); + } + + // Return whether the integer is directly coercible to a SymInt + // without requiring heap allocation. You don't need to use this + // to check if you can pass an integer to SymInt; this is guaranteed + // to work (it just might heap allocate!) + static bool check_range(int64_t i) { + return i > MAX_UNREPRESENTABLE_INT; + } + + // Return the min representable integer as a SymInt without + // heap allocation. For quantities that count bytes (or larger), + // this is still much larger than you need, so you may consider + // using this as a more efficient version of MIN_INT + static constexpr int64_t min_representable_int() { + return MAX_UNREPRESENTABLE_INT + 1; + } + + private: + void promote_to_negative(); + SymInt operator_add_slow_path(const SymInt& sci) const; + SymInt operator_sub_slow_path(const SymInt& sci) const; + SymInt operator_mul_slow_path(const SymInt& sci) const; + SymInt operator_div_slow_path(const SymInt& sci) const; + SymInt operator_mod_slow_path(const SymInt& sci) const; + void operator_imul_slow_path(const SymInt& sci); + void operator_iadd_slow_path(const SymInt& sci); + void operator_idiv_slow_path(const SymInt& sci); + SymBool sym_eq_slow_path(const SymInt& sci) const; + SymBool sym_ne_slow_path(const SymInt& sci) const; + SymBool sym_lt_slow_path(const SymInt& sci) const; + SymBool sym_le_slow_path(const SymInt& sci) const; + SymBool sym_gt_slow_path(const SymInt& sci) const; + SymBool sym_ge_slow_path(const SymInt& sci) const; + + SymInt min_slow_path(const SymInt& sci) const; + SymInt max_slow_path(const SymInt& sci) const; + + std::optional maybe_as_int_slow_path() const; + + // Constraints on the internal representation: + // + // - Should represent positive and small negative ints + // - No conversion necessary for operations on ints + // - Must represent valid 64-bit pointers + // - Is symbolic test should be FAST (two arithmetic instructions is too + // much). + // This code being a hotpath is based on Strobelight profiles of + // is_heap_allocated(). FB only: https://fburl.com/strobelight/5l50ncxd + // (you will need to change the time window). + // + // So, the scheme is to reserve large negative numbers (assuming + // two's complement): + // + // - 0b0.... means we are a positive int + // - 0b11... means we are a small negative int + // - 0b10... means we are are a pointer. This means that + // [-2^63, -2^62-1] are not representable as ints. + // We don't actually need all of this space as on x86_64 + // as the top 16bits aren't used for anything + static constexpr uint64_t MASK = 1ULL << 63 | 1ULL << 62 | 1ULL << 61; + static constexpr uint64_t IS_SYM = 1ULL << 63 | 1ULL << 61; + // We must manually translate the bit pattern test into a greater + // than test because compiler doesn't figure it out: + // https://godbolt.org/z/356aferaW + static constexpr int64_t MAX_UNREPRESENTABLE_INT = + -1LL & static_cast(~(1ULL << 62)); + int64_t data_; +}; + +/// Sum of a list of SymInt; accumulates into the c10::SymInt expression +template < + typename C, + typename std::enable_if_t< + std::is_same_v, + int> = 0> +inline c10::SymInt multiply_integers(const C& container) { + return std::accumulate( + container.begin(), + container.end(), + c10::SymInt(1), + [](const c10::SymInt& a, const c10::SymInt& b) { return a * b; }); +} + +template < + typename Iter, + typename = std::enable_if_t::value_type, + c10::SymInt>>> +inline c10::SymInt multiply_integers(Iter begin, Iter end) { + return std::accumulate( + begin, + end, + c10::SymInt(1), + [](const c10::SymInt& a, const c10::SymInt& b) { return a * b; }); +} + +#define DECLARE_SYMINT_OP_INTONLY(scalar_t, RetTy) \ + C10_API RetTy operator%(const SymInt& a, scalar_t b); \ + C10_API RetTy operator%(scalar_t a, const SymInt& b); + +#define DECLARE_SYMINT_OP(scalar_t, RetTy) \ + C10_API RetTy operator+(const SymInt& a, scalar_t b); \ + C10_API RetTy operator-(const SymInt& a, scalar_t b); \ + C10_API RetTy operator*(const SymInt& a, scalar_t b); \ + C10_API RetTy operator/(const SymInt& a, scalar_t b); \ + C10_API RetTy operator+(scalar_t a, const SymInt& b); \ + C10_API RetTy operator-(scalar_t a, const SymInt& b); \ + C10_API RetTy operator*(scalar_t a, const SymInt& b); \ + C10_API RetTy operator/(scalar_t a, const SymInt& b); \ + C10_API bool operator==(const SymInt& a, scalar_t b); \ + C10_API bool operator!=(const SymInt& a, scalar_t b); \ + C10_API bool operator<(const SymInt& a, scalar_t b); \ + C10_API bool operator<=(const SymInt& a, scalar_t b); \ + C10_API bool operator>(const SymInt& a, scalar_t b); \ + C10_API bool operator>=(const SymInt& a, scalar_t b); \ + C10_API bool operator==(scalar_t a, const SymInt& b); \ + C10_API bool operator!=(scalar_t a, const SymInt& b); \ + C10_API bool operator<(scalar_t a, const SymInt& b); \ + C10_API bool operator<=(scalar_t a, const SymInt& b); \ + C10_API bool operator>(scalar_t a, const SymInt& b); \ + C10_API bool operator>=(scalar_t a, const SymInt& b); + +DECLARE_SYMINT_OP_INTONLY(int64_t, SymInt) +DECLARE_SYMINT_OP_INTONLY(int32_t, SymInt) +DECLARE_SYMINT_OP_INTONLY(uint64_t, SymInt) +DECLARE_SYMINT_OP_INTONLY(uint32_t, SymInt) +DECLARE_SYMINT_OP(int64_t, SymInt) +DECLARE_SYMINT_OP(int32_t, SymInt) // make sure constants work +DECLARE_SYMINT_OP(uint64_t, SymInt) +DECLARE_SYMINT_OP(uint32_t, SymInt) +DECLARE_SYMINT_OP(double, SymFloat) +DECLARE_SYMINT_OP(float, SymFloat) // just for completeness + +// On OSX size_t is different than uint64_t so we have to +// define it separately +#if defined(__APPLE__) +DECLARE_SYMINT_OP_INTONLY(size_t, SymInt) +DECLARE_SYMINT_OP(size_t, SymInt) +#endif + +#undef DECLARE_SYMINT_OP + +C10_API std::ostream& operator<<(std::ostream& os, const SymInt& s); +C10_API SymInt operator-(const SymInt& s); + +inline bool sym_eq(int64_t a, int64_t b) { + return a == b; +} + +inline SymBool sym_eq(const SymInt& a, const SymInt& b) { + return a.sym_eq(b); +} + +inline bool sym_ne(int64_t a, int64_t b) { + return a != b; +} + +inline SymBool sym_ne(const SymInt& a, const SymInt& b) { + return a.sym_ne(b); +} + +inline bool sym_lt(int64_t a, int64_t b) { + return a < b; +} + +inline SymBool sym_lt(const SymInt& a, const SymInt& b) { + return a.sym_lt(b); +} + +inline bool sym_le(int64_t a, int64_t b) { + return a <= b; +} + +inline SymBool sym_le(const SymInt& a, const SymInt& b) { + return a.sym_le(b); +} + +inline bool sym_gt(int64_t a, int64_t b) { + return a > b; +} + +inline SymBool sym_gt(const SymInt& a, const SymInt& b) { + return a.sym_gt(b); +} + +inline bool sym_ge(int64_t a, int64_t b) { + return a >= b; +} + +inline SymBool sym_ge(const SymInt& a, const SymInt& b) { + return a.sym_ge(b); +} + +} // namespace c10 + +#include + +namespace std { + +template <> +class numeric_limits { + public: + static constexpr bool is_specialized = true; + + static constexpr int64_t max() noexcept { + return std::numeric_limits::max(); + } + + static constexpr int64_t min() noexcept { + return std::numeric_limits::min(); + } + + static constexpr bool is_signed = true; + static constexpr bool is_integer = true; +}; + +} // namespace std + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/SymIntArrayRef.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/SymIntArrayRef.h new file mode 100644 index 0000000000000000000000000000000000000000..b63753b186937f0e6869ee557ca1528bb2d7e340 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/SymIntArrayRef.h @@ -0,0 +1,113 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace c10 { +using SymIntArrayRef = ArrayRef; + +inline at::IntArrayRef asIntArrayRefUnchecked(c10::SymIntArrayRef ar) { + return IntArrayRef(reinterpret_cast(ar.data()), ar.size()); +} + +// TODO: a SymIntArrayRef containing a heap allocated large negative integer +// can actually technically be converted to an IntArrayRef... but not with +// the non-owning API we have here. We can't reinterpet cast; we have to +// allocate another buffer and write the integers into it. If you need it, +// we can do it. But I don't think you need it. + +inline std::optional asIntArrayRefSlowOpt( + c10::SymIntArrayRef ar) { + for (const c10::SymInt& sci : ar) { + if (sci.is_heap_allocated()) { + return std::nullopt; + } + } + + return {asIntArrayRefUnchecked(ar)}; +} + +inline at::IntArrayRef asIntArrayRefSlow( + c10::SymIntArrayRef ar, + const char* file, + int64_t line) { + for (const c10::SymInt& sci : ar) { + TORCH_CHECK( + !sci.is_heap_allocated(), + file, + ":", + line, + ": SymIntArrayRef expected to contain only concrete integers"); + } + return asIntArrayRefUnchecked(ar); +} + +// Even slower than asIntArrayRefSlow, as it forces an allocation for a +// destination int, BUT it is able to force specialization (it never errors) +inline c10::DimVector asIntArrayRefSlowAlloc( + c10::SymIntArrayRef ar, + const char* file, + int64_t line) { + c10::DimVector res(ar.size(), 0); + for (const auto i : c10::irange(ar.size())) { + res[i] = ar[i].guard_int(file, line); + } + return res; +} + +#define C10_AS_INTARRAYREF_SLOW(a) c10::asIntArrayRefSlow(a, __FILE__, __LINE__) +#define C10_AS_INTARRAYREF_SLOW_ALLOC(a) \ + c10::asIntArrayRefSlowAlloc(a, __FILE__, __LINE__) + +// Prefer using a more semantic constructor, like +// fromIntArrayRefKnownNonNegative +inline SymIntArrayRef fromIntArrayRefUnchecked(IntArrayRef array_ref) { + return SymIntArrayRef( + reinterpret_cast(array_ref.data()), array_ref.size()); +} + +inline SymIntArrayRef fromIntArrayRefKnownNonNegative(IntArrayRef array_ref) { + return fromIntArrayRefUnchecked(array_ref); +} + +inline SymIntArrayRef fromIntArrayRefSlow(IntArrayRef array_ref) { + for (long i : array_ref) { + TORCH_CHECK( + SymInt::check_range(i), + "IntArrayRef contains an int that cannot be represented as a SymInt: ", + i); + } + return SymIntArrayRef( + reinterpret_cast(array_ref.data()), array_ref.size()); +} + +inline c10::SymBool sym_equals(SymIntArrayRef LHS, SymIntArrayRef RHS) { + if (LHS.size() != RHS.size()) { + return c10::SymBool(false); + } + + c10::SymBool result = sym_eq(LHS.size(), RHS.size()); + for (size_t i = 0; i < RHS.size(); ++i) { + c10::SymBool equals = sym_eq(LHS[i], RHS[i]); + std::optional equals_bool = equals.maybe_as_bool(); + + if (equals_bool.has_value() && !*equals_bool) { + // Early return if element comparison is known to be false + return equals; + } + result = result.sym_and(equals); + } + return result; +} + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/SymNodeImpl.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/SymNodeImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..a4257684ea150ac4f8f1bda39ab4c1212c1929ed --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/SymNodeImpl.h @@ -0,0 +1,261 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-parameter") + +namespace c10 { + +class SymNodeImpl; +using SymNode = c10::intrusive_ptr; + +// When you add a method, you also need to edit +// torch/csrc/jit/python/init.cpp +// torch/csrc/utils/python_symnode.h +// c10/core/ConstantSymNodeImpl.h +class C10_API SymNodeImpl : public c10::intrusive_ptr_target { + public: + ~SymNodeImpl() override = default; + + template + c10::intrusive_ptr dyn_cast() const { + return c10::intrusive_ptr::reclaim_copy(dynamic_cast(this)); + } + + // these could be pure virtual when we implement LTC versions + virtual bool is_int() { + TORCH_CHECK(false, "NYI"); + } + virtual bool is_bool() { + TORCH_CHECK(false, "NYI"); + } + virtual bool is_float() { + TORCH_CHECK(false, "NYI"); + } + virtual bool is_nested_int() const { + return false; + } + virtual SymNode add(const SymNode& other) { + TORCH_CHECK(false, "NYI"); + } + virtual SymNode sub(const SymNode& other) { + TORCH_CHECK(false, "NYI"); + } + virtual SymNode mul(const SymNode& other) { + TORCH_CHECK(false, "NYI"); + } + // NB: legacy, prefer float_truediv or int_truediv + virtual SymNode truediv(const SymNode& other) { + TORCH_CHECK(false, "NYI"); + } + virtual SymNode float_truediv(const SymNode& other) { + return truediv(other); + } + virtual SymNode int_truediv(const SymNode& other) { + return truediv(other); + } + // NB: legacy, prefer float_pow or pow_by_natural + virtual SymNode pow(const SymNode& other) { + TORCH_CHECK(false, "NYI"); + } + virtual SymNode float_pow(const SymNode& other) { + return pow(other); + } + virtual SymNode pow_by_natural(const SymNode& other) { + return pow(other); + } + // NB: legacy, prefer int_floordiv + virtual SymNode floordiv(const SymNode& other) { + TORCH_CHECK(false, "NYI"); + } + virtual SymNode int_floordiv(const SymNode& other) { + return floordiv(other); + } + virtual SymNode mod(const SymNode& other) { + TORCH_CHECK(false, "NYI"); + } + virtual SymNode eq(const SymNode& other) { + TORCH_CHECK(false, "NYI"); + } + virtual SymNode ne(const SymNode& other) { + TORCH_CHECK(false, "NYI"); + } + virtual SymNode gt(const SymNode& other) { + TORCH_CHECK(false, "NYI"); + } + virtual SymNode lt(const SymNode& other) { + TORCH_CHECK(false, "NYI"); + } + virtual SymNode le(const SymNode& other) { + TORCH_CHECK(false, "NYI"); + } + virtual SymNode ge(const SymNode& other) { + TORCH_CHECK(false, "NYI"); + } + virtual SymNode ceil() { + TORCH_CHECK(false, "NYI"); + } + virtual SymNode floor() { + TORCH_CHECK(false, "NYI"); + } + virtual SymNode neg() { + TORCH_CHECK(false, "NYI"); + } + virtual SymNode sym_min(const SymNode& other) { + TORCH_CHECK(false, "NYI"); + } + virtual SymNode sym_max(const SymNode& other) { + TORCH_CHECK(false, "NYI"); + } + virtual SymNode sym_or(const SymNode& other) { + TORCH_CHECK(false, "NYI"); + } + virtual SymNode sym_and(const SymNode& other) { + TORCH_CHECK(false, "NYI"); + } + virtual SymNode sym_not() { + TORCH_CHECK(false, "NYI"); + } + virtual SymNode sym_ite(const SymNode& then_val, const SymNode& else_val) { + TORCH_CHECK(false, "NYI"); + } + // NB: self is ignored here, only the arguments are used + virtual SymNode is_contiguous( + ArrayRef sizes, + ArrayRef strides) { + TORCH_CHECK(false, "NYI"); + } + virtual SymNode is_channels_last_contiguous_2d( + ArrayRef sizes, + ArrayRef strides) { + TORCH_CHECK(false, "NYI"); + } + virtual SymNode is_channels_last_contiguous_3d( + ArrayRef sizes, + ArrayRef strides) { + TORCH_CHECK(false, "NYI"); + } + virtual SymNode is_channels_last_strides_2d( + ArrayRef sizes, + ArrayRef strides) { + TORCH_CHECK(false, "NYI"); + } + virtual SymNode is_channels_last_strides_3d( + ArrayRef sizes, + ArrayRef strides) { + TORCH_CHECK(false, "NYI"); + } + virtual SymNode is_non_overlapping_and_dense( + ArrayRef sizes, + ArrayRef strides) { + TORCH_CHECK(false, "NYI"); + } + virtual SymNode clone() { + TORCH_CHECK(false, "NYI"); + } + virtual SymNode sym_float() { + TORCH_CHECK(false, "NYI"); + } + virtual SymNode wrap_int(int64_t num) { + TORCH_CHECK(false, "NYI"); + } + virtual SymNode wrap_float(double num) { + TORCH_CHECK(false, "NYI"); + } + virtual SymNode wrap_bool(bool num) { + TORCH_CHECK(false, "NYI"); + } + virtual int64_t guard_int(const char* file, int64_t line) { + TORCH_CHECK(false, "NYI"); + } + virtual bool guard_bool(const char* file, int64_t line) { + TORCH_CHECK(false, "NYI"); + } + virtual double guard_float(const char* file, int64_t line) { + TORCH_CHECK(false, "NYI"); + } + virtual bool guard_size_oblivious(const char* file, int64_t line) { + // No improvement for unbacked SymBools by default, replace this + // with a better implementation! + return guard_bool(file, line); + } + virtual bool guard_or_false(const char* file, int64_t line) { + // Note: PT2 primarily uses PythonSymNodeImpl for this functionality. + // XLA is currently the main consumer of this fallback path since it uses + // ahead-of-time compilation and cannot depend on Python runtime. + return guard_bool(file, line); + } + virtual bool statically_known_true(const char* file, int64_t line) { + // Note: PT2 primarily uses PythonSymNodeImpl for this functionality. + // XLA is currently the main consumer of this fallback path since it uses + // ahead-of-time compilation and cannot depend on Python runtime. + return guard_bool(file, line); + } + virtual bool guard_or_true(const char* file, int64_t line) { + // Note: PT2 primarily uses PythonSymNodeImpl for this functionality. + // XLA is currently the main consumer of this fallback path since it uses + // ahead-of-time compilation and cannot depend on Python runtime. + return guard_bool(file, line); + } + virtual bool expect_true(const char* file, int64_t line) { + // No improvement for unbacked SymBools by default, replace this + // with a better implementation! + return guard_bool(file, line); + } + virtual int64_t int_() { + TORCH_CHECK(false, "NYI"); + } + virtual bool bool_() { + TORCH_CHECK(false, "NYI"); + } + virtual bool has_hint() { + TORCH_CHECK(false, "NYI"); + } + virtual std::string str() { + TORCH_CHECK(false, "NYI"); + } + virtual std::string _graph_repr() { + return str(); + } + virtual std::optional nested_int() { + return std::nullopt; + } + virtual std::optional nested_int_coeff() { + return std::nullopt; + } + virtual std::optional constant_int() { + return std::nullopt; + } + virtual std::optional constant_bool() { + return std::nullopt; + } + virtual std::optional maybe_as_int() { + return std::nullopt; + } + virtual bool is_constant() { + return false; + } + virtual bool is_symbolic() { + return true; + } + std::ostream& operator<<(std::ostream& os) { + os << str(); + return os; + } +}; + +} // namespace c10 +C10_DIAGNOSTIC_POP() + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/SymbolicShapeMeta.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/SymbolicShapeMeta.h new file mode 100644 index 0000000000000000000000000000000000000000..411c81a98bac68a34c7c2bafbf78b096bf2bc9cc --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/SymbolicShapeMeta.h @@ -0,0 +1,234 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace c10 { + +class C10_API SymbolicShapeMeta { + public: + // Basic metadata from which other quantities are derived + SymDimVector sizes_ = {0}; + SymDimVector strides_ = {1}; + SymInt storage_offset_ = 0; + + bool strides_valid_ = true; // e.g. for sparse where there are no strides + + SymbolicShapeMeta() = default; + ~SymbolicShapeMeta() = default; + SymbolicShapeMeta(const SymbolicShapeMeta& other); + SymbolicShapeMeta(SymbolicShapeMeta&& other) = delete; + SymbolicShapeMeta& operator=(const SymbolicShapeMeta& other) = delete; + SymbolicShapeMeta& operator=(SymbolicShapeMeta&& other) = delete; + + void refresh_numel() { + // Non-const, don't need to hold mutables_ lock + available_.fetch_and(~numel_avail); + numel_ = 1; + } + + void refresh_contiguous() { + // Non-const, don't need to hold mutables_ lock + available_.fetch_and(numel_avail); + is_contiguous_ = false; + is_channels_last_contiguous_ = false; + is_channels_last_3d_contiguous_ = false; + is_channels_last_ = false; + is_channels_last_3d_ = false; + is_non_overlapping_and_dense_ = false; + } + + int64_t dim() const { + return static_cast(sizes_.size()); + } + + // Accessors for derived quantities, computed lazily on first access + + bool has_numel() const { + return available_.load() & numel_avail; + } + bool has_is_contiguous() const { + return available_.load() & is_contiguous_avail; + } + bool has_is_channels_last_contiguous() const { + return available_.load() & is_channels_last_contiguous_avail; + } + bool has_is_channels_last_3d_contiguous() const { + return available_.load() & is_channels_last_3d_contiguous_avail; + } + bool has_is_channels_last() const { + return available_.load() & is_channels_last_avail; + } + bool has_is_channels_last_3d() const { + return available_.load() & is_channels_last_3d_avail; + } + bool has_is_non_overlapping_and_dense() const { + return available_.load() & is_non_overlapping_and_dense_avail; + } + + // Accessors to cached derived properties + // DO NOT call with mutables_ lock held + const SymInt& numel() const { + if (C10_UNLIKELY(!has_numel())) { + init_numel(); + } + return numel_; + } + + const SymBool& is_contiguous(at::MemoryFormat memory_format) const { + if (memory_format == at::MemoryFormat::ChannelsLast) { + return this->is_channels_last_contiguous(); + } else if (memory_format == at::MemoryFormat::ChannelsLast3d) { + return this->is_channels_last_3d_contiguous(); + } + return this->is_contiguous(); + } + + const SymBool& is_contiguous() const { + if (C10_UNLIKELY(!has_is_contiguous())) { + init_is_contiguous(); + } + return is_contiguous_; + } + + const SymBool& is_channels_last_contiguous() const { + if (C10_UNLIKELY(!has_is_channels_last_contiguous())) { + init_is_channels_last_contiguous(); + } + return is_channels_last_contiguous_; + } + + const SymBool& is_channels_last_3d_contiguous() const { + if (C10_UNLIKELY(!has_is_channels_last_3d_contiguous())) { + init_is_channels_last_3d_contiguous(); + } + return is_channels_last_3d_contiguous_; + } + + const SymBool& is_channels_last() const { + if (C10_UNLIKELY(!has_is_channels_last())) { + init_is_channels_last(); + } + return is_channels_last_; + } + + const SymBool& is_channels_last_3d() const { + if (C10_UNLIKELY(!has_is_channels_last_3d())) { + init_is_channels_last_3d(); + } + return is_channels_last_3d_; + } + + const SymBool& is_non_overlapping_and_dense() const { + if (C10_UNLIKELY(!has_is_non_overlapping_and_dense())) { + init_is_non_overlapping_and_dense(); + } + return is_non_overlapping_and_dense_; + } + + // Assumptions so we can short-circuit computation + // NOTE: Don't need to lock mutables_ since these aren't const + void assume_contiguous(SymBool val = true) { + is_contiguous_ = std::move(val); + available_.fetch_or(is_contiguous_avail); + } + void assume_channels_last_contiguous(SymBool val = true) { + is_contiguous_ = std::move(val); + available_.fetch_or(is_channels_last_contiguous_avail); + } + void assume_channels_last_3d_contiguous(SymBool val = true) { + is_channels_last_3d_contiguous_ = std::move(val); + available_.fetch_or(is_channels_last_3d_contiguous_avail); + } + void assume_channels_last(SymBool val = true) { + is_channels_last_ = std::move(val); + available_.fetch_or(is_channels_last_avail); + } + void assume_channels_last_3d(SymBool val = true) { + is_channels_last_3d_ = std::move(val); + available_.fetch_or(is_channels_last_3d_avail); + } + void assume_non_overlapping_and_dense(SymBool val = true) { + is_non_overlapping_and_dense_ = std::move(val); + available_.fetch_or(is_non_overlapping_and_dense_avail); + } + + private: + SymBool compute_contiguous() const; + SymBool compute_channels_last_contiguous_2d() const; + SymBool compute_channels_last_contiguous_3d() const; + SymBool compute_strides_like_channels_last_2d() const; + SymBool compute_strides_like_channels_last_3d() const; + SymBool compute_non_overlapping_and_dense() const; + + // These are little wrappers over the real compute_ functions that + // can make use of other contiguity fields to short circuit. + // They need to be implemented separately for SymBool, as SymBool does + // not short circuit. + // TODO: should the SymBool cases avoid the short circuit? Need to reason + // if its correct, and reason if the simpler expressions are better for + // analysis (maybe not!) + + SymBool compute_channels_last_contiguous_3d_dim5() const; + SymBool compute_channels_last_2d_dim5() const; + SymBool compute_channels_last_3d_dim5() const; + SymBool compute_is_non_overlapping_and_dense_dim4() const; + SymBool compute_is_non_overlapping_and_dense_dim5() const; + SymBool compute_is_non_overlapping_and_dense_anydim() const; + + void init_numel() const; + void init_is_contiguous() const; + void init_is_channels_last_contiguous() const; + void init_is_channels_last_3d_contiguous() const; + void init_is_channels_last() const; + void init_is_channels_last_3d() const; + void init_is_non_overlapping_and_dense() const; + + // NOTE: These only set if !has_foo() + void set_numel(SymInt val) const; + void set_is_contiguous(SymBool val) const; + void set_is_channels_last_contiguous(SymBool val) const; + void set_is_channels_last_3d_contiguous(SymBool val) const; + void set_is_channels_last(SymBool val) const; + void set_is_channels_last_3d(SymBool val) const; + void set_is_non_overlapping_and_dense(SymBool val) const; + + // Lazily initialized variables, with the corresponding available_ flag + // indicating whether the value has been initialized + mutable std::atomic available_{0}; + + enum avail { + numel_avail = 1 << 0, + is_contiguous_avail = 1 << 1, + is_channels_last_contiguous_avail = 1 << 2, + is_channels_last_3d_contiguous_avail = 1 << 3, + is_channels_last_avail = 1 << 4, + is_channels_last_3d_avail = 1 << 5, + is_non_overlapping_and_dense_avail = 1 << 6, + }; + + // Mutex to prevent races when initializing the variable from const accessors + mutable std::mutex mutables_; + mutable SymInt numel_ = 1; + mutable SymBool is_contiguous_{true}; + mutable SymBool is_channels_last_contiguous_{false}; + mutable SymBool is_channels_last_3d_contiguous_{false}; + mutable SymBool is_channels_last_{false}; + mutable SymBool is_channels_last_3d_{false}; + mutable SymBool is_non_overlapping_and_dense_{true}; +}; + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/TensorImpl.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/TensorImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..03faea3fbc70500bda37a8099657e80f38976657 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/TensorImpl.h @@ -0,0 +1,3333 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// A global boolean variable to control whether we free memory when a Tensor +// is shrunk to a smaller size. As a result, a Tensor is always going to +// keep the memory allocated for its maximum capacity reshaped to so far. +// +// This parameter is respected "upper-case" methods which call Resize() +// (e.g., CopyFrom, ResizeLike); it is NOT respected by Tensor::resize_ +// or ShrinkTo, both of which guarantee to never to free memory. +C10_DECLARE_bool(caffe2_keep_on_shrink); + +// Since we can have high variance in blob memory allocated across different +// inputs in the same run, we will shrink the blob only if the memory gain +// is larger than this flag in bytes. This only applies to functions which +// respect caffe2_keep_on_shrink. +C10_DECLARE_int64(caffe2_max_keep_on_shrink_memory); + +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-default") + +namespace at { +class Tensor; +class TensorBase; +} // namespace at + +namespace c10 { + +/** + * A utility function to convert vector to vector. + */ +inline std::vector ToVectorint64_t(const ArrayRef& src) { + return std::vector(src.begin(), src.end()); +} + +/** + * Return product of all dimensions starting from k + */ +inline int64_t size_from_dim_(int k, IntArrayRef dims) { + int64_t r = 1; + for (const auto i : c10::irange(k, dims.size())) { + r *= dims[i]; + } + return r; +} + +// Product of all dims up to k (not including dims[k]) +inline int64_t size_to_dim_(int k, IntArrayRef dims) { + TORCH_CHECK(k >= 0 && static_cast(k) <= dims.size()); + int64_t r = 1; + for (const auto i : c10::irange(k)) { + r *= dims[i]; + } + return r; +} + +// Product of all dims between k and l (not including dims[k] and dims[l]) +inline int64_t size_between_dim_(int k, int l, IntArrayRef dims) { + TORCH_CHECK((unsigned)l < dims.size() && (unsigned)k < dims.size()); + int64_t r = 1; + if (k < l) { + for (int i = k + 1; i < l; ++i) { + r *= dims[i]; + } + } else { + for (int i = l + 1; i < k; ++i) { + r *= dims[i]; + } + } + return r; +} + +// Wrap around axis_index if it is negative, s.t., -1 is the last dim +inline int canonical_axis_index_(int axis_index, int ndims) { + TORCH_CHECK(axis_index >= -ndims); + TORCH_CHECK(axis_index < ndims); + if (axis_index < 0) { + return axis_index + ndims; + } + return axis_index; +} + +using PlacementDtor = void (*)(void*, size_t); + +/* + * A Context that will call extra placement deleter during + * deconstruction. + * + * Accept a already constructed DataPtr and store it as member + * during destruction, we'll call extra deleter on the underlying + * data pointer before the DataPtr is destructed. + * `data_ptr_` owns the memory. + */ +struct C10_API PlacementDeleteContext { + DataPtr data_ptr_; + PlacementDtor placement_dtor_; + size_t size_; + + PlacementDeleteContext( + DataPtr&& data_ptr, + PlacementDtor placement_dtor, + size_t size) + : data_ptr_(std::move(data_ptr)), + placement_dtor_(placement_dtor), + size_(size) {} + + PlacementDeleteContext(PlacementDeleteContext&&) noexcept = delete; + PlacementDeleteContext(const PlacementDeleteContext&) = delete; + PlacementDeleteContext& operator=(const PlacementDeleteContext&) = delete; + PlacementDeleteContext& operator=(PlacementDeleteContext&&) = delete; + static DataPtr makeDataPtr( + DataPtr&& data_ptr, + PlacementDtor placement_dtor, + size_t size, + Device device); + ~PlacementDeleteContext() { + placement_dtor_(data_ptr_.get(), size_); + // original memory will be freed when data_ptr_ is destructed + } +}; + +struct C10_API AutogradMetaInterface { + virtual void set_requires_grad( + bool requires_grad, + at::TensorImpl* self_impl) = 0; + virtual bool requires_grad() const = 0; + virtual at::Tensor& mutable_grad() = 0; + virtual const at::Tensor& grad() const = 0; + virtual const at::Tensor& fw_grad(uint64_t level, const at::TensorBase& self) + const = 0; + virtual void set_fw_grad( + const at::TensorBase& new_grad, + const at::TensorBase& self, + uint64_t level, + bool is_inplace_op) = 0; + virtual ~AutogradMetaInterface(); +}; + +namespace impl { + +// Unfortunately, the definition of AutogradMeta lives in a separate +// compilation unit than TensorImpl (libtorch.so versus libc10.so) +// which means that we cannot construct an AutogradMeta from TensorImpl, +// not even from the cpp file. So we have to indirect it through a factory +// function which will be initialized when we load libtorch.so. + +struct C10_API AutogradMetaFactory { + virtual ~AutogradMetaFactory() = default; + virtual std::unique_ptr make() const = 0; + // This method is the dumbest method. But I don't have access + // to Tensor (not TensorImpl) which is undefined in this header. + virtual const at::Tensor& undefined_tensor() const = 0; +}; + +C10_API void SetAutogradMetaFactory(AutogradMetaFactory* factory); +C10_API AutogradMetaFactory* GetAutogradMetaFactory(); + +struct C10_API AutogradMetaFactoryRegisterer{ + explicit AutogradMetaFactoryRegisterer(AutogradMetaFactory * factory){ + SetAutogradMetaFactory(factory); +} // namespace impl +}; // namespace c10 + +} // namespace impl + +struct C10_API NamedTensorMetaInterface { + virtual ~NamedTensorMetaInterface() = default; + virtual std::unique_ptr clone() const { + TORCH_INTERNAL_ASSERT( + false, "Not implemented: NamedTensorMetaInterface::clone"); + } + virtual int64_t slow_dim() const { + TORCH_INTERNAL_ASSERT( + false, "Not implemented: NamedTensorMetaInterface::slow_dim"); + } +}; + +// For ease of copy pasting +#if 0 +is_contiguous +is_channels_last_contiguous +is_channels_last_3d_contiguous +is_channels_last +is_channels_last_3d +is_non_overlapping_and_dense +#endif + +/** + * This structure is intended to hold additional metadata of the specific device + * backend. + **/ +struct C10_API BackendMeta : intrusive_ptr_target { + ~BackendMeta() override = default; + virtual intrusive_ptr clone( + const intrusive_ptr& ptr) const { + return ptr; + } +}; + +struct C10_API ExtraMeta { + std::unique_ptr symbolic_shape_meta_ = nullptr; + std::unique_ptr named_tensor_meta_ = nullptr; + intrusive_ptr backend_meta_ = nullptr; + std::optional custom_data_ptr_error_msg_ = std::nullopt; + std::optional custom_storage_error_msg_ = std::nullopt; + + ExtraMeta() = default; + ~ExtraMeta() = default; + ExtraMeta(const ExtraMeta& other) { + if (other.symbolic_shape_meta_) { + symbolic_shape_meta_ = + std::make_unique(*other.symbolic_shape_meta_); + } + if (other.named_tensor_meta_) { + named_tensor_meta_ = other.named_tensor_meta_->clone(); + } + if (other.backend_meta_) { + backend_meta_ = other.backend_meta_->clone(other.backend_meta_); + } + if (other.custom_data_ptr_error_msg_) { + custom_data_ptr_error_msg_ = other.custom_data_ptr_error_msg_; + } + if (other.custom_storage_error_msg_) { + custom_storage_error_msg_ = other.custom_storage_error_msg_; + } + } + ExtraMeta& operator=(const ExtraMeta& other) = delete; + ExtraMeta(ExtraMeta&& other) = delete; + ExtraMeta& operator=(ExtraMeta&& other) = delete; + + ExtraMeta( + std::unique_ptr symbolic_shape_meta, + std::unique_ptr named_tensor_meta, + intrusive_ptr backend_meta, + std::optional custom_data_ptr_error_msg = std::nullopt, + std::optional custom_storage_access_error_msg = std::nullopt) + : symbolic_shape_meta_(std::move(symbolic_shape_meta)), + named_tensor_meta_(std::move(named_tensor_meta)), + backend_meta_(std::move(backend_meta)), + custom_data_ptr_error_msg_(std::move(custom_data_ptr_error_msg)), + custom_storage_error_msg_(std::move(custom_storage_access_error_msg)) {} + + std::unique_ptr clone() const { + return std::make_unique(*this); + } +}; + +// NOTE [ Version Counter Sharing ] +// +// Every Tensor has a version counter. Version counters are incremented whenever +// the data or size of a tensor changes through in-place Variable operations. +// Version counters are used to detect modifications to saved variables which +// would result in incorrect gradient calculations. Version counters may be +// shared between Variables: +// +// 1. A view shares the version counter of the base Variable, +// 2. `x.detach()` shares the version counter of `x`, +// 3. Unpacked saved variables share the version counter of the source. +// +// Version counters are not shared in these scenarios: +// +// 1. When we replace a `Variable`'s underlying `Tensor` by calling +// `set_data(...)`, +// 2. `x.data` does not share the version counter of `x`. (See discussion at +// https://github.com/pytorch/pytorch/issues/5396) +// +// Question: Why do we put the version counter in TensorImpl instead of +// AutogradMeta? +// +// Answer: After the Variable/Tensor merge, a tensor will not have AutogradMeta +// when its `requires_grad_` is false, but when we use this tensor in the +// forward pass of a function that requires saving this tensor for backward, we +// need to keep track of this tensor's version to make sure it's always valid in +// the autograd graph. +// +// To achieve this goal, we put the version counter in TensorImpl instead of +// AutogradMeta, and have it always be available. This allows us to have the +// optimization of not carrying AutogradMeta when a tensor doesn't require +// gradient. +// +// A hypothetical alternative way to achieve this goal is to initialize +// AutogradMeta and create the version counter for the non-requires-grad tensor +// only when it's saved for backward. However, since saving a tensor for +// backward happens in the forward pass, and our invariant is that forward pass +// needs to be thread-safe, lazy-initializing AutogradMeta when saving a tensor +// can introduce race conditions when we are running the forward pass in +// multi-thread scenarios, thus making the forward pass not thread-safe anymore, +// which breaks the invariant. +struct C10_API VariableVersion { + private: + struct VersionCounter : intrusive_ptr_target { + VersionCounter(uint32_t version) : version_(version) {} + std::atomic version_; + }; + c10::intrusive_ptr version_counter_; + + public: + // Note [Disabled VariableVersion] + // VariableVersion struct has an intrusive_ptr pointing VersionCounter struct + // with an atomic variable. Thus `VariableVersion(/*version=*/0)` is not as + // cheap as we expected. In some cases constructing a VariableVersion with + // version 0 is not necessary so we add a cheap constructor which + // doesn't allocate the intrusive_ptr. + // Example use cases are: + // - Inference tensors don't track version counter, so they'll just always + // have disabled VariableVersion. + // - In SavedVariable class we override version_counter_ inside its + // constructor + // so that we can use the cheap constructor there. + enum Disabled { DISABLED }; + // It's okay to return true even for inference tensor which + // doesn't have version counter enabled. + // We want to be permissive here since in many cases (e.g. make_variable) + // we can std::move a TensorImpl if there's no other uses which saves us + // an additional TensorImpl allocation. + bool unique() const { + return version_counter_ ? 1 == version_counter_.use_count() : true; + } + // NOTE: As of C++11 and 14, default-constructing a std::atomic variable + // leaves it in a persistently undefined state. See + // https://cplusplus.github.io/LWG/issue2334. + VariableVersion(uint32_t version) + : version_counter_(c10::make_intrusive(version)) {} + VariableVersion(Disabled /*unused*/ = DISABLED) {} + + bool enabled() const { + return version_counter_; + } + + // Note [Inplace update inference tensor] + // 1. Inplace update to inference tensor is forbidden in normal mode. + // For example: + // inference_tensor.copy_(normal_tensor_requires_grad) + // This inplace makes inference_tensor have requires_grad=True and + // have a grad_fn. This is bad because views of `inference_tensor` + // created in InferenceMode won't be able to know the grad_fn since + // their ViewMeta were not recorded. To match NoGradMode behavior + // that "inplace update to a view created in NoGradMode raise an error", + // we just ban inplace update to inference tensor since we can't tell + // if an inference tensor is a view created in InferenceMode. + // + // Note that views of normal tensor created in InferenceMode has proper + // ViewMeta so that they're aware of the grad_fn correctly. + // + // 2. Inplace update to inference tensor in inference tensor doesn't bump + // version counter. + // * It either doesn't call bump() by skipping ADInplaceOrView kernel, + // - e.g. inference_tensor.add_(1) + // * or bump() is a no-op for inference tensor. + // - e.g. inference_tensor.add_(normal_tensor) + void bump() { + // TODO: Replace the link to the documentation once it's available. + TORCH_CHECK( + version_counter_ || InferenceMode::is_enabled(), + "Inplace update to inference tensor outside InferenceMode is not allowed." + "You can make a clone to get a normal tensor before doing inplace update." + "See https://github.com/pytorch/rfcs/pull/17 for more details."); + if (version_counter_) { + ++version_counter_->version_; + } + } + + void set_version(int64_t i) { + TORCH_CHECK( + version_counter_, + "Tried to call torch.autograd._unsafe_set_version() on a tensor " + "that does not have a version counter. Was it created in inference mode?"); + TORCH_CHECK(i >= 0, "Cannot set a version_counter to a value below 0: ", i); + version_counter_->version_ = i; + } + + // Inference tensor doesn't have version counter so it shouldn't be + // accessed. + uint32_t current_version() const { + TORCH_CHECK( + version_counter_, "Inference tensors do not track version counter."); + return version_counter_->version_; + } +}; + +// Forward declaration of TensorImpl needed for forward declaration of +// C10_TensorImpl_Size_Check_Dummy_Class +struct C10_API TensorImpl; + +/** + * NOTE: Some TensorImpl methods are small and not overridden in the + * PyTorch codebase itself, but may theoretically need to be + * overridden by third-party TensorImpl subclasses. This macro allows + * users that need maximum performance and don't need these extension + * points to disable them with a build-time flag. (In particular, + * XLA's XLATensorImpl currently overrides these methods, so we can't + * enable this flag by default.) + */ +#ifdef C10_DISABLE_TENSORIMPL_EXTENSIBILITY +#define TENSORIMPL_MAYBE_VIRTUAL +#else +#define TENSORIMPL_MAYBE_VIRTUAL virtual +#endif + +/** + * The low-level representation of a tensor, which contains a pointer + * to a storage (which contains the actual data) and metadata (e.g., sizes and + * strides) describing this particular view of the data as a tensor. + * + * Some basic characteristics about our in-memory representation of + * tensors: + * + * - It contains a pointer to a storage struct (Storage/StorageImpl) + * which contains the pointer to the actual data and records the + * data type and device of the view. This allows multiple tensors + * to alias the same underlying data, which allows to efficiently + * implement differing *views* on a tensor. + * + * - The tensor struct itself records view-specific metadata about + * the tensor, e.g., sizes, strides and offset into storage. + * Each view of a storage can have a different size or offset. + * + * - This class is intrusively refcounted. It is refcounted so that + * we can support prompt deallocation of large tensors; it is + * intrusively refcounted so that we can still perform reference + * counted operations on raw pointers, which is often more convenient + * when passing tensors across language boundaries. + * + * - For backwards-compatibility reasons, a tensor may be in an + * uninitialized state. A tensor may be uninitialized in the following + * two ways: + * + * - A tensor may be DTYPE UNINITIALIZED. A tensor of this + * form has an uninitialized dtype. This situation most + * frequently arises when a user writes Tensor x(CPU). The dtype + * is subsequently initialized when mutable_data() is + * invoked for the first time. + * + * - A tensor may be STORAGE UNINITIALIZED. A tensor of this form + * has non-zero size, but has a storage with a null data pointer. + * This situation most frequently arises when a user calls + * Resize() or FreeMemory(). This is because Caffe2 historically + * does lazy allocation: allocation of data doesn't occur until + * mutable_data() is invoked. A tensor with zero size is + * always storage initialized, because no allocation is necessary + * in this case. + * + * All combinations of these two uninitialized states are possible. + * Consider the following transcript in idiomatic Caffe2 API: + * + * Tensor x(CPU); // x is storage-initialized, dtype-UNINITIALIZED + * x.Resize(4); // x is storage-UNINITIALIZED, dtype-UNINITIALIZED + * x.mutable_data(); // x is storage-initialized, dtype-initialized + * x.FreeMemory(); // x is storage-UNINITIALIZED, dtype-initialized. + * + * All other fields on tensor are always initialized. In particular, + * size is always valid. (Historically, a tensor declared as Tensor x(CPU) + * also had uninitialized size, encoded as numel == -1, but we have now + * decided to default to zero size, resulting in numel == 0). + * + * Uninitialized storages MUST be uniquely owned, to keep our model + * simple. Thus, we will reject operations which could cause an + * uninitialized storage to become shared (or a shared storage to + * become uninitialized, e.g., from FreeMemory). + * + * In practice, tensors which are storage-UNINITIALIZED and + * dtype-UNINITIALIZED are *extremely* ephemeral: essentially, + * after you do a Resize(), you basically always call mutable_data() + * immediately afterwards. Most functions are not designed to + * work if given a storage-UNINITIALIZED, dtype-UNINITIALIZED tensor. + * + * We intend to eliminate all uninitialized states, so that every + * tensor is fully initialized in all fields. Please do not write new code + * that depends on these uninitialized states. + */ +struct C10_API TensorImpl : public c10::intrusive_ptr_target { + TensorImpl() = delete; + ~TensorImpl() override; + // Note [Enum ImplType] + // This enum is temporary. In the followup refactor we should + // think about how to specialize TensorImpl creation for view + // tensors. Currently we only special case its key_set_ but + // there's also potential to share version_counter_ directly + // without creating first and then override in as_view. + enum ImplType { VIEW }; + + /** + * Construct a 1-dim 0-size tensor backed by the given storage. + */ + TensorImpl( + Storage&& storage, + DispatchKeySet /*key_set*/, + const caffe2::TypeMeta data_type); + + // See Note [Enum ImplType] + TensorImpl( + ImplType /*unused*/, + Storage&& storage, + DispatchKeySet /*key_set*/, + const caffe2::TypeMeta data_type); + + /** + * Construct a 1-dim 0 size tensor that doesn't have a storage. + */ + TensorImpl( + DispatchKeySet /*key_set*/, + const caffe2::TypeMeta data_type, + std::optional device_opt); + + // Legacy constructors so I don't have to go update call sites. + // TODO: When Variable is added, delete these constructors + TensorImpl( + Storage&& storage, + DispatchKey dispatch_key, + const caffe2::TypeMeta data_type) + : TensorImpl( + std::move(storage), + DispatchKeySet(dispatch_key), + data_type) {} + TensorImpl( + DispatchKey dispatch_key, + const caffe2::TypeMeta data_type, + std::optional device_opt) + : TensorImpl(DispatchKeySet(dispatch_key), data_type, device_opt) {} + + private: + // This constructor is private, because the data_type is redundant with + // storage. Still, we pass it in separately because it's easier to write + // the initializer list if we're not worried about storage being moved out + // from under us. + TensorImpl( + Storage&& storage, + DispatchKeySet /*key_set*/, + const caffe2::TypeMeta data_type, + std::optional /*device_opt*/); + + public: + TensorImpl(const TensorImpl&) = delete; + TensorImpl& operator=(const TensorImpl&) = delete; + TensorImpl(TensorImpl&&) = delete; + TensorImpl& operator=(TensorImpl&&) = delete; + + /** + * Release (decref) storage, and any other external allocations. This + * override is for `intrusive_ptr_target` and is used to implement weak + * tensors. + */ + void release_resources() override; + + public: + /** + * Return the DispatchKeySet corresponding to this Tensor, specifying + * all of the DispatchKeys that this Tensor identifies as. This is the + * information used to dispatch operations on this tensor. + */ + DispatchKeySet key_set() const { + return key_set_; + } + + private: + [[noreturn]] void throw_cannot_call_with_symbolic(const char* meth) const; + + // NOTE: The general recipe for customizable methods is that the fastpath + // function (e.g., sizes()) does an unlikely policy test, and if doesn't + // trigger, it does the fast path implementation with no checks and going + // directly to on-TensorImpl fields. In particular, you never need to + // check ExtraMeta if the policy doesn't trigger, as non-trivial ExtraMeta + // implies the policy will always match. + // + // The default implementations of methods are "safe": they do extra tests + // to make sure the internal state is consistent no matter if you are + // doing symbolic shapes or not. If you don't want the tests, directly + // override the custom method (e.g., custom_sizes()) to do your preferred + // behavior. + + public: + /** + * Return a reference to the sizes of this tensor. This reference remains + * valid as long as the tensor is live and not resized. + */ + IntArrayRef sizes() const { + if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomSizes))) { + return sizes_custom(); + } + return sizes_and_strides_.sizes_arrayref(); + } + + SymIntArrayRef sym_sizes() const { + if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomSizes))) { + return sym_sizes_custom(); + } + // Sizes guaranteed to be non-negative, so unchecked cast is OK + return c10::fromIntArrayRefKnownNonNegative( + sizes_and_strides_.sizes_arrayref()); + } + + IntArrayRef sizes_default() const { + if (C10_UNLIKELY(has_symbolic_sizes_strides_)) { + throw_cannot_call_with_symbolic("sizes"); + } + return sizes_and_strides_.sizes_arrayref(); + } + + SymIntArrayRef sym_sizes_default() const { + if (has_symbolic_sizes_strides_) { + return symbolic_shape_meta().sizes_; + } else { + // Sizes guaranteed to be non-negative, so unchecked cast is OK + return c10::fromIntArrayRefKnownNonNegative(sizes_default()); + } + } + + template + ArrayRef generic_sizes() { + static_assert( + std::is_same_v || std::is_same_v, + "Only supports int64_t and c10::SymInt."); + + if constexpr (std::is_same_v) { + return sizes(); + } else { + return sym_sizes(); + } + } + + template + ArrayRef generic_strides() { + static_assert( + std::is_same_v || std::is_same_v, + "Only supports int64_t and c10::SymInt."); + + if constexpr (std::is_same_v) { + return strides(); + } else { + return sym_strides(); + } + } + + template + T generic_storage_offset() { + static_assert( + std::is_same_v || std::is_same_v, + "Only supports int64_t and c10::SymInt."); + + if constexpr (std::is_same_v) { + return storage_offset(); + } else { + return sym_storage_offset(); + } + } + + /** + * The number of elements in a tensor. + * + * WARNING: Previously, if you were using the Caffe2 API, you could + * test numel() == -1 to see if a tensor was uninitialized. This + * is no longer true; numel always accurately reports the product + * of sizes of a tensor. + */ + int64_t numel() const { + if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomSizes))) { + return numel_custom(); + } + return numel_; + } + + c10::SymInt sym_numel() const { + if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomSizes))) { + return sym_numel_custom(); + } + return c10::SymInt(SymInt::UNCHECKED, numel_); + } + + int64_t numel_default() const { + if (C10_UNLIKELY(has_symbolic_sizes_strides_)) { + throw_cannot_call_with_symbolic("numel"); + } + return numel_; + } + + c10::SymInt sym_numel_default() const { + if (has_symbolic_sizes_strides_) { + return symbolic_shape_meta().numel(); + } else { + return c10::SymInt(SymInt::UNCHECKED, numel_); + } + } + + /** + * Return the number of dimensions of this tensor. Note that 0-dimension + * represents a Tensor that is a Scalar, e.g., one that has a single element. + */ + int64_t dim() const { + if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomSizes))) { + return dim_custom(); + } + return static_cast(sizes_and_strides_.size()); + } + + int64_t dim_default() const { + if (has_symbolic_sizes_strides_) { + return static_cast(symbolic_shape_meta().sizes_.size()); + } else { + return static_cast(sizes_and_strides_.size()); + } + } + + /** + * Return the offset in number of elements into the storage that this + * tensor points to. Most tensors have storage_offset() == 0, but, + * for example, an index into a tensor will have a non-zero storage_offset(). + * + * WARNING: This is NOT computed in bytes. + */ + int64_t storage_offset() const { + // TODO: maybe this should be toggled by strides + if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomSizes))) { + return storage_offset_custom(); + } + return storage_offset_; + } + + c10::SymInt sym_storage_offset() const { + if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomSizes))) { + return sym_storage_offset_custom(); + } + return c10::SymInt(SymInt::UNCHECKED, storage_offset_); + } + + int64_t storage_offset_default() const { + if (C10_UNLIKELY(has_symbolic_sizes_strides_)) { + throw_cannot_call_with_symbolic("storage_offset"); + } + return storage_offset_; + } + + c10::SymInt sym_storage_offset_default() const { + if (has_symbolic_sizes_strides_) { + return symbolic_shape_meta().storage_offset_; + } else { + return c10::SymInt(SymInt::UNCHECKED, storage_offset_); + } + } + + /** + * Return a reference to the strides of this tensor. This reference remains + * valid as long as the tensor is live and not restrided. + */ + IntArrayRef strides() const { + if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomStrides))) { + return strides_custom(); + } + return sizes_and_strides_.strides_arrayref(); + } + + c10::SymIntArrayRef sym_strides() const { + if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomStrides))) { + return sym_strides_custom(); + } + return c10::fromIntArrayRefKnownNonNegative(strides_default()); + } + + IntArrayRef strides_default() const { + if (C10_UNLIKELY(has_symbolic_sizes_strides_)) { + throw_cannot_call_with_symbolic("strides"); + } + return sizes_and_strides_.strides_arrayref(); + } + + c10::SymIntArrayRef sym_strides_default() const { + if (has_symbolic_sizes_strides_) { + return symbolic_shape_meta().strides_; + } else { + return c10::fromIntArrayRefKnownNonNegative(strides_default()); + } + } + + c10::SymBool sym_is_contiguous( + at::MemoryFormat memory_format = at::MemoryFormat::Contiguous) const { + if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomStrides))) { + return sym_is_contiguous_custom(memory_format); + } + return sym_is_contiguous_default(memory_format); + } + + template + T is_contiguous_default_impl(at::MemoryFormat memory_format) const { + if (!has_symbolic_sizes_strides_) { + if (memory_format == at::MemoryFormat::ChannelsLast) { + return is_channels_last_contiguous_; + } else if (memory_format == at::MemoryFormat::ChannelsLast3d) { + return is_channels_last_3d_contiguous_; + } + return is_contiguous_; + } + + // Handle dynamic shapes. + const auto& symbolic = symbolic_shape_meta().is_contiguous(memory_format); + + if constexpr (std::is_same_v) { + return symbolic.guard_bool(__FILE__, __LINE__); + } else { + return symbolic; + } + } + + bool is_contiguous_default(at::MemoryFormat memory_format) const { + return is_contiguous_default_impl(memory_format); + } + + c10::SymBool sym_is_contiguous_default(at::MemoryFormat memory_format) const { + return is_contiguous_default_impl(memory_format); + } + + /** + * Whether or not a tensor is laid out in contiguous memory. + * + * Tensors with non-trivial strides are not contiguous. See + * compute_contiguous() for the exact definition of whether or not + * a tensor is contiguous or not. + */ + bool is_contiguous( + at::MemoryFormat memory_format = at::MemoryFormat::Contiguous) const { + if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomStrides))) { + return is_contiguous_custom(memory_format); + } + return is_contiguous_default(memory_format); + } + + bool is_strides_like_default(at::MemoryFormat memory_format) const { + if (has_symbolic_sizes_strides_) { + if (memory_format == at::MemoryFormat::ChannelsLast) { + return symbolic_shape_meta().is_channels_last().guard_bool( + __FILE__, __LINE__); + } else if (memory_format == at::MemoryFormat::ChannelsLast3d) { + return symbolic_shape_meta().is_channels_last_3d().guard_bool( + __FILE__, __LINE__); + } else { + return false; + } + } + + if (memory_format == at::MemoryFormat::ChannelsLast) { + return is_channels_last_; + } else if (memory_format == at::MemoryFormat::ChannelsLast3d) { + return is_channels_last_3d_; + } else { + return false; + } + } + + SymBool sym_is_non_overlapping_and_dense_default() const { + if (has_symbolic_sizes_strides_) { + return symbolic_shape_meta().is_non_overlapping_and_dense(); + } else { + return is_non_overlapping_and_dense_; + } + } + + bool is_non_overlapping_and_dense_default() const { + if (has_symbolic_sizes_strides_) { + return sym_is_non_overlapping_and_dense_default().guard_bool( + __FILE__, __LINE__); + } else { + return is_non_overlapping_and_dense_; + } + } + + // NB: these dim accessor functions don't have _default(), as you can use + // sizes_default/strides_default + /** + * Return the size of a tensor at some dimension, wrapping the dimension if + * necessary. + * + * NOTE: if you know wrapping is unnecessary, do sizes()[d] instead; it will + * be faster + */ + int64_t size(int64_t d) const { + if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomSizes))) { + return size_custom(d); + } + d = maybe_wrap_dim(d, dim(), /*wrap_scalar=*/false); + return sizes_and_strides_.size_at_unchecked(d); + } + + c10::SymInt sym_size(int64_t d) const { + if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomSizes))) { + return sym_size_custom(d); + } + d = maybe_wrap_dim(d, dim(), /*wrap_scalar=*/false); + const auto sizes = this->sym_sizes(); + return sizes[d]; + } + + /** + * Return the stride of a tensor at some dimension, wrapping the dimension + * if necessary. + * + * NOTE: if you know wrapping is unnecessary, do sizes()[d] instead; it will + * be faster + */ + int64_t stride(int64_t d) const { + d = maybe_wrap_dim(d, dim(), false); + if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomStrides))) { + // TODO: provide stride_custom, symmetrically with size_custom. + // There is presently no user for it; only NestedTensor is using + // size_custom overrideability + return strides_custom()[d]; // unchecked (maybe_wrap_dim enforces bounds) + } + // Intentionally don't call default, which also handles symbolic + return sizes_and_strides_.stride_at_unchecked(d); + } + + enum class SizesStridesPolicy : uint8_t { + // Default behavior, e.g., dense tensor. + // + // Can override: nothing + Default = 0, + // Customizable strides behavior, e.g., sparse tensor, + // mkldnn tensor. + // + // Can override: strides(), is_contiguous() + CustomStrides = 1, + // Customizable sizes behavior, e.g., nested tensor + // + // Can override: strides(), is_contiguous(), sizes(), dim(), numel() + CustomSizes = 2 + }; + + protected: + inline bool matches_policy(SizesStridesPolicy policy) const { + return sizes_strides_policy_ >= static_cast(policy); + } + + inline bool matches_custom(SizesStridesPolicy policy) const { + return custom_sizes_strides_ >= static_cast(policy); + } + + inline bool matches_python_custom(SizesStridesPolicy policy) const { + auto r = python_custom_sizes_strides_ >= static_cast(policy); + if (r) { + TORCH_INTERNAL_ASSERT(is_python_dispatch()) + } + return r; + } + + /** + * Customization points for the functions above. sizes_strides_policy_ + * must be set to enable these. + * + * NB: dim is overridable separately from sizes because it is possible + * for a tensor to have rank, but not well defined sizes. + */ + // sizes_strides_policy_ >= CustomStrides + + virtual bool is_strides_like_custom(at::MemoryFormat memory_format) const; + + virtual c10::SymBool sym_is_non_overlapping_and_dense_custom() const; + + bool is_non_overlapping_and_dense_custom() const { + return sym_is_non_overlapping_and_dense_custom().guard_bool( + __FILE__, __LINE__); + } + + virtual c10::SymBool sym_is_contiguous_custom( + at::MemoryFormat memory_format) const; + + bool is_contiguous_custom(at::MemoryFormat memory_format) const { + return sym_is_contiguous_custom(memory_format) + .guard_bool(__FILE__, __LINE__); + } + + // sizes_strides_policy_ >= CustomSizes + // Currently this method only exists to be overwritten by subclasses such as + // NestedTensorImpl. + virtual int64_t size_custom(int64_t d) const { + // TODO: We could add support to Python dispatch here. + // TODO: We could call into aten::size.int instead of + // sizes_custom()[d] and enable use of the dispatcher. + d = maybe_wrap_dim(d, dim(), /*wrap_scalar=*/false); + return sizes_custom()[d]; // unchecked (maybe_wrap_dim enforces bounds) + } + + virtual c10::SymInt sym_size_custom(int64_t d) const { + // TODO: We could add support to Python dispatch here. + // TODO: We could call into aten::size.int instead of + // sym_sizes_custom()[d] and enable use of the dispatcher. + d = maybe_wrap_dim(d, dim(), /*wrap_scalar=*/false); + return sym_sizes_custom()[d]; // unchecked (maybe_wrap_dim enforces bounds) + } + + virtual IntArrayRef sizes_custom() const; + virtual IntArrayRef strides_custom() const; + virtual int64_t numel_custom() const; + virtual int64_t storage_offset_custom() const; + virtual int64_t dim_custom() const; + virtual Device device_custom() const; + virtual Layout layout_custom() const; + + virtual c10::SymIntArrayRef sym_sizes_custom() const; + virtual c10::SymIntArrayRef sym_strides_custom() const; + virtual c10::SymInt sym_numel_custom() const; + virtual c10::SymInt sym_storage_offset_custom() const; + + public: +/** + * True if this tensor has storage. See storage() for details. + */ +#ifdef DEBUG + // Allow subclasses to check that their storage_ is never getting set in debug + // builds. + virtual +#else + TENSORIMPL_MAYBE_VIRTUAL +#endif + bool + has_storage() const +// NOTE: we devirtualize this because it arguably shouldn't be an +// error just to ask subclasses if they have storage. +// This used to throw for most subclasses, but OpaqueTensorImpl +// wanted it to successfully return false, so we went ahead and made +// it a non-error. +#ifdef C10_DISABLE_TENSORIMPL_EXTENSIBILITY + { + return storage_; + } +#else + ; +#endif + + /** + * Return the underlying storage of a Tensor. Multiple tensors may share + * a single storage. A Storage is an impoverished, Tensor-like class + * which supports far less operations than Tensor. + * + * Avoid using this method if possible; try to use only Tensor APIs to perform + * operations. + */ + TENSORIMPL_MAYBE_VIRTUAL const Storage& storage() const { + if (C10_UNLIKELY(storage_access_should_throw_)) { + throw_storage_access_error(); + } + return storage_; + } + + /** + * Return the underlying storage, unsafely assuming this is a basic strided + * tensor. In cases where `storage` access would throw, this returns a + * default-constructed Storage. + */ + inline const Storage& unsafe_storage() const { + return storage_; + } + + bool unique_version() const { + return version_counter_.unique(); + } + + protected: + virtual Layout layout_impl() const { + TORCH_CHECK( + false, "layout_impl is only implemented for TensorImpl subclasses."); + } + + public: + // Whether a tensor is sparse COO or not. + bool is_sparse() const { + // NB: This method is not virtual and avoid dispatches for performance + // reasons. + return key_set_.has_all(c10::sparse_ks); + } + + // Whether a tensor is sparse CSR or not. + bool is_sparse_csr() const { + return layout() == kSparseCsr; + } + + // Whether a tensor is sparse CSR/CSC/BSR/BSC or not. + bool is_sparse_compressed() const { + return key_set_.has_all(c10::sparse_csr_ks); + } + + bool is_quantized() const { + // NB: This method is not virtual and avoid dispatches for performance + // reasons. + constexpr auto quantized_ks = DispatchKeySet(DispatchKey::Quantized); + return key_set_.has_all(quantized_ks); + } + + bool is_meta() const { + // NB: This method is not virtual and avoid dispatches for performance + // reasons. + if (C10_UNLIKELY(device_policy_)) { + return device_custom().is_meta(); + } + return device_opt_.has_value() && device_opt_->type() == kMeta; + } + + bool is_cpu() const { + // NB: This method is not virtual and avoid dispatches for performance + // reasons. + if (C10_UNLIKELY(device_policy_)) { + return device_custom().is_cpu(); + } + // Note: we cannot rely on dispatch keys to determine the device type + // of a tensor, because "wrapper" tensors (like FunctionalTensorWrapper) + // don't include backend dispatch keys. + return device_opt_.has_value() && device_opt_->type() == kCPU; + } + + bool is_cuda() const { + // NB: This method is not virtual and avoid dispatches for performance + // reasons. + if (C10_UNLIKELY(device_policy_)) { + return device_custom().is_cuda(); + } + return device_opt_.has_value() && device_opt_->type() == kCUDA; + } + + bool is_xpu() const { + // NB: This method is not virtual and avoid dispatches for performance + // reasons. + if (C10_UNLIKELY(device_policy_)) { + return device_custom().is_xpu(); + } + return device_opt_.has_value() && device_opt_->type() == kXPU; + } + + bool is_ipu() const { + if (C10_UNLIKELY(device_policy_)) { + return device_custom().is_ipu(); + } + return device_opt_.has_value() && device_opt_->type() == kIPU; + } + + bool is_xla() const { + if (C10_UNLIKELY(device_policy_)) { + return device_custom().is_xla(); + } + return device_opt_.has_value() && device_opt_->type() == kXLA; + } + + bool is_mtia() const { + if (C10_UNLIKELY(device_policy_)) { + return device_custom().is_mtia(); + } + return device_opt_.has_value() && device_opt_->type() == kMTIA; + } + + bool is_hpu() const { + if (C10_UNLIKELY(device_policy_)) { + return device_custom().is_hpu(); + } + return device_opt_.has_value() && device_opt_->type() == kHPU; + } + + bool is_lazy() const { + if (C10_UNLIKELY(device_policy_)) { + return device_custom().is_lazy(); + } + return device_opt_.has_value() && device_opt_->type() == kLazy; + } + + bool is_hip() const { + // NB: This method is not virtual and avoid dispatches for performance + // reasons. + if (C10_UNLIKELY(device_policy_)) { + return device_custom().is_hip(); + } + return device_opt_.has_value() && device_opt_->type() == kHIP; + } + + bool is_ve() const { + // NB: This method is not virtual and avoid dispatches for performance + // reasons. + if (C10_UNLIKELY(device_policy_)) { + return device_custom().is_ve(); + } + return device_opt_.has_value() && device_opt_->type() == kVE; + } + + bool is_privateuseone() const { + // NB: This method is not virtual and avoid dispatches for performance + // reasons. + if (C10_UNLIKELY(device_policy_)) { + return device_custom().is_privateuseone(); + } + return device_opt_.has_value() && device_opt_->type() == kPrivateUse1; + } + + bool is_mkldnn() const { + return key_set_.has_all(c10::mkldnn_ks); + } + + bool is_vulkan() const { + if (C10_UNLIKELY(device_policy_)) { + return device_custom().is_vulkan(); + } + return device_opt_.has_value() && device_opt_->type() == kVulkan; + } + + bool is_metal() const { + if (C10_UNLIKELY(device_policy_)) { + return device_custom().is_metal(); + } + return device_opt_.has_value() && device_opt_->type() == kMetal; + } + + bool is_mps() const { + if (C10_UNLIKELY(device_policy_)) { + return device_custom().is_mps(); + } + return device_opt_.has_value() && device_opt_->type() == kMPS; + } + + bool is_maia() const { + if (C10_UNLIKELY(device_policy_)) { + return device_custom().is_maia(); + } + return device_opt_.has_value() && device_opt_->type() == kMAIA; + } + + bool is_nested() const { + return key_set_.has(DispatchKey::NestedTensor); + } + + // TODO: remove this once we don't automatically enabled Autograd dispatch + // keys + // in TensorImpl constructor. + // DON'T USE THIS API!! It's only created for testing purpose in + // file aten/src/ATen/core/boxing/impl/test_helpers.h + void remove_autograd_key() { + key_set_ = key_set_ - autograd_dispatch_keyset; + } + + // Inference tensor doesn't have autograd or ADInplaceOrView key. + // Invariant: + // Inference tensor has version_counter_.enabled() == false + bool is_inference() { + bool no_ADInplaceOrView = !key_set_.has_any(c10::inplace_or_view_ks); + bool no_Autograd = !key_set_.has_any(c10::autograd_dispatch_keyset); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + no_ADInplaceOrView == no_Autograd, + "ADInplaceOrView and Autograd keys must be on/off at the same time."); + return no_ADInplaceOrView && no_Autograd; + } + + DeviceIndex get_device() const { + if (C10_UNLIKELY(device_policy_)) { + return device_custom().index(); + } + return device_default().index(); + } + + Device device() const { + if (C10_UNLIKELY(device_policy_)) { + return device_custom(); + } + return device_default(); + } + + protected: + c10::Device device_default() const { + TORCH_CHECK(device_opt_.has_value(), "tensor does not have a device"); + // See NOTE [std::optional operator usage in CUDA] + return *device_opt_; + } + + public: + Layout layout() const { + if (C10_UNLIKELY(layout_policy_)) { + return layout_custom(); + } + + // NB: This method is not virtual and avoid dispatches for perf. + // strided is also the most common layout type, so we check for + // strided case first. + // This keyset must also be kept in sync with the logic in + // is_sparse() / is_sparse_csr() / is_mkldnn() + constexpr auto sparse_and_sparsecsr_and_mkldnn_ks = + c10::sparse_ks | c10::sparse_csr_ks | c10::mkldnn_ks; + if (!key_set_.has_any(sparse_and_sparsecsr_and_mkldnn_ks)) { + return kStrided; + } else if (is_sparse()) { + return kSparse; + } else if (is_sparse_compressed()) { + // Typically, the tensor dispatch keys define the tensor layout + // uniquely. This allows using non-virtual layout method for + // better performance. However, when tensor's layout depends, + // say, on tensor attributes, one must use this execution path + // where the corresponding tensor impl class overwrites virtual + // layout_impl() method. + // + // TODO: implement layout() as native function/method so that + // __torch_dispatch__ users will be able to redefine the + // layout() method. + return layout_impl(); + } else { + TORCH_INTERNAL_ASSERT( + is_mkldnn(), "There is an error in the layout calculation logic."); + return kMkldnn; + } + } + + /** + * True if a tensor was auto-wrapped from a C++ or Python number. + * For example, when you write 't + 2', 2 is auto-wrapped into a Tensor + * with `is_wrapped_number_` set to true. + * + * Wrapped numbers do not participate in the result type computation for + * mixed-type operations if there are any Tensors that are not wrapped + * numbers. This is useful, because we want 't + 2' to work with + * any type of tensor, not just LongTensor (which is what integers + * in Python represent). + * + * Otherwise, they behave like their non-wrapped equivalents. + * See [Result type computation] in TensorIterator.h. + * + * Why did we opt for wrapped numbers, as opposed to just having + * an extra function add(Tensor, Scalar)? This helps greatly reduce + * the amount of code we have to write for add, when actually + * a Tensor-Scalar addition is really just a Tensor-Tensor + * addition when the RHS is 0-dim (except for promotion behavior.) + */ + bool is_wrapped_number() const { + return is_wrapped_number_; + } + + /** + * Set whether or not a tensor was auto-wrapped from a C++ or Python + * number. You probably don't want to call this, unless you are + * writing binding code. + */ + void set_wrapped_number(bool value) { + TORCH_INTERNAL_ASSERT(dim() == 0); + is_wrapped_number_ = value; + } + + /** + * Returns true if Tensor supports as_strided and as_strided_backward. + * This is used in autograd to perform inplace update on view Tensors. + * See Note [View + Inplace update for base tensor] and + * [View + Inplace update for view tensor] for details. + * Note this method only returns true for XLA backend, where it + * simulates strided Tensor to support most view ops, but it cannot + * fully support general `as_strided` case. + * It can be expanded as needed in the future, e.g sparse Tensor. + */ + inline bool support_as_strided() const { + if (is_nested()) { + return false; + } + if (key_set_.has(DispatchKey::Functionalize)) { + return false; + } + return device().supports_as_strided(); + } + + // ~~~~~ Autograd API ~~~~~ + // Some methods below are defined in TensorImpl.cpp because Tensor is an + // incomplete type. + + /** + * Set whether or not a tensor requires gradient. + */ + void set_requires_grad(bool requires_grad); + + /** + * True if a tensor requires gradient. Tensors which require gradient + * have history tracked for any operations performed on them, so that + * we can automatically differentiate back to them. A tensor that + * requires gradient and has no history is a "leaf" tensor, which we + * accumulate gradients into. + */ + bool requires_grad() const; + + /** + * Return a mutable reference to the gradient. This is conventionally + * used as `t.grad() = x` to set a gradient to a completely new tensor. + */ + at::Tensor& mutable_grad(); + + /** + * Return the accumulated gradient of a tensor. This gradient is written + * into when performing backwards, when this tensor is a leaf tensor. + */ + const at::Tensor& grad() const; + + /** + * Whether or not the imaginary part of the tensor should be negated + */ + inline bool is_conj() const { + constexpr auto conjugate_ks = DispatchKeySet(DispatchKey::Conjugate); + return key_set_.has_all(conjugate_ks); + } + + /** + * Set whether or not to take the conjugate of the tensor (flip the imaginary + * bit). + */ + void _set_conj(bool value) { + if (value) { + key_set_ = key_set_.add(DispatchKey::Conjugate); + TORCH_INTERNAL_ASSERT(isComplexType(typeMetaToScalarType(dtype()))); + } else { + key_set_ = key_set_.remove(DispatchKey::Conjugate); + } + } + + /** + * XXX: do not use, private api! + * Update the backend component related keys to the backend component + * corresponding to this device. + */ + void _change_backend_component_keys(c10::Device device); + + /** + * Whether or not the tensor is a zerotensor + */ + inline bool _is_zerotensor() const { + constexpr auto zerotensor_ks = DispatchKeySet(DispatchKey::ZeroTensor); + return key_set_.has_all(zerotensor_ks); + } + + /** + Set whether or not the tensor is a zero tensor + */ + void _set_zero(bool value) { + if (value) { + TORCH_INTERNAL_ASSERT( + false, + "Please call `torch._efficientzerotensor` if you want to create a tensor with no storage."); + } else { + key_set_ = key_set_.remove(DispatchKey::ZeroTensor); + } + } + + /** + * Whether or not the tensor should be negated + */ + inline bool is_neg() const { + constexpr auto negative_ks = DispatchKeySet(DispatchKey::Negative); + return key_set_.has_all(negative_ks); + } + + /** + * Set whether or not to take the conjugate of the tensor (flip the imaginary + * bit). + */ + void _set_neg(bool value) { + if (value) { + key_set_ = key_set_.add(DispatchKey::Negative); + } else { + key_set_ = key_set_.remove(DispatchKey::Negative); + } + } + + /** + * Return the accumulated gradient of a tensor. This gradient is computed + * using forward mode AD. + * + * This is an internal API that should never be used by end users. + * + * The API is as follows: + * - "level" allows to specify the level of forward AD nesting for which the + * gradient should be returned. Note that since levels are not fully + * supported yet, this argument should be 0. See documentation for + * torch::autograd::enter_dual_level for more details about forward AD + * nesting. + * - "self" should represent the Tensor whose forward grad is accessed. It + * is required when dealing with view. + */ + const at::Tensor& _fw_grad(uint64_t level, const at::TensorBase& self) const; + + /** + * Sets the forward gradient for this Tensor. + * The given Tensor might not be used directly and its content will be copied. + * + * This is an internal API that should never be used by end users. + * + * The API is as follows: + * - "new_grad" is a Tensor containing the new value of the gradient that + * should be set + * - "self" should represent the Tensor whose forward grad is accessed. It + * is required when dealing with view. + * - "level" allows to specify the level of forward AD nesting for which the + * gradient should be set. Note that since levels are not fully supported + * yet, this argument should be 0. See documentation for + * torch::autograd::enter_dual_level for more details about forward AD + * nesting. + * - "is_inplace_op" is a boolean flag that tells if this gradient was + * generated by an inplace operation or an out of place one. This allows + * better error checking. + */ + void _set_fw_grad( + const at::TensorBase& new_grad, + const at::TensorBase& self, + uint64_t level, + bool is_inplace_op); + + /** + * Return a typed data pointer to the actual data which this tensor refers to. + * This checks that the requested type (from the template parameter) matches + * the internal type of the tensor. + * + * It is invalid to call data() on a dtype-uninitialized tensor, even if + * the size is 0. + * + * WARNING: If a tensor is not contiguous, you MUST use strides when + * performing index calculations to determine the location of elements in + * the tensor. We recommend using 'TensorAccessor' to handle this computation + * for you; this class is available from 'Tensor'. + */ + template + const T* data_dtype_initialized() const { + return data_dtype_initialized_impl( + [this] { return static_cast(storage_.data()); }); + } + + /** + * Return a mutable typed data pointer to the actual data which this + * tensor refers to. This checks that the requested type (from the + * template parameter) matches the internal type of the tensor. + * + * It is invalid to call data() on a dtype-uninitialized tensor, even if + * the size is 0. + * + * WARNING: If a tensor is not contiguous, you MUST use strides when + * performing index calculations to determine the location of elements in + * the tensor. We recommend using 'TensorAccessor' to handle this computation + * for you; this class is available from 'Tensor'. + */ + template + T* mutable_data_dtype_initialized() { + return data_dtype_initialized_impl( + [this] { return static_cast(storage_.mutable_data()); }); + } + + private: + // Shared implementation of data_dtype_initialized() and + // mutable_data_dtype_initialized(). + template + T* data_dtype_initialized_impl(const Func& get_data) const { + TORCH_CHECK( + data_type_.Match>(), + "Tensor type mismatch, caller expects elements to be ", + caffe2::TypeMeta::TypeName>(), + ", while tensor contains ", + data_type_.name(), + ". "); + return data_ptr_impl_impl(get_data); + } + + public: + /** + * More efficient helper for Tensor::data_ptr(). Like data(), but + * does not do a type check. Unlike the untemplated data(), does + * check has_storage() and storage_initialized(). + */ + template + inline const T* data_ptr_impl() const { + return data_ptr_impl_impl( + [this] { return static_cast(storage_.data()); }); + } + + /** + * More efficient helper for Tensor::data_ptr(). Like data(), but + * does not do a type check. Unlike the untemplated data(), does + * check has_storage() and storage_initialized(). + */ + template + inline T* mutable_data_ptr_impl() { + return data_ptr_impl_impl( + [this] { return static_cast(storage_.mutable_data()); }); + } + + private: + // Shared implementation of mutable_data_ptr_impl() and the future + // mutable_data_ptr_impl(). + template + __ubsan_ignore_pointer_overflow__ T* data_ptr_impl_impl( + const Func& get_data) const { + if (C10_UNLIKELY(!has_storage())) { + throw_data_ptr_access_error(); + } + TORCH_CHECK( + storage_initialized(), + "The tensor has a non-zero number of elements, but its data is not allocated yet.\n" + "If you're using torch.compile/export/fx, it is likely that we are erroneously " + "tracing into a custom kernel. To fix this, please wrap the custom kernel into " + "an opaque custom op. Please see the following for details: " + "https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html\n" + "If you're using Caffe2, Caffe2 uses a lazy allocation, so you will need to call " + "mutable_data() or raw_mutable_data() to actually allocate memory."); + // Caller does the type check. + // Note: storage_offset_ can be non-null even for zero-elements tensors + // (for example if created as `torch.empty(5)[10:]`) that triggers + // applying non-zero offset to null pointer in UBSan + return get_data() + storage_offset_; + } + + public: + /** + * Return a const void* data pointer to the actual data which this + * tensor refers to. + * + * It is invalid to call data() on a dtype-uninitialized tensor, even if the + * size is 0. + * + * WARNING: The data pointed to by this tensor may not contiguous; do NOT + * assume that itemsize() * numel() is sufficient to compute the bytes that + * can be validly read from this tensor. + */ + inline const void* data() const { + return data_impl( + [this] { return static_cast(storage_.data()); }); + } + + /** + * Return a void* data pointer to the actual data which this tensor refers to. + * + * It is invalid to call mutable_data() on a dtype-uninitialized + * tensor, even if the size is 0. + * + * WARNING: The data pointed to by this tensor may not contiguous; do NOT + * assume that itemsize() * numel() is sufficient to compute the bytes that + * can be validly read from this tensor. + */ + inline void* mutable_data() { + return data_impl( + [this] { return static_cast(storage_.mutable_data()); }); + } + + private: + /// Shared implementation of data() and mutable_data(). + /// + /// get_data must return a byte-addressed pointer, e.g. char*, + /// std::byte const*, etc. + template + Void* data_impl(const Func& get_data) const { + if (C10_UNLIKELY(!has_storage())) { + throw_data_ptr_access_error(); + } + TORCH_CHECK( + dtype_initialized(), + "Cannot access data pointer of Tensor that doesn't have initialized dtype " + "(e.g., caffe2::Tensor x(CPU), prior to calling mutable_data() on x)"); + auto* data = get_data(); + static_assert( + sizeof(*data) == 1, "get_data must return a byte-addressed pointer."); + // Computing an offset into an empty tensor would be UB, since an empty + // tensor's storage will be nullptr, and adding a nonzero offset to nullptr + // is UB. So we skip the offset computation in this case. + if (is_empty()) { + return nullptr; + } + return data + data_type_.itemsize() * storage_offset_; + } + + public: + /** + * Returns the TypeMeta of a tensor, which describes what data type + * it is (e.g., int, float, ...) + */ + const caffe2::TypeMeta dtype() const { + return data_type_; + } + + /** + * Return the size of a single element of this tensor in bytes. + */ + size_t itemsize() const { + TORCH_CHECK( + dtype_initialized(), + "Cannot report itemsize of Tensor that doesn't have initialized dtype " + "(e.g., caffe2::Tensor x(CPU), prior to calling mutable_data() on x)"); + return data_type_.itemsize(); + } + + void set_backend_meta(intrusive_ptr backend_meta) { + get_extra_meta().backend_meta_ = std::move(backend_meta); + } + + c10::BackendMeta* get_backend_meta() { + if (!extra_meta_) { + return nullptr; + } + return extra_meta_->backend_meta_.get(); + } + + intrusive_ptr get_backend_meta_intrusive_ptr() const { + if (!extra_meta_) { + return nullptr; + } + return extra_meta_->backend_meta_; + } + + void release_storage_and_set_meta_custom_data_ptr_error_msg_( + std::optional s) { + storage_ = {}; + set_storage_access_should_throw(); + get_extra_meta().custom_data_ptr_error_msg_ = s; + get_extra_meta().custom_storage_error_msg_ = std::move(s); + } + + protected: + /** + * Returns the human-readable name of the actual type of this object (e.g., + * TensorImpl, BatchedTensorImpl, etc.). Used for error messages. + */ + virtual const char* tensorimpl_type_name() const { + return "TensorImpl"; + } + + private: + [[noreturn]] void throw_storage_access_error() const; + [[noreturn]] void throw_data_ptr_access_error() const; + + ExtraMeta& get_extra_meta() { + if (!extra_meta_) { + extra_meta_ = std::make_unique(); + } + return *extra_meta_; + } + + c10::SymbolicShapeMeta& symbolic_shape_meta() { + TORCH_INTERNAL_ASSERT(extra_meta_ && extra_meta_->symbolic_shape_meta_); + return *extra_meta_->symbolic_shape_meta_; + } + + const c10::SymbolicShapeMeta& symbolic_shape_meta() const { + TORCH_INTERNAL_ASSERT(extra_meta_ && extra_meta_->symbolic_shape_meta_); + return *extra_meta_->symbolic_shape_meta_; + } + + public: + /** + * True if a tensor has no elements (e.g., numel() == 0). + */ + inline bool is_empty() const { + return numel() == 0; + } + + // if we are going to use sym sizes, we should be setting sym strides at the + // same time, otherwise it's very easy to misuse this API + void set_sizes_and_strides( + c10::SymIntArrayRef sizes, + c10::SymIntArrayRef strides, + std::optional storage_offset = std::nullopt); + // This is renamed to avoid breaking overload BC + void generic_set_sizes_contiguous(c10::SymIntArrayRef sizes); + void generic_set_sizes_contiguous(c10::IntArrayRef sizes) { + set_sizes_contiguous(sizes); + } + + /** + * Change the size at some dimension. This DOES NOT update strides; + * thus, most changes to size will not preserve contiguity. You probably + * also want to call set_stride() when you call this. + * + * TODO: This should be jettisoned in favor of `set_sizes_and_strides`, + * which is harder to misuse. + */ + virtual void set_size(int64_t dim, int64_t new_size) { + TORCH_CHECK( + allow_tensor_metadata_change(), + "set_size ", + err_msg_tensor_metadata_change_not_allowed); + TORCH_CHECK( + !matches_policy(SizesStridesPolicy::CustomSizes), + "set_size() called on tensor with dynamic shapes or customized size behavior") + sizes_and_strides_.size_at(dim) = new_size; + refresh_numel(); + refresh_contiguous(); + } + + /** + * Change the stride at some dimension. + * + * TODO: This should be jettisoned in favor of `set_sizes_and_strides`, + * which is harder to misuse. + */ + virtual void set_stride(int64_t dim, int64_t new_stride) { + TORCH_CHECK( + allow_tensor_metadata_change(), + "set_stride ", + err_msg_tensor_metadata_change_not_allowed); + TORCH_CHECK( + !has_symbolic_sizes_strides_, + "set_stride() called on tensor with symbolic shape") + sizes_and_strides_.stride_at_unchecked(dim) = new_stride; + refresh_contiguous(); + } + + /** + * Set the offset into the storage of this tensor. + * + * WARNING: This does NOT check if the tensor is in bounds for the new + * location at the storage; the caller is responsible for checking this + * (and resizing if necessary.) + */ + virtual void set_storage_offset(int64_t storage_offset) { + TORCH_CHECK( + allow_tensor_metadata_change(), + "set_storage_offset ", + err_msg_tensor_metadata_change_not_allowed); + // TODO: this should probably consult policy + TORCH_CHECK( + !has_symbolic_sizes_strides_, + "set_storage_offset() called on tensor with symbolic shape") + storage_offset_ = storage_offset; + } + + /** + * Like set_sizes_and_strides but assumes contiguous strides. + * + * WARNING: This function does not check if the requested + * sizes/strides are in bounds for the storage that is allocated; + * this is the responsibility of the caller + */ + void set_sizes_contiguous(IntArrayRef new_size) { + TORCH_CHECK( + allow_tensor_metadata_change(), + "set_sizes_contiguous ", + err_msg_tensor_metadata_change_not_allowed); + TORCH_CHECK( + !matches_policy(SizesStridesPolicy::CustomStrides), + "tried to directly modify sizes for customized tensor"); + sizes_and_strides_.set_sizes(new_size); + + refresh_numel(); + empty_tensor_restride( + MemoryFormat::Contiguous); // calls refresh_contiguous() + } + + C10_ALWAYS_INLINE const impl::SizesAndStrides& sizes_and_strides() { + return sizes_and_strides_; + } + + /** + * Set the sizes and strides of a tensor. + * + * WARNING: This function does not check if the requested + * sizes/strides are in bounds for the storage that is allocated; + * this is the responsibility of the caller + */ + void set_sizes_and_strides( + IntArrayRef new_size, + IntArrayRef new_stride, + std::optional storage_offset = std::nullopt) { + TORCH_CHECK( + allow_tensor_metadata_change(), + "set_sizes_and_strides ", + err_msg_tensor_metadata_change_not_allowed); + TORCH_CHECK( + !has_symbolic_sizes_strides_, + "set_sizes_and_strides() called on tensor with symbolic shape") + TORCH_CHECK( + new_size.size() == new_stride.size(), + "dimensionality of sizes (", + new_size.size(), + ") must match dimensionality of strides (", + new_stride.size(), + ")"); + const auto new_dim = new_size.size(); + bool overflowed = false; + sizes_and_strides_.set_sizes(new_size); + + if (new_dim > 0) { + for (size_t dim = new_dim - 1;; dim--) { + if (new_stride[dim] >= 0) { + sizes_and_strides_.stride_at_unchecked(dim) = new_stride[dim]; + } else { + // XXX: This behavior is surprising and may need to be removed to + // support negative strides. Some pytorch functions rely on it: + // for example, torch.cat (run TestTorch.test_cat_empty). + if (dim == new_dim - 1) { + sizes_and_strides_.stride_at_unchecked(dim) = 1; + } else { + // Keep stride monotonically increasing to match NumPy. + overflowed |= c10::mul_overflows( + sizes_and_strides_.stride_at_unchecked(dim + 1), + std::max( + sizes_and_strides_.size_at_unchecked(dim + 1), 1), + std::addressof(sizes_and_strides_.stride_at_unchecked(dim))); + } + } + if (dim == 0) + break; + } + TORCH_CHECK(!overflowed, "Stride calculation overflowed"); + } + + refresh_numel(); + refresh_contiguous(); + + if (storage_offset.has_value()) { + storage_offset_ = *storage_offset; + } + } + + /** + * Set whether a tensor allows changes to its metadata (e.g. sizes / strides / + * storage / storage_offset). See NOTE [ Metadata Change for a Detached Tensor + * ] for details. + */ + void set_allow_tensor_metadata_change(bool value [[maybe_unused]]) { + // TODO: at some point, we should kill this field completely. + allow_tensor_metadata_change_ = true; + } + + /** + * True if a tensor allows changes to its metadata (e.g. sizes / strides / + * storage / storage_offset). See NOTE [ Metadata Change for a Detached Tensor + * ] for details. + */ + bool allow_tensor_metadata_change() const { + return allow_tensor_metadata_change_; + } + + /** + * Set the pointer to autograd metadata. + */ + void set_autograd_meta( + std::unique_ptr autograd_meta); + + /** + * Return the pointer to autograd metadata. May return nullptr if the + * tensor does not track gradients. + */ + c10::AutogradMetaInterface* autograd_meta() const; + + /** + * Set the pointer to named tensor metadata. + */ + void set_named_tensor_meta( + std::unique_ptr named_tensor_meta) { + TORCH_WARN_ONCE( + "Named tensors and all their associated APIs are an experimental feature ", + "and subject to change. Please do not use them for anything important ", + "until they are released as stable."); +#ifdef DEBUG + if (named_tensor_meta) { + TORCH_INTERNAL_ASSERT(named_tensor_meta->slow_dim() == dim()); + } +#endif + if (named_tensor_meta) { + get_extra_meta().named_tensor_meta_ = std::move(named_tensor_meta); + key_set_ = key_set_.add(DispatchKey::Named); + } else { + if (extra_meta_) { + extra_meta_->named_tensor_meta_ = nullptr; + } + key_set_ = key_set_.remove(DispatchKey::Named); + } + } + + void set_python_dispatch(bool k) { + if (k) { + key_set_ = key_set_.add(c10::python_ks); + } else { + key_set_ = key_set_ - c10::python_ks; + } + } + + bool is_python_dispatch() const { + return key_set_.has_all(c10::python_ks); + } + + /** + * Return the pointer to named tensor metadata. + */ + const c10::NamedTensorMetaInterface* named_tensor_meta() const { + if (!extra_meta_) { + return nullptr; + } + return extra_meta_->named_tensor_meta_.get(); + } + + c10::NamedTensorMetaInterface* named_tensor_meta() { + if (!extra_meta_) { + return nullptr; + } + return extra_meta_->named_tensor_meta_.get(); + } + + bool has_named_tensor_meta() const { + if (!extra_meta_) { + return false; + } + return extra_meta_->named_tensor_meta_ != nullptr; + } + + // NOTE [ TensorImpl Shallow-Copying ] + // + // TensorImpl shallow-copying is used when we want to have two Variables share + // the same tensor metadata (e.g. sizes / strides / storage pointer / + // storage_offset), but each with a different autograd history. Example call + // sites: + // + // 1. `var_detached = var.detach()` uses `shallow_copy_and_detach()` to create + // `var_detached` that shares the same tensor metadata with `var`, but with a + // completely new autograd history. + // 2. `var.set_data(tensor)` uses `shallow_copy_from()` to copy tensor + // metadata from `tensor` into `var`, while keeping `var`'s original + // AutogradMeta. + // + // Functions that shallow-copy a TensorImpl (such as + // `shallow_copy_and_detach()` / `shallow_copy_from()` / + // `copy_tensor_metadata()`) copy the tensor metadata fields (e.g. sizes / + // strides / storage pointer / storage_offset) by value. However, the + // following fields are not copied: + // + // 1. the AutogradMeta pointer, because it is unique for each Variable. + // 2. the version counter, because the destination TensorImpl's version + // counter is either set to the passed-in `version_counter` (in + // `shallow_copy_and_detach()` and `copy_tensor_metadata()`), or it is kept + // intact (in `shallow_copy_from()`). See NOTE [ Version Counter Sharing ] for + // details. + // + // In `shallow_copy_and_detach()` and `copy_tensor_metadata()`, the passed-in + // `allow_tensor_metadata_change` determines whether the TensorImpl + // shallow-copy allows changes to its metadata (e.g. sizes / strides / storage + // / storage_offset). See NOTE [ Metadata Change for a Detached Tensor ] for + // details. + // + // In `shallow_copy_from()`, we don't check the destination TensorImpl's + // `allow_tensor_metadata_change_`, because `shallow_copy_from()` is used for + // implementing functions such as `var.set_data(tensor)`, which changes + // `var`'s tensor metadata and expects its `allow_tensor_metadata_change_` to + // be ignored. + + /** + * One TensorImpl can be copied to another TensorImpl if they have the same + * DispatchKeySet. The only two special cases (for legacy reason) are: + * CPU is compatible with CUDA and SparseCPU is + * compatible with SparseCUDA. + */ + inline bool has_compatible_shallow_copy_type(DispatchKeySet from) { + auto is_dense = [](DispatchKeySet ts) { + constexpr auto dense_backends = DispatchKeySet( + {BackendComponent::CPUBit, + BackendComponent::CUDABit, + BackendComponent::MPSBit, + BackendComponent::HIPBit, + BackendComponent::XPUBit, + BackendComponent::HPUBit, + BackendComponent::MTIABit}); + constexpr auto dense_k = DispatchKeySet(DispatchKey::Dense); + return ts.has_any(dense_k) && ts.has_any(dense_backends); + }; + auto is_sparse = [](DispatchKeySet ts) { + constexpr auto sparse_backends = DispatchKeySet( + {BackendComponent::CPUBit, + BackendComponent::CUDABit, + BackendComponent::MPSBit, + BackendComponent::HIPBit, + BackendComponent::XPUBit}); + constexpr auto sparse_k = DispatchKeySet(DispatchKey::Sparse); + return ts.has_any(sparse_k) && ts.has_any(sparse_backends); + }; + auto is_sparse_compressed = [](DispatchKeySet ts) { + constexpr auto sparse_compressed_k = + DispatchKeySet(DispatchKey::SparseCsr); + return ts.has_any(sparse_compressed_k); + }; + return (key_set_ == from) || (is_dense(key_set_) && is_dense(from)) || + (is_sparse(key_set_) && is_sparse(from)) || + (is_sparse_compressed(key_set_) && is_sparse_compressed(from)); + ; + } + + private: + template + c10::intrusive_ptr shallow_copy_and_detach_core( + VariableVersion&& version_counter, + bool allow_tensor_metadata_change) const; + + public: + /** + * Return a TensorImpl that is a shallow-copy of this TensorImpl. + * + * For usage of `version_counter` and `allow_tensor_metadata_change`, + * see NOTE [ TensorImpl Shallow-Copying ]. + */ + virtual c10::intrusive_ptr shallow_copy_and_detach( + const c10::VariableVersion& version_counter, + bool allow_tensor_metadata_change) const; + + /** + * Return a TensorImpl that is a shallow-copy of this TensorImpl. + * + * For usage of `version_counter` and `allow_tensor_metadata_change`, + * see NOTE [ TensorImpl Shallow-Copying ]. + */ + virtual c10::intrusive_ptr shallow_copy_and_detach( + c10::VariableVersion&& version_counter, + bool allow_tensor_metadata_change) const; + + /** + * Shallow-copies data from another TensorImpl into this TensorImpl. + * + * For why this function doesn't check this TensorImpl's + * `allow_tensor_metadata_change_`, see NOTE [ TensorImpl Shallow-Copying ]. + */ + virtual void shallow_copy_from(const c10::intrusive_ptr& impl) { + copy_tensor_metadata( + /*src_impl=*/impl.get(), + /*dest_impl=*/this, + /*version_counter=*/version_counter(), + /*allow_tensor_metadata_change=*/allow_tensor_metadata_change()); + } + + // Inference tensor doesn't have version counter, + // set_version_counter is no-op for them. + void set_version_counter(const c10::VariableVersion& version_counter) { + TORCH_CHECK( + !(is_inference() && version_counter.enabled()), + "Cannot set version_counter for inference tensor"); + version_counter_ = version_counter; + } + + void set_version_counter(c10::VariableVersion&& version_counter) { + TORCH_CHECK( + !(is_inference() && version_counter.enabled()), + "Cannot set version_counter for inference tensor"); + version_counter_ = std::move(version_counter); + } + + const c10::VariableVersion& version_counter() const noexcept { + return version_counter_; + } + + void bump_version() { + version_counter_.bump(); + } + + impl::PyObjectSlot* pyobj_slot() { + return &pyobj_slot_; + } + + const impl::PyObjectSlot* pyobj_slot() const { + return &pyobj_slot_; + } + + void incref_pyobject() const noexcept override final; + + void decref_pyobject() const noexcept override final; + + bool try_incref_pyobject() const noexcept override final; + + private: + // See NOTE [std::optional operator usage in CUDA] + // We probably don't want to expose this publicly until + // the note is addressed. + std::optional device_opt() const { + return device_opt_; + } + + public: + /** + * The device type of a Tensor, e.g., DeviceType::CPU or DeviceType::CUDA. + */ + DeviceType device_type() const { + // TODO: A useful internal assert would be to show that device_opt_ is null + // only if you are an undefined tensor + TORCH_CHECK( + device_opt_.has_value(), + "device_type cannot be run on undefined Tensor"); + // See NOTE [std::optional operator usage in CUDA] + return (*device_opt_).type(); + } + + /** + * @brief Extends the outer-most dimension of this tensor by num elements, + * preserving the existing data. + * + * The underlying data may be reallocated in order to accommodate the new + * elements, in which case this tensors' capacity is grown at a factor of + * growthPct. This ensures that Extend runs on an amortized O(1) time + * complexity. + * + * This op is auto-asynchronous if the underlying device (CUDA) supports it. + */ + void Extend(int64_t num, float growthPct); + + /** + * @brief Reserve space for the underlying tensor. + * + * This must be called after Resize(), since we only specify the first + * dimension This does not copy over the old data to the newly allocated space + */ + void ReserveSpace(int64_t outer_dim); + + /** + * @brief Resizes a tensor. + * + * Resize takes in a vector of ints specifying the dimensions of the tensor. + * You can pass in an empty vector to specify that it is a scalar (i.e. + * containing one single item). + * + * The underlying storage may be deleted after calling Resize: if the new + * shape leads to a different number of items in the tensor, the old memory + * is deleted and new memory will be allocated next time you call + * mutable_data(). However, if the shape is different but the total number of + * items is the same, the underlying storage is kept. + * + * This method respects caffe2_keep_on_shrink. Consult the internal logic + * of this method to see exactly under what circumstances this flag matters. + */ + template + void Resize(Ts... dim_source) { + bool size_changed = SetDims(dim_source...); + if (size_changed) { + HandleResize(); + } + } + + template + void Resize(const std::vector& dim_source) { + Resize(ArrayRef(dim_source)); + } + + /** + * Resizes the tensor without touching underlying storage. + * This requires the total size of the tensor to remains constant. + */ + void Reshape(const std::vector& dims); + + /** + * Release whatever memory the tensor was holding but keep size and type + * information. Subsequent call to mutable_data will trigger new memory + * allocation. + */ + void FreeMemory(); + + /** + * @brief Shares the data with another tensor. + * + * To share data between two tensors, the sizes of the two tensors must be + * equal already. The reason we do not implicitly do a Resize to make the two + * tensors have the same shape is that we want to allow tensors of different + * shapes but the same number of items to still be able to share data. This + * allows one to e.g. have a n-dimensional Tensor and a flattened version + * sharing the same underlying storage. + * + * The source tensor should already have its data allocated. + */ + // To be deprecated + void ShareData(const TensorImpl& src); + + void ShareExternalPointer( + DataPtr&& data_ptr, + const caffe2::TypeMeta data_type, + size_t size_bytes); + + /** + * Returns a mutable raw pointer of the underlying storage. Since we will need + * to know the type of the data for allocation, a TypeMeta object is passed in + * to specify the necessary information. This is conceptually equivalent of + * calling mutable_data() where the TypeMeta parameter meta is derived from + * the type T. This function differs from mutable_data() in the sense that + * the type T can be specified during runtime via the TypeMeta object. + * + * If the existing data does not match the desired type, it will be deleted + * and a new storage will be created. + */ + inline void* raw_mutable_data(const caffe2::TypeMeta& meta) { + // For 0-size tensors it's fine to return any pointer (including nullptr) + if (data_type_ == meta && storage_initialized()) { + return static_cast( + static_cast(storage_.mutable_data()) + + storage_offset_ * meta.itemsize()); + } else { + bool had_special_dtor = data_type_.placementDelete() != nullptr; + storage_offset_ = 0; + data_type_ = meta; + // NB: device is not changed + + // We can reuse the existing buffer if the current data does not have + // a special destructor and the new data doesn't have a special + // constructor. + if (numel_ == 0 || + (meta.placementNew() == nullptr && !had_special_dtor && + (storage_.nbytes() >= (numel_ * data_type_.itemsize())))) { + TORCH_INTERNAL_ASSERT( + storage_offset_ == 0); // because we just reallocated + return storage_.mutable_data(); + } + Allocator* allocator = storage_.allocator(); + // Storage might have nullptr allocator in rare cases, for example, if + // an external memory segment has been wrapped with Tensor and we don't + // know how to reallocate it. However, in order to preserve legacy C2 + // behavior, we allow reallocating the memory using default allocator. + if (allocator == nullptr) { + allocator = GetAllocator(storage_.device_type()); + } + if (meta.placementNew()) { + // For types that need placement new, we will call it, as well as + // making sure that when the data is freed, it calls the right + // destruction procedure. + auto size = numel_; + auto dtor = data_type_.placementDelete(); + auto data_ptr = allocator->allocate(numel_ * data_type_.itemsize()); + storage_.set_data_ptr_noswap(PlacementDeleteContext::makeDataPtr( + std::move(data_ptr), dtor, size, storage_.device())); + data_type_.placementNew()(storage_.mutable_data(), numel_); + } else { + // For fundamental type, new and delete is easier. + storage_.set_data_ptr_noswap( + allocator->allocate(numel_ * data_type_.itemsize())); + } + storage_.set_nbytes(numel_ * data_type_.itemsize()); + TORCH_INTERNAL_ASSERT( + storage_offset_ == 0); // because we just reallocated + device_opt_ = storage_.device(); + return storage_.mutable_data(); + } + } + + /** + * Returns a typed pointer of the underlying storage. + * + * For fundamental types, we reuse possible existing storage if there + * is sufficient capacity. + */ + template + inline T* mutable_data() { + if (storage_initialized() && data_type_.Match()) { + return static_cast(storage_.mutable_data()) + storage_offset_; + } + // Check it here statically - otherwise TypeMeta would throw the runtime + // error in attempt to invoke TypeMeta::ctor() + static_assert( + std::is_default_constructible_v, + "Tensor can't hold non-default-constructable types"); + return static_cast(raw_mutable_data(caffe2::TypeMeta::Make())); + } + + /** + * True if a tensor is storage initialized. A tensor may become + * storage UNINITIALIZED after a Resize() or FreeMemory() + */ + bool storage_initialized() const { + TORCH_CHECK( + has_storage(), + "cannot call storage_initialized on tensor that does not have storage"); + return storage_.data() || numel_ == 0; + } + + /** + * True if a tensor is dtype initialized. A tensor allocated with + * Caffe2-style constructors is dtype uninitialized until the + * first time mutable_data() is called. + */ + bool dtype_initialized() const noexcept { + return data_type_ != caffe2::TypeMeta(); + } + + void set_storage_keep_dtype(at::Storage storage) { + TORCH_CHECK( + allow_tensor_metadata_change(), + "set_storage ", + err_msg_tensor_metadata_change_not_allowed); + storage_ = std::move(storage); + device_opt_ = storage_.device(); + } + + void set_storage_and_dtype( + at::Storage storage, + const caffe2::TypeMeta data_type) { + set_storage_keep_dtype(std::move(storage)); + data_type_ = data_type; + } + + void empty_tensor_restride_symint(MemoryFormat memory_format); + + /** + * Set the strides of the tensor to match memory_format + * + * WARNING: This function doesn't rearrange data and assumes tensor is a + * memory contiguous + */ + void empty_tensor_restride(MemoryFormat memory_format) { + if (has_symbolic_sizes_strides_) { + empty_tensor_restride_symint(memory_format); + return; + } +#ifdef DEBUG + TORCH_INTERNAL_ASSERT( + compute_numel() == numel_, + "If you are seeing this error, that means empty_tensor_restride was " + "called before setting correct numel"); +#endif + switch (memory_format) { + case MemoryFormat::Contiguous: { + // dim_ is a virtual call, don't repeat it + const auto dim_ = dim(); + sizes_and_strides_.resize(dim_); + if (dim_ > 0) { + bool overflowed = false; + const auto last_idx = dim_ - 1; + sizes_and_strides_.stride_at_unchecked(last_idx) = 1; + for (auto i = last_idx - 1; i >= 0; --i) { + overflowed |= c10::mul_overflows( + sizes_and_strides_.stride_at_unchecked(i + 1), + std::max( + sizes_and_strides_.size_at_unchecked(i + 1), 1), + std::addressof(sizes_and_strides_.stride_at_unchecked(i))); + } + TORCH_CHECK(!overflowed, "Stride calculation overflowed"); + } + break; + } + case MemoryFormat::ChannelsLast: { + TORCH_CHECK( + dim() == 4, "required rank 4 tensor to use channels_last format"); + set_sizes_and_strides(sizes(), get_channels_last_strides_2d(sizes())); + break; + } + case MemoryFormat::ChannelsLast3d: { + TORCH_CHECK( + dim() == 5, + "required rank 5 tensor to use channels_last_3d format"); + set_sizes_and_strides(sizes(), get_channels_last_strides_3d(sizes())); + break; + } + case MemoryFormat::Preserve: + TORCH_CHECK(false, "unsupported memory format ", memory_format); + // Cleaning warning messages, no need to break as TORCH_CHECK(false) + // terminates flow. + // break; + case MemoryFormat::NumOptions: + TORCH_INTERNAL_ASSERT(false, "invalid memory format ", memory_format); + } + // recompute contiguous flag, as currently NHWC/NCHW flags are not mutually + // exclusive see #24090 + refresh_contiguous(); + } + + bool is_strides_like(at::MemoryFormat memory_format) const { + if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomStrides))) { + return is_strides_like_custom(memory_format); + } + return is_strides_like_default(memory_format); + } + + bool is_strides_like_channels_last() const { + return is_strides_like(at::MemoryFormat::ChannelsLast); + } + + bool is_strides_like_channels_last_3d() const { + return is_strides_like(at::MemoryFormat::ChannelsLast3d); + } + + bool is_non_overlapping_and_dense_or_false() const { + return sym_is_non_overlapping_and_dense().guard_or_false( + __FILE__, __LINE__); + } + + bool is_non_overlapping_and_dense() const { + if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomStrides))) { + return is_non_overlapping_and_dense_custom(); + } + return is_non_overlapping_and_dense_default(); + } + + SymBool sym_is_non_overlapping_and_dense() const { + if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomStrides))) { + return sym_is_non_overlapping_and_dense_custom(); + } + return sym_is_non_overlapping_and_dense_default(); + } + + // if this returns true, then it is guaranteed that this tensor has symbolic + // sizes/strides + bool has_symbolic_sizes_strides() const { + return has_symbolic_sizes_strides_; + } + + private: + void HandleResize(); + + // The Caffe2 Resize() method supports being called both as Resize({2,2}) as + // well as variadic with Resize(2, 2). These overloads provide all of the + // supported calling configurations, while being overloads (and not templates) + // so that implicit conversions still work. + // + // SetDims on ArrayRef is internally implemented as a template, so we can + // handle both ArrayRefs of different types (there are some uses of + // Resize in Caffe2 which pass in int, not int64_t.) + + template < + typename T, + typename = typename std::enable_if_t>> + bool SetDimsTemplate(ArrayRef src) { + TORCH_CHECK( + !has_symbolic_sizes_strides_, + "SetDims() called on tensor with symbolic shape") + + auto old_numel = numel_; + sizes_and_strides_.resize(src.size()); + int64_t new_numel = 1; + for (const auto i : c10::irange(src.size())) { + new_numel *= src[i]; + sizes_and_strides_.size_at_unchecked(i) = src[i]; + } + numel_ = new_numel; + empty_tensor_restride(MemoryFormat::Contiguous); + return numel_ != old_numel; + } + + bool SetDims(ArrayRef s) { + return SetDimsTemplate(s); + } + + bool SetDims(ArrayRef s) { + return SetDimsTemplate(s); + } + + bool SetDims(ArrayRef s) { + return SetDimsTemplate(s); + } + + bool SetDims() { + return SetDims(IntArrayRef{}); + } + + bool SetDims(const int64_t d0) { + return SetDims(IntArrayRef{d0}); + } + + bool SetDims(const int64_t d0, const int64_t d1) { + return SetDims(IntArrayRef{d0, d1}); + } + + bool SetDims(const int64_t d0, const int64_t d1, const int64_t d2) { + return SetDims(IntArrayRef{d0, d1, d2}); + } + + bool SetDims( + const int64_t d0, + const int64_t d1, + const int64_t d2, + const int64_t d3) { + return SetDims(IntArrayRef{d0, d1, d2, d3}); + } + + /** + * Compute the number of elements based on the sizes of a tensor. + */ + // NB: This is ONLY called when sizes_and_strides_ is used directly; if + // we are virtualizing, then numel calls are virtualized as well, and this + // should never get called + int64_t compute_numel() const { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!has_symbolic_sizes_strides_); +#if C10_HAS_BUILTIN_OVERFLOW() && !defined(C10_MOBILE) + // Use overflow checks if supported by the compiler + return safe_compute_numel(); +#else + return c10::multiply_integers(sizes_and_strides_.sizes_arrayref()); +#endif + } + + /** + * Compute the number of elements based on the sizes of a + * tensor. Catches integer overflow that may occur when a tensor + * using a sparse layout has multiple dimensions with large sizes. + */ + int64_t safe_compute_numel() const { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!has_symbolic_sizes_strides_); + uint64_t n = 1; + bool overflows = + c10::safe_multiplies_u64(sizes_and_strides_.sizes_arrayref(), &n); + constexpr auto numel_max = std::min( + static_cast(std::numeric_limits::max()), + static_cast(std::numeric_limits::max())); + + overflows |= (n > numel_max); + TORCH_CHECK(!overflows, "numel: integer multiplication overflow"); + return static_cast(n); + } + + /** + * Compute whether or not a tensor is contiguous based on the sizes and + * strides of a tensor. + */ + bool compute_contiguous() const; + + bool compute_channels_last_contiguous_2d() const; + + bool compute_channels_last_contiguous_3d() const; + + bool compute_strides_like_channels_last_2d() const; + + bool compute_strides_like_channels_last_3d() const; + + bool compute_non_overlapping_and_dense() const; + + protected: + /** + * Recompute the cached numel of a tensor. Call this if you modify + * sizes. + * + * For tensors with sparse layouts, use safe_refresh_numel() instead + * because it will catch integer overflow that may occur for tensors + * with sparse layouts and large dimensions. + * + * NB: We may uselessly recompute cached numel even in situations where + * it is completely never used (e.g., if CustomSizes for Python). However, + * we still must keep it up to date in case the Python overload + * returns None (in which case we will consult the field here). This also + * implies that sizes/strides will never be complete garbage; in the + * very worst case scenario, it will reflect a 1-dim zero size tensor. + */ + void refresh_numel() { + if (has_symbolic_sizes_strides_) { + symbolic_shape_meta().refresh_numel(); + } else { + numel_ = compute_numel(); + } + } + + /** + * Recompute the cached numel of a tensor. Call this if you modify + * sizes. Use only for tensors with sparse layouts because only + * sparse tensor are likely to have sizes that may lead to integer + * overflow when computing numel. + */ + void safe_refresh_numel() { + if (has_symbolic_sizes_strides_) { + // NB: sym numel is done with symbolic integers, which handle overflow + // checking + symbolic_shape_meta().refresh_numel(); + } else { + numel_ = safe_compute_numel(); + } + } + + private: + void _set_is_contiguous(bool b) { + is_contiguous_ = b; + } + + void _set_is_channels_last_contiguous(bool b) { + is_channels_last_contiguous_ = b; + } + + void _set_is_channels_last_3d_contiguous(bool b) { + is_channels_last_3d_contiguous_ = b; + } + + void _set_is_channels_last(bool b) { + is_channels_last_ = b; + } + + void _set_is_channels_last_3d(bool b) { + is_channels_last_3d_ = b; + } + + void _set_is_non_overlapping_and_dense(bool b) { + is_non_overlapping_and_dense_ = b; + } + + // These are little wrappers over the real compute_ functions that + // can make use of other contiguity fields to short circuit. + + bool compute_is_non_overlapping_and_dense_dim4() { + return is_contiguous_ || is_channels_last_contiguous_ || + compute_non_overlapping_and_dense(); + } + + bool compute_channels_last_contiguous_3d_dim5() { + return !is_channels_last_contiguous_ && + compute_channels_last_contiguous_3d(); + } + + bool compute_channels_last_2d_dim5() { + return !is_channels_last_3d_contiguous_ && + compute_strides_like_channels_last_2d(); + } + + bool compute_channels_last_3d_dim5() { + return !is_channels_last_ && compute_strides_like_channels_last_3d(); + } + + bool compute_is_non_overlapping_and_dense_dim5() { + return is_contiguous_ || is_channels_last_contiguous_ || + is_channels_last_3d_contiguous_ || compute_non_overlapping_and_dense(); + } + + bool compute_is_non_overlapping_and_dense_anydim() { + return is_contiguous_ || compute_non_overlapping_and_dense(); + } + + void _refresh_contiguous() { + // Note: + // Dim 0, 1, 2 will never be a channels last 2d/3d format + // Dim 3+ is possibly be a channels last 2d format (Dim 4 only at this + // point) Dim 4+ is possibly be a channels last 3d format (Dim 5 only at + // this point) + switch (dim()) { + case 4: { + _set_is_contiguous(compute_contiguous()); + _set_is_channels_last_contiguous(compute_channels_last_contiguous_2d()); + _set_is_channels_last_3d_contiguous(false); + _set_is_channels_last(compute_strides_like_channels_last_2d()); + _set_is_channels_last_3d(false); + _set_is_non_overlapping_and_dense( + compute_is_non_overlapping_and_dense_dim4()); + break; + } + case 5: { + _set_is_contiguous(compute_contiguous()); + _set_is_channels_last_contiguous(compute_channels_last_contiguous_2d()); + _set_is_channels_last_3d_contiguous( + compute_channels_last_contiguous_3d_dim5()); + _set_is_channels_last(compute_channels_last_2d_dim5()); + _set_is_channels_last_3d(compute_channels_last_3d_dim5()); + _set_is_non_overlapping_and_dense( + compute_is_non_overlapping_and_dense_dim5()); + break; + } + default: + // is_channels_last_ and is_channels_last_3d_ are suggested + // memory_format. Being channels_last_contiguous doesn't necessarily + // mean the tensor is strided like channels_last: for strides on channel + // dimension could suggest desired memory_layout, but it doesn't affect + // memory storage + _set_is_contiguous(compute_contiguous()); + _set_is_channels_last_contiguous(false); + _set_is_channels_last_3d_contiguous(false); + _set_is_channels_last(false); + _set_is_channels_last_3d(false); + _set_is_non_overlapping_and_dense( + compute_is_non_overlapping_and_dense_anydim()); + break; + } + } + + protected: + /** + * Recompute the cached contiguity of a tensor. Call this if you modify sizes + * or strides. + */ + void refresh_contiguous() { + if (has_symbolic_sizes_strides_) { + symbolic_shape_meta().refresh_contiguous(); + } else { + _refresh_contiguous(); + } + } + + /** + * Copy the tensor metadata fields (e.g. sizes / strides / storage pointer / + * storage_offset) from one TensorImpl to another TensorImpl. + * + * For usage of `version_counter` and `allow_tensor_metadata_change`, see NOTE + * [ TensorImpl Shallow-Copying ]. + */ + static void copy_tensor_metadata( + const TensorImpl* src_impl, + TensorImpl* dest_impl, + const c10::VariableVersion& version_counter, + bool allow_tensor_metadata_change); + + /** + * Copy the tensor metadata fields (e.g. sizes / strides / storage pointer / + * storage_offset) from one TensorImpl to another TensorImpl. + * + * For usage of `version_counter` and `allow_tensor_metadata_change`, see NOTE + * [ TensorImpl Shallow-Copying ]. + */ + static void copy_tensor_metadata( + const TensorImpl* src_impl, + TensorImpl* dest_impl, + c10::VariableVersion&& version_counter, + bool allow_tensor_metadata_change); + + private: + static void copy_tensor_metadata_except_version_counter( + const TensorImpl* src_impl, + TensorImpl* dest_impl, + bool allow_tensor_metadata_change); + + protected: + // Error message to show when the user tries to change tensor metadata on + // Tensor created from .data or .detach(). + // + // See NOTE [ Metadata Change for a Detached Tensor ] for details. + static const char* const err_msg_tensor_metadata_change_not_allowed; + + static void copy_generic_tensor_metadata( + const TensorImpl* src_impl, + TensorImpl* dest_impl); + + public: + void set_storage_access_should_throw() { + storage_access_should_throw_ = true; + } + + public: + void set_custom_sizes_strides(SizesStridesPolicy policy) { + custom_sizes_strides_ = static_cast(policy); + refresh_sizes_strides_policy(); + } + + void set_python_custom_sizes_strides(SizesStridesPolicy policy) { + python_custom_sizes_strides_ = static_cast(policy); + refresh_sizes_strides_policy(); + } + + void set_custom_device(bool custom_device) { + custom_device_ = custom_device; + refresh_device_policy(); + } + + void set_custom_layout(bool custom_layout) { + custom_layout_ = custom_layout; + refresh_layout_policy(); + } + + void set_python_custom_device(bool custom_device) { + python_custom_device_ = custom_device; + refresh_device_policy(); + } + + void set_python_custom_layout(bool custom_layout) { + python_custom_layout_ = custom_layout; + refresh_layout_policy(); + } + + protected: + void refresh_sizes_strides_policy() { + if (has_symbolic_sizes_strides_) { + sizes_strides_policy_ = + static_cast(SizesStridesPolicy::CustomSizes); + } else { + sizes_strides_policy_ = + std::max(custom_sizes_strides_, python_custom_sizes_strides_); + } + } + + void refresh_device_policy() { + device_policy_ = custom_device_ || python_custom_device_; + } + + void refresh_layout_policy() { + layout_policy_ = custom_layout_ || python_custom_layout_; + } + + protected: + Storage storage_; + + private: + // This pointer points to an AutogradMeta struct that stores autograd-specific + // fields (such as grad_ / grad_fn_ / grad_accumulator_). This pointer always + // has unique ownership (meaning only one TensorImpl can own it at a time). + // + // autograd_meta_ can be nullptr, as an optimization. When this occurs, it is + // equivalent to having an autograd_meta_ pointing to a default constructed + // AutogradMeta; intuitively, tensors which don't require grad will have this + // field set to null. + // + // This means accessors on autograd_meta_ have to be careful to test if they + // got a nullptr, and handle default behavior appropriately in that case. + // + // Note that we don't enforce the invariant that if the AutogradMeta is + // default constructed, it is nullptr (to do this, we'd have to continuously + // check if an AutogradMeta became, by mutation, equal to the default + // constructed form. (This might be useful, but it seems rare enough that + // a requires_grad=True variable will turn back into the requires_grad=False + // version.) So there are three representable states: + // + // 1. autograd_meta_ == nullptr + // 2. autograd_meta_ is default constructed (semantically, same as (1)) + // 3. autograd_meta_ has nontrivial information content + // + std::unique_ptr autograd_meta_ = nullptr; + + protected: + std::unique_ptr extra_meta_ = nullptr; + + c10::VariableVersion version_counter_; + + impl::PyObjectSlot pyobj_slot_; + + c10::impl::SizesAndStrides sizes_and_strides_; + + int64_t storage_offset_ = 0; + // If sizes and strides are empty, the numel is 1!! However, most of the + // time, we will immediately set sizes to {0} and reset numel to 0. + // (Can't do that in the default initializers, because there's no way to + // spell "allocate a one-element array" for strides_). + int64_t numel_ = 1; + + // INVARIANT: When storage is non-null, this type meta must + // agree with the type meta in storage + caffe2::TypeMeta data_type_; + + // NOTE [std::optional operator usage in CUDA] + // Our optional definition doesn't compile in .cu file if `value()` or + // `operator->` are used. Instead, we always use `operator*`. + // See https://github.com/pytorch/pytorch/issues/18496 for more info. + // If this is too burdensome to maintain, we can just + // manually implement this with an additional bool. + + // INVARIANT: When storage is non-null, this Device must + // agree with the type meta in storage. + // + // INVARIANT: device_opt_ is only nullopt for undefined tensors + // (which do not have a device.) + std::optional device_opt_; + + // default member initializers for bit-fields only available with -std=c++2a + // or -std=gnu++2a + inline void init_bitfields() { + is_contiguous_ = true; + is_channels_last_ = false; + is_channels_last_contiguous_ = false; + is_channels_last_3d_ = false; + is_channels_last_3d_contiguous_ = false; + is_non_overlapping_and_dense_ = true; + is_wrapped_number_ = false; + allow_tensor_metadata_change_ = true; + reserved_ = false; + sizes_strides_policy_ = static_cast(SizesStridesPolicy::Default); + custom_sizes_strides_ = static_cast(SizesStridesPolicy::Default); + python_custom_sizes_strides_ = + static_cast(SizesStridesPolicy::Default); + python_custom_device_ = false; + python_custom_layout_ = false; + custom_device_ = false; + custom_layout_ = false; + device_policy_ = false; + layout_policy_ = false; + storage_access_should_throw_ = false; + has_symbolic_sizes_strides_ = false; + } + + // Tensor is contiguous + bool is_contiguous_ : 1; + + // Tensor is a subclass that does not permit storage access. + bool storage_access_should_throw_ : 1; + + // Tensor is stored in the channels last 2d memory format, when dimensions + // order is (N)CHW and C-strides < W-strides < H-strides (< N-strides) + // (If size of any dimension is equal to 1, this dimension strides value + // is not taken into account). + bool is_channels_last_ : 1; + + // Channels last contiguous tensor is channel last tensor which occupies + // contiguous memory block. + bool is_channels_last_contiguous_ : 1; + + // Tensor is stored in the channels last 3d memory format, when dimensions + // order is (N)CDHW and C-strides < W-strides < H-strides < D - strides (< + // N-strides) (If size of any dimension is equal to 1, this dimension strides + // value is not taken into account). + bool is_channels_last_3d_ : 1; + + // Channels last 3d contiguous tensor is channel last 3d tensor which occupies + // contiguous memory block. + bool is_channels_last_3d_contiguous_ : 1; + + // Dense tensor is the tensor that store values in a contiguous block of + // memory. Non-overlapping tensor is the tensor in which elements occupy + // individual non-repetitive memory. + bool is_non_overlapping_and_dense_ : 1; + + bool is_wrapped_number_ : 1; + + // NOTE [ Metadata Change for a Detached Tensor ] + // + // Normally, a user is allowed to change the tensor metadata + // (e.g. sizes / strides / storage / storage_offset) of a tensor. + // However, if the tensor is created by `t1_detached = t1.data` in Python + // or `t1_detached = t1.detach()` in Python/C++, those changes to the + // tensor metadata of `t1_detached` will not be propagated back to the + // original tensor `t1`. In order to make such changes explicitly illegal, + // we created the `allow_tensor_metadata_change_` flag, to prevent users + // from changing metadata of the detached tensor and expecting the original + // tensor to also be updated. + // + // NOTE: For a full list of tensor metadata fields, please see + // `copy_tensor_metadata()` in TensorImpl and its subclasses to find + // which fields are copied by value. + bool allow_tensor_metadata_change_ : 1; + + // we decide to keep reserved_ and it will + // live in Tensor after the split + // The logic is that if Extend() or ReserveSpace() were ever called, + // then subsequent Resize()s will not free up Storage. + bool reserved_ : 1; + + // Call _custom() virtual methods for + // strides()/is_contiguous()/sizes()/dim()/numel() + // This is a combination of sizes_strides_custom_dispatch_ + // and has_symbolic_sizes_strides_ + uint8_t sizes_strides_policy_ : 2; + + // Whether or not sizes_and_strides_ contains a symbolic value. + bool has_symbolic_sizes_strides_ : 1; + + // Call _custom() virtual method for + // strides()/is_contiguous()/sizes()/dim()/numel() + uint8_t custom_sizes_strides_ : 2; + + // Combo of custom_ and python_custom_ + bool device_policy_ : 1; + bool layout_policy_ : 1; + + // Call _custom() virtual method for device() + bool custom_device_ : 1; + + // Call _custom() virtual method for layout() + bool custom_layout_ : 1; + + // Call into Python for + // strides()/is_contiguous()/sizes()/dim()/numel() + uint8_t python_custom_sizes_strides_ : 2; + + // Call into Python for device() + bool python_custom_device_ : 1; + + // Call into Python for layout() + bool python_custom_layout_ : 1; + + // The set of DispatchKeys which describe this tensor. NB: this + // does NOT include Autograd (historically, it did, but + // not anymore!) + // + // INVARIANT: extra_meta_->named_tensor_meta_ != nullptr <==> + // key_set_.has(DispatchKey::Named) + DispatchKeySet key_set_; + + private: + // C10_TensorImpl_Size_Check_Dummy_Class needs to be friends with + // TensorImpl so it can inspect the size of private fields + template < + size_t cplusplus, + size_t clang_ver_major, + size_t gcc_ver, + size_t gcc_ver_minor, + size_t nvcc, + size_t cuda_version, + size_t cuda_version_major, + size_t ptr_size> + friend class C10_TensorImpl_Size_Check_Dummy_Class; +}; + +namespace detail { + +#ifndef C10_MOBILE +template +struct TargetTraits< + T, + std::enable_if_t>>> { + static constexpr bool can_have_pyobject = true; +}; +#endif + +} // namespace detail + +// Note [TensorImpl size constraints] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// Changed the size of TensorImpl? If the size went down, good for +// you! Adjust the documentation below and the expected size. +// Did it go up? Read on... +// +// Struct size matters. In some production systems at Facebook, we have +// 400M live tensors during a training run. Do the math: every 64-bit +// word you add to Tensor is an extra 3.2 gigabytes in RAM. +// +// If you are a Facebook employee, you can check if the run in question +// has tipped you over the point using the command here: +// https://fburl.com/q5enpv98 +// +// For reference, we OOMed at 160 bytes (20 words) per TensorImpl. +// This is not counting overhead from strides out-of-line allocation and +// StorageImpl space and this is from before we inlined sizes and strides +// directly into TensorImpl as SmallVectors. +// +// Our memory usage on 32-bit systems is suboptimal, but we're not checking +// for it at the moment (to help avoid rage inducing cycles when the +// 32-bit number is wrong). +// +// Current breakdown: +// +// vtable pointer +// strong refcount TODO: pack these into one word +// weak refcount +// storage pointer +// autograd metadata pointer +// named tensor metadata pointer +// version counter pointer +// PyObjectSlot +// SizesAndStrides size/pointer +// SizesAndStrides sizes (pre-allocated 0) +// SizesAndStrides sizes (pre-allocated 1) +// SizesAndStrides sizes (pre-allocated 2) +// SizesAndStrides sizes (pre-allocated 3) +// SizesAndStrides sizes (pre-allocated 4) +// SizesAndStrides strides (pre-allocated 0) +// SizesAndStrides strides (pre-allocated 1) +// SizesAndStrides strides (pre-allocated 2) +// SizesAndStrides strides (pre-allocated 3) +// SizesAndStrides strides (pre-allocated 4) +// storage offset +// numel +// data type, device, is_contiguous, storage_access_should_throw_, bitfields +// DispatchKeySet +// + +// Various preprocessor macros we use to check that the +// TensorImpl size hasn't changed unexpectedly. We undef +// these later. +#ifndef __NVCC__ +#define C10_NVCC 0 +#else +#define C10_NVCC __NVCC__ +#endif + +#ifndef __CUDA_VER_MAJOR__ +#define C10_CUDA_VERSION_MAJOR 0 +#else +#define C10_CUDA_VERSION_MAJOR __CUDA_VER_MAJOR__ +#endif + +#ifndef CUDA_VERSION +#define C10_CUDA_VERSION 0 +#else +#define C10_CUDA_VERSION CUDA_VERSION +#endif + +#ifndef __clang_major__ +#define C10_CLANG_MAJOR_VERSION 0 +#else +#define C10_CLANG_MAJOR_VERSION __clang_major__ +#endif + +#ifndef __GNUC__ +#define C10_GCC_VERSION 0 +#else +#define C10_GCC_VERSION __GNUC__ +#endif + +#ifndef __GNUC_MINOR__ +#define C10_GCC_VERSION_MINOR 0 +#else +#define C10_GCC_VERSION_MINOR __GNUC_MINOR__ +#endif + +// We use a templatized class to both contain the logic of checking the sizes +// as well as to provide compile-time information that might be useful in +// figuring out why sizes may have changed. +// All the compile time information is given by the template fields that are +// always printed by the compiler when the static_assert fails. +template < + size_t cplusplus = __cplusplus, + size_t clang_ver_major = C10_CLANG_MAJOR_VERSION, + size_t gcc_ver = C10_GCC_VERSION, + size_t gcc_ver_minor = C10_GCC_VERSION_MINOR, + size_t nvcc = C10_NVCC, + size_t cuda_version = C10_CUDA_VERSION, + size_t cuda_version_major = C10_CUDA_VERSION_MAJOR, + size_t ptr_size = sizeof(void*)> +class C10_TensorImpl_Size_Check_Dummy_Class : private TensorImpl { + // Names of (non-bitfield) fields in TensorImpl; used to provide + // compile-time info about fields whose size changes unexpectedly. + enum class FieldNameEnum { + storage_, + autograd_meta_, + extra_meta_, + version_counter_, + pyobj_slot_, + sizes_and_strides_, + storage_offset_, + numel_, + data_type_, + device_opt_, + key_set_, + TOTAL_SIZE + }; + + // Provides compile-time equality check that reveals what numbers + // were used and on which quantity + template + constexpr static bool are_equal() { + static_assert( + Actual == Expected, + "Actual and Expected sizes of a field did not match!"); + return true; + } + + // Provides compile-time <= check that reveals what numbers + // were used and on which quantity + template + constexpr static bool is_le() { + static_assert( + Actual <= Expected, + "Actual and Expected sizes of a field did not match!"); + return true; + } + + public: + // Compile-time check that TensorImpl field sizes are as expected + // + // Observed total sizes and associated versions + // If you find a flag that predicts when unique_ptr has 16 bytes + // on 64-bit systems or when sizes_and_strides_ is 84 vs 88 bytes + // on 32-bit systems you get a cookie! + // Length | LLVM | GCC | C++ | CUDA + // 192 | ? | 11.2 | 201703 | 11040 + // 208 | ? | 11.2 | 201703 | 11040 + // 208 | ? | 11.2 | 201402 | 11040 + // 192 | ? | 11.2 | 201402 | 11040 + // 160 | 12 | 4.2 | 201703 | 0 + // + // To keep things clean, we split on systems here. + +#if UINTPTR_MAX == 0xFFFFFFFF + // This is a 32-bit system + static constexpr bool check_sizes() { + constexpr size_t tsize = 20 * sizeof(int64_t); + + // clang-format off + are_equal(); + are_equal(); + are_equal(); + are_equal(); + are_equal(); + is_le(); + are_equal(); + are_equal(); + are_equal(); + are_equal(); + are_equal(); + is_le(); + // clang-format on + + return true; + } +#else + // This is a 64-bit system + static constexpr bool check_sizes() { + constexpr size_t tsize = 26 * sizeof(int64_t); + + // clang-format off + are_equal(); + // On some systems involving NVCC the size of unique_ptr is 16 bytes. We haven't + // figured out how to detect those via macro preprocessors yet, so we use <= + // comparisons for the relevant fields. + is_le(); + is_le(); + are_equal(); + are_equal(); + are_equal(); + are_equal(); + are_equal(); + are_equal(); + are_equal(); + are_equal(); + is_le(); + // clang-format on + + return true; + } +#endif +}; + +// We use a class to encapsulate size-checking logic with +// templates to capture sizes and flags. We call this within +// a static assert to prove there is no run-time behaviour. +// Since the methods we call return either true or fail their +// own static_asserts, we should never see the error messages +// below. We have to provide it though for c++ <17. +static_assert( + C10_TensorImpl_Size_Check_Dummy_Class<>::check_sizes(), + "You should not see this message."); + +// Clean up after ourselves +#undef C10_NVCC +#undef C10_CUDA_VERSION_MAJOR +#undef C10_CUDA_VERSION +#undef C10_CLANG_MAJOR_VERSION +#undef C10_GCC_VERSION +#undef C10_GCC_VERSION_MINOR + +} // namespace c10 + +C10_DIAGNOSTIC_POP() + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/TensorOptions.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/TensorOptions.h new file mode 100644 index 0000000000000000000000000000000000000000..7add8edc4361ab3c38675d8565ad13b4d1ed48b3 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/TensorOptions.h @@ -0,0 +1,791 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-enum") + +namespace c10 { + +inline ScalarType dtype_or_default(std::optional dtype) { + return dtype.value_or(get_default_dtype_as_scalartype()); +} + +inline caffe2::TypeMeta dtype_or_default( + std::optional dtype) { + return dtype.value_or(get_default_dtype()); +} + +inline Layout layout_or_default(std::optional layout) { + return layout.value_or(kStrided); +} + +inline Device device_or_default(std::optional device) { + return device.value_or(Device(kCPU)); +} + +inline bool pinned_memory_or_default(std::optional pinned_memory) { + return pinned_memory.value_or(false); +} + +/// A class to encapsulate construction axes of an Tensor. TensorOptions was +/// designed to support the Python style API for specifying construction options +/// on factory functions, e.g., +/// +/// torch.zeros(2, 3, dtype=torch.int32) +/// +/// Because C++ doesn't natively support keyword arguments, there must be +/// another way of specifying keyword-like arguments. TensorOptions is a +/// builder class which can be used to construct this "dictionary" of keyword +/// arguments: functions which support TensorOptions conventionally take this +/// argument optionally as their last argument. +/// +/// WARNING: In PyTorch, there are `torch::` variants of factory functions, +/// e.g., torch::zeros for at::zeros. These return Variables (while the +/// stock ATen functions return plain Tensors). If you mix these functions +/// up, you WILL BE SAD. +/// +/// Rather than use the constructor of this class directly, you should prefer to +/// use the constructor functions, and then chain setter methods on top of them. +/// +/// at::device(at::kCUDA).dtype(kInt) +/// at::dtype(at::kInt) +/// +/// Additionally, anywhere a TensorOptions is expected, you can directly +/// pass at::kCUDA / at::kInt, and it will implicitly convert to a +/// TensorOptions. +/// +/// Here are some recommended ways to create a 2x2 tensor of zeros +/// with certain properties. These all *implicitly* make use of +/// TensorOptions, even if they don't mention the class explicitly: +/// +/// at::zeros({2,2}, at::kCUDA); +/// at::zeros({2,2}, at::kLong); +/// at::zeros({2,2}, at::device(at::kCUDA).dtype(at::kLong())); +/// at::zeros({2,2}, at::device({at::kCUDA, 1})); // place on device 1 +/// at::zeros({2,2}, at::requires_grad()); +/// + +/// NOTE [ TensorOptions Constructors ] +/// +/// TensorOptions is like a dictionary with entries from the set: +/// {requires_grad, device, dtype, layout}, where each entry may be +/// unspecified (i.e., is optional). It is used to specify the properties of +/// tensors in many places both in C++ internal and API, e.g., tensor factory +/// methods like `at::empty({10}, options)`, tensor conversions like +/// `tensor.to(...)`, etc. +/// +/// To provide a simple API that is consistent with Python, where one can do +/// `torch.empty(sizes, X)` with `X` being a `torch.device`, `torch.dtype`, or a +/// `torch.layout`, we want TensorOptions to be implicitly convertible from +/// `ScalarType dtype`, `Layout layout` and `Device device`. Therefore, we have +/// three implicit constructors from each of these three types. +/// +/// This is sufficient for `ScalarType` and `Layout` as they are simple Enum +/// classes. However, `Device` is an ordinary class with implicit constructors +/// `Device(DeviceType, DeviceIndex = -1)` and `Device(std::string)` to be +/// consistent with Python API, where strings are treated as equivalent with a +/// `torch.device` object (e.g., "cuda:1" can be passed to everywhere a +/// `torch.device("cuda:1")` is accepted). To support the syntax +/// `at::empty({10}, {kCUDA, 1})` and `tensor.to(kCUDA)`, we need to make sure +/// that `TensorOptions` is implicitly constructible with any arguments that a +/// `Device` can constructed from. So we have, +/// +/// /* implicit */ TensorOptions(T&& device) : TensorOptions() { +/// this->set_device(device); +/// } +/// +/// template ::value>> +/// /* implicit */ TensorOptions(Args&&... args) +/// : TensorOptions(Device(std::forward(args)...)) {} +/// +/// +/// But this will be problematic. Consider this: `TensorOptions({kCUDA, 1})`. +/// Compiler will complain about ambiguity between the copy constructor and the +/// `Device` constructor because `{kCUDA, 1}` can be converted to both a +/// `TensorOption` and a `Device`. +/// +/// To get around this, we templatize the `Device` constructor. Since overload +/// resolution is done before template resolution, our problem is solved. + +DispatchKey computeDispatchKey( + std::optional dtype, + std::optional layout, + std::optional device); + +struct C10_API TensorOptions { + TensorOptions() + : requires_grad_(false), + pinned_memory_(false), + has_device_(false), + has_dtype_(false), + has_layout_(false), + has_requires_grad_(false), + has_pinned_memory_(false), + has_memory_format_(false) {} + + /// Constructs a `TensorOptions` object with the given layout. + /* implicit */ TensorOptions(Layout layout) : TensorOptions() { + this->set_layout(layout); + } + + /// Constructs a `TensorOptions` object with the given device. + /// See NOTE [ TensorOptions Constructors ] on why this is templatized. + template < + typename T, + typename = std::enable_if_t, Device>>> + /* implicit */ TensorOptions(T&& device) : TensorOptions() { + this->set_device(std::forward(device)); + } + + /// Constructs a `TensorOptions` object from arguments allowed in `Device` + /// constructors. + /// + /// See NOTE [ TensorOptions Constructors ]. + /// + /// NB: Ideally we only allow implicit constructors here. But there is no easy + /// way to detect them. So we have this one that allows explicit + /// constructors too. + template < + typename... Args, + typename = std::enable_if_t>> + /* implicit */ TensorOptions(Args&&... args) + : TensorOptions(Device(std::forward(args)...)) {} + + /// Constructs a `TensorOptions` object with the given dtype. + /* implicit */ TensorOptions(caffe2::TypeMeta dtype) : TensorOptions() { + this->set_dtype(dtype); + } + + /// legacy constructor to support ScalarType + /* implicit */ TensorOptions(ScalarType dtype) : TensorOptions() { + this->set_dtype(dtype); + } + + /// Constructs a `TensorOptions` object with the given memory format. + /* implicit */ TensorOptions(MemoryFormat memory_format) : TensorOptions() { + set_memory_format(memory_format); + } + + /// Return a copy of `TensorOptions` with `device` set to the given one, or + /// cleared if `device` is `nullopt`. + [[nodiscard]] TensorOptions device( + std::optional device) const noexcept { + TensorOptions r = *this; + r.set_device(device); + return r; + } + + /// Return a copy of `TensorOptions` with `device` set to the given one. + /// (This overload ensures that variadic template std::optional constructor + /// for Device work correctly.) + template + [[nodiscard]] TensorOptions device(Args&&... args) const noexcept { + return device( + std::optional(std::in_place, std::forward(args)...)); + } + + /// Return a copy of `TensorOptions`, but with device set to CUDA, and the + /// device index set to the given one. + /// + /// TODO: This function encourages bad behavior (assuming CUDA is + /// the only device that matters). Get rid of it / rename it. + [[nodiscard]] TensorOptions device_index( + c10::DeviceIndex device_index) const noexcept { + return device(Device::Type::CUDA, device_index); + } + + /// Return a copy of `TensorOptions` with `dtype` set to the given one. + [[nodiscard]] TensorOptions dtype( + std::optional dtype) const noexcept { + TensorOptions r = *this; + r.set_dtype(dtype); + return r; + } + + // legacy function to support ScalarType + [[nodiscard]] TensorOptions dtype( + std::optional dtype) const noexcept { + TensorOptions r = *this; + r.set_dtype(dtype); + return r; + } + + // Since dtype is taken... + template + TensorOptions& dtype() { + dtype_ = caffe2::TypeMeta::Make(); + has_dtype_ = true; + return *this; + } + + /// Sets the layout of the `TensorOptions`. + [[nodiscard]] TensorOptions layout( + std::optional layout) const noexcept { + TensorOptions r = *this; + r.set_layout(layout); + return r; + } + + /// Sets the `requires_grad` property of the `TensorOptions`. + [[nodiscard]] TensorOptions requires_grad( + std::optional requires_grad) const noexcept { + TensorOptions r = *this; + r.set_requires_grad(requires_grad); + return r; + } + + /// Sets the `pinned_memory` property on the `TensorOptions`. + [[nodiscard]] TensorOptions pinned_memory( + std::optional pinned_memory) const noexcept { + TensorOptions r = *this; + r.set_pinned_memory(pinned_memory); + return r; + } + + /// Sets the `memory_format` property on `TensorOptions`. + [[nodiscard]] TensorOptions memory_format( + std::optional memory_format) const noexcept { + TensorOptions r = *this; + r.set_memory_format(memory_format); + return r; + } + + /// Returns the device of the `TensorOptions`. + Device device() const noexcept { + return device_or_default(device_opt()); + } + + /// Returns whether the device is specified. + bool has_device() const noexcept { + return has_device_; + } + + /// Returns the device of the `TensorOptions`, or `std::nullopt` if + /// device is not specified. + std::optional device_opt() const noexcept { + return has_device_ ? std::make_optional(device_) : std::nullopt; + } + + /// Returns the device index of the `TensorOptions`. + c10::DeviceIndex device_index() const noexcept { + return device().index(); + } + + /// Returns the dtype of the `TensorOptions`. + caffe2::TypeMeta dtype() const noexcept { + return dtype_or_default(dtype_opt()); + } + + /// Returns whether the dtype is specified. + bool has_dtype() const noexcept { + return has_dtype_; + } + + /// Returns the dtype of the `TensorOptions`, or `std::nullopt` if + /// device is not specified. + std::optional dtype_opt() const noexcept { + return has_dtype_ ? std::make_optional(dtype_) : std::nullopt; + } + + /// Returns the layout of the `TensorOptions`. + Layout layout() const noexcept { + return layout_or_default(layout_opt()); + } + + /// Returns whether the layout is specified. + bool has_layout() const noexcept { + return has_layout_; + } + + /// Returns the layout of the `TensorOptions`, or `std::nullopt` if + /// layout is not specified. + std::optional layout_opt() const noexcept { + return has_layout_ ? std::make_optional(layout_) : std::nullopt; + } + + /// Returns the `requires_grad` property of the `TensorOptions`. + bool requires_grad() const noexcept { + return has_requires_grad_ ? requires_grad_ : false; + } + + /// Returns whether the `requires_grad` is specified. + bool has_requires_grad() const noexcept { + return has_requires_grad_; + } + + /// Returns the `requires_grad` property of the `TensorOptions`, or + /// `std::nullopt` if `requires_grad` is not specified. + std::optional requires_grad_opt() const noexcept { + return has_requires_grad_ ? std::make_optional(requires_grad_) + : std::nullopt; + } + + /// Returns the `pinned_memory` property of the `TensorOptions`. + bool pinned_memory() const noexcept { + return pinned_memory_or_default(pinned_memory_opt()); + } + + /// Returns whether the `pinned_memory` is specified. + bool has_pinned_memory() const noexcept { + return has_pinned_memory_; + } + + /// Returns if the layout is sparse + bool is_sparse() const { + return layout_ == c10::Layout::Sparse; + } + + /// Returns if the layout is sparse CSR, deprecated, use + /// is_sparse_compressed() instead + bool is_sparse_csr() const { + return layout_ == c10::Layout::SparseCsr; + } + + bool is_sparse_compressed() const { + return layout_ == c10::Layout::SparseCsr || + layout_ == c10::Layout::SparseCsc || + layout_ == c10::Layout::SparseBsr || layout_ == c10::Layout::SparseBsc; + } + + // For compatibility with legacy tensor.type() comparisons + bool type_equal(const TensorOptions& other) const { + return computeDispatchKey() == other.computeDispatchKey() && + typeMetaToScalarType(dtype_) == typeMetaToScalarType(other.dtype()); + } + + /// Returns the `pinned_memory` property of the `TensorOptions`, or + /// `std::nullopt` if `pinned_memory` is not specified. + std::optional pinned_memory_opt() const noexcept { + return has_pinned_memory_ ? std::make_optional(pinned_memory_) + : std::nullopt; + } + + /// Returns whether the `memory_layout` is specified + bool has_memory_format() const noexcept { + return has_memory_format_; + } + + // NB: memory_format() getter is PURPOSELY not defined, as the default + // behavior of memory_format varies from function to function. + + /// Returns the `memory_layout` property of `TensorOptions, or + /// `std::nullopt` if `memory_format` is not specified. + std::optional memory_format_opt() const noexcept { + return has_memory_format_ ? std::make_optional(memory_format_) + : std::nullopt; + } + + // Resolves the ATen backend specified by the current construction axes. + // TODO: Deprecate this + Backend backend() const { + return at::dispatchKeyToBackend(computeDispatchKey()); + } + + /// Return the right-biased merge of two TensorOptions. This has the + /// effect of overwriting settings from self with specified options + /// of options. + /// + /// NB: This merging operation does NOT respect device merges. + /// For example, if you device({kCUDA, 1}).merge_in(kCUDA) + /// you will get kCUDA in the end! Functions like Tensor.new_empty + /// ensure the right device is selected anyway by way of a + /// device guard. + /// + TensorOptions merge_in(TensorOptions options) const noexcept { + TensorOptions merged = *this; + if (options.has_device()) + merged.set_device(options.device_opt()); + if (options.has_dtype()) + merged.set_dtype(options.dtype_opt()); + if (options.has_layout()) + merged.set_layout(options.layout_opt()); + // NB: requires grad is right biased; not a logical AND/OR! + if (options.has_requires_grad()) + merged.set_requires_grad(options.requires_grad_opt()); + if (options.has_pinned_memory()) + merged.set_pinned_memory(options.pinned_memory_opt()); + if (options.has_memory_format()) + merged.set_memory_format(options.memory_format_opt()); + return merged; + } + + // TODO remove after TensorOptions rationalization + TensorOptions merge_memory_format( + std::optional optional_memory_format) const noexcept { + TensorOptions merged = *this; + if (optional_memory_format.has_value()) { + merged.set_memory_format(optional_memory_format); + } + return merged; + } + + // INVARIANT: computeDispatchKey returns only the subset of dispatch keys for + // which dispatchKeyToBackend is injective, if it is defined at all (for + // the most part, this just means that this function never returns an + // Autograd key) + DispatchKey computeDispatchKey() const { + return c10::computeDispatchKey( + optTypeMetaToScalarType(dtype_opt()), layout_opt(), device_opt()); + } + + private: + // These methods are currently private because I'm not sure if it's wise + // to actually publish them. They are methods because I need them in + // the constructor and the functional API implementation. + // + // If you really, really need it, you can make these public, but check if you + // couldn't just do what you need with the functional API. Similarly, these + // methods are not chainable, because if you wanted chaining, you probably + // want to use the functional API instead. (It's probably OK to make + // these chainable, because these functions are all explicitly annotated + // with a ref-qualifier, the trailing &, that makes them illegal to call + // on temporaries.) + + /// Mutably set the device of `TensorOptions`. + void set_device(std::optional device) & noexcept { + if (device) { + device_ = *device; + has_device_ = true; + } else { + has_device_ = false; + } + } + + /// Mutably set the dtype of `TensorOptions`. + void set_dtype(std::optional dtype) & noexcept { + if (dtype) { + dtype_ = *dtype; + has_dtype_ = true; + } else { + has_dtype_ = false; + } + } + + // legacy function to support ScalarType + void set_dtype(std::optional dtype) & noexcept { + if (dtype) { + dtype_ = scalarTypeToTypeMeta(*dtype); + has_dtype_ = true; + } else { + has_dtype_ = false; + } + } + + /// Mutably set the layout of `TensorOptions`. + void set_layout(std::optional layout) & noexcept { + if (layout) { + layout_ = *layout; + has_layout_ = true; + } else { + has_layout_ = false; + } + } + + /// Mutably set the `requires_grad` property of `TensorOptions`. + void set_requires_grad(std::optional requires_grad) & noexcept { + if (requires_grad) { + requires_grad_ = *requires_grad; + has_requires_grad_ = true; + } else { + has_requires_grad_ = false; + } + } + + /// Mutably set the `pinned_memory` property of `TensorOptions`. + void set_pinned_memory(std::optional pinned_memory) & noexcept { + if (pinned_memory) { + pinned_memory_ = *pinned_memory; + has_pinned_memory_ = true; + } else { + has_pinned_memory_ = false; + } + } + + /// Mutably set the `memory_Format` property of `TensorOptions`. + void set_memory_format(std::optional memory_format) & noexcept { + if (memory_format) { + memory_format_ = *memory_format; + has_memory_format_ = true; + } else { + has_memory_format_ = false; + } + } + + // WARNING: If you edit TensorOptions to add more options, you + // may need to adjust the implementation of Tensor::options. + // The criteria for whether or not Tensor::options must be adjusted + // is whether or not the new option you added should preserved + // by functions such as empty_like(); if it should be preserved, + // you must adjust options(). + // + // TODO: MemoryFormat is not implemented in this way + + // NB: We didn't use std::optional here, because then we can't pack + // the has_***_ boolean fields. + + Device device_ = at::kCPU; // 16-bit + caffe2::TypeMeta dtype_ = caffe2::TypeMeta::Make(); // 16-bit + Layout layout_ = at::kStrided; // 8-bit + MemoryFormat memory_format_ = MemoryFormat::Contiguous; // 8-bit + + // Bitmask required here to get this to fit inside 32 bits (or even 64 bits, + // for that matter) + + bool requires_grad_ : 1; + bool pinned_memory_ : 1; + + bool has_device_ : 1; + bool has_dtype_ : 1; + bool has_layout_ : 1; + bool has_requires_grad_ : 1; + bool has_pinned_memory_ : 1; + bool has_memory_format_ : 1; +}; + +// We should aspire to fit in one machine-size word; but a size greater than two +// words is too much. (We are doing terribly on 32-bit archs, where we require +// three machine size words to store tensor options. Eek!) +static_assert( + sizeof(TensorOptions) <= sizeof(int64_t) * 2, + "TensorOptions must fit in 128-bits"); + +/// Convenience function that returns a `TensorOptions` object with the `dtype` +/// set to the given one. +inline TensorOptions dtype(caffe2::TypeMeta dtype) { + return TensorOptions().dtype(dtype); +} + +// legacy function to support ScalarType +inline TensorOptions dtype(ScalarType dtype) { + return TensorOptions().dtype(scalarTypeToTypeMeta(dtype)); +} + +/// Convenience function that returns a `TensorOptions` object with the `layout` +/// set to the given one. +inline TensorOptions layout(Layout layout) { + return TensorOptions().layout(layout); +} + +/// Convenience function that returns a `TensorOptions` object with the `device` +/// set to the given one. +inline TensorOptions device(Device device) { + return TensorOptions().device(device); +} + +/// Convenience function that returns a `TensorOptions` object with the +/// `device` set to CUDA and the `device_index` set to the given one. +inline TensorOptions device_index(c10::DeviceIndex device_index) { + return TensorOptions().device_index(device_index); +} + +/// Convenience function that returns a `TensorOptions` object with the +/// `requires_grad` set to the given one. +inline TensorOptions requires_grad(bool requires_grad = true) { + return TensorOptions().requires_grad(requires_grad); +} + +/// Convenience function that returns a `TensorOptions` object with the +/// `memory_format` set to the given one. +inline TensorOptions memory_format(MemoryFormat memory_format) { + return TensorOptions().memory_format(memory_format); +} + +C10_API std::ostream& operator<<( + std::ostream& stream, + const TensorOptions& options); + +template +inline TensorOptions dtype() { + return dtype(caffe2::TypeMeta::Make()); +} + +inline std::string toString(const TensorOptions& options) { + std::ostringstream stream; + stream << options; + return stream.str(); +} + +// This is intended to be a centralized location by which we can determine +// what an appropriate DispatchKey for a tensor is. +inline DispatchKey computeDispatchKey( + std::optional dtype, + std::optional layout, + std::optional device) { + const auto layout_ = layout_or_default(layout); + const auto device_ = device_or_default(device); + switch (layout_) { + case Layout::Jagged: + case Layout::Strided: { + const auto dtype_ = dtype_or_default(dtype); + switch (device_.type()) { +#define DO_CASE(device, _) \ + case c10::DeviceType::device: { \ + if (isQIntType(dtype_)) { \ + return DispatchKey::Quantized##device; \ + } \ + return DispatchKey::device; \ + } + C10_FORALL_BACKEND_DEVICE_TYPES(DO_CASE, unused) +#undef DO_CASE + case c10::DeviceType::FPGA: + return DispatchKey::FPGA; + case c10::DeviceType::MAIA: + return DispatchKey::MAIA; + case c10::DeviceType::Vulkan: + return DispatchKey::Vulkan; + case c10::DeviceType::Metal: + return DispatchKey::Metal; + case c10::DeviceType::MKLDNN: + case c10::DeviceType::OPENGL: + case c10::DeviceType::OPENCL: + case c10::DeviceType::IDEEP: + TORCH_INTERNAL_ASSERT( + 0, + "This is a grandfathered Caffe2 device type ", + device_.type(), + ", it shouldn't ever convert to a DispatchKey. File a bug describing what you were doing if you think this is in error."); + default: + TORCH_CHECK_NOT_IMPLEMENTED( + false, + "Unsupported device type for dense layout: ", + device_.type()); + } + } + case Layout::Sparse: + switch (device_.type()) { +#define DO_CASE(device, _) \ + case c10::DeviceType::device: { \ + return DispatchKey::Sparse##device; \ + } + C10_FORALL_BACKEND_DEVICE_TYPES(DO_CASE, unused) +#undef DO_CASE + default: + TORCH_CHECK_NOT_IMPLEMENTED( + false, + "Unsupported device type for sparse layout: ", + device_.type()); + } + case Layout::Mkldnn: + switch (device_.type()) { + case c10::DeviceType::CPU: + return DispatchKey::MkldnnCPU; + default: + TORCH_CHECK_NOT_IMPLEMENTED( + false, + "Unsupported device type for mkldnn layout: ", + device_.type()); + } + case Layout::SparseCsr: + case Layout::SparseCsc: + case Layout::SparseBsr: + case Layout::SparseBsc: + switch (device_.type()) { +#define DO_CASE(device, _) \ + case c10::DeviceType::device: { \ + return DispatchKey::SparseCsr##device; \ + } + C10_FORALL_BACKEND_DEVICE_TYPES(DO_CASE, unused) +#undef DO_CASE + default: + TORCH_CHECK_NOT_IMPLEMENTED( + false, + "Unsupported device type for ", + layout_, + " layout: ", + device_.type()); + } + default: + TORCH_CHECK(false, "Unsupported layout: ", layout_); + } +} + +inline Layout dispatchKeyToLayout(DispatchKey dispatch_key) { + switch (dispatch_key) { +#define DO_CASE(bc, _) case DispatchKey::Sparse##bc: + C10_FORALL_BACKEND_COMPONENTS(DO_CASE, unused) +#undef DO_CASE + return Layout::Sparse; +#define DO_CASE(bc, _) case DispatchKey::SparseCsr##bc: + C10_FORALL_BACKEND_COMPONENTS(DO_CASE, unused) +#undef DO_CASE + TORCH_CHECK( + false, "Cannot map DispatchKey ", dispatch_key, " to a unique layout."); + case DispatchKey::MkldnnCPU: + return Layout::Mkldnn; + default: + return Layout::Strided; + } +} + +inline c10::DeviceType dispatchKeyToDeviceType(DispatchKey dispatch_key) { + switch (dispatch_key) { + // stuff that's real +#define DO_CASE(suffix, prefix) \ + case DispatchKey::prefix##suffix: \ + return c10::DeviceType::suffix; +#define DO_CASES(_, prefix) C10_FORALL_BACKEND_DEVICE_TYPES(DO_CASE, prefix) + C10_FORALL_FUNCTIONALITY_KEYS(DO_CASES) +#undef DO_CASES +#undef DO_CASE + + case DispatchKey::MkldnnCPU: + return c10::DeviceType::CPU; + case DispatchKey::Vulkan: + return c10::DeviceType::Vulkan; + + case DispatchKey::MAIA: + return c10::DeviceType::MAIA; + default: + TORCH_CHECK( + false, + "DispatchKey ", + dispatch_key, + " doesn't correspond to a device"); + } +} + +inline TensorOptions dispatchKeyToTensorOptions(DispatchKey dispatch_key) { + return TensorOptions() + .layout(dispatchKeyToLayout(dispatch_key)) + .device(dispatchKeyToDeviceType(dispatch_key)); +} + +namespace detail { +inline bool backend_supports_empty_operator(const TensorOptions& options) { + // Quantized backends don't support at::empty(). + // They have separate operators like at::empty_quantized() that take in + // extra information about how to quantize the tensor. + return !isQIntType(typeMetaToScalarType(options.dtype())); +} + +} // namespace detail + +} // namespace c10 + +C10_DIAGNOSTIC_POP() + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/UndefinedTensorImpl.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/UndefinedTensorImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..3a8381e887f90556b66f8b654bb5376e16afe074 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/UndefinedTensorImpl.h @@ -0,0 +1,54 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace c10 { + +struct C10_API UndefinedTensorImpl final : public TensorImpl { + public: + // Without this, we get: + // error: identifier "at::UndefinedTensorImpl::_singleton" is undefined in + // device code + // (ostensibly because the constexpr tricks MSVC into trying to compile this + // function for device as well). +#ifdef _WIN32 + static inline TensorImpl* singleton() { + return &getInstance(); + } +#else + static constexpr inline TensorImpl* singleton() { + return &_singleton; + } +#endif + +#ifdef DEBUG + bool has_storage() const override; +#endif + void set_storage_offset(int64_t offset) override; + + protected: + c10::SymBool sym_is_contiguous_custom(MemoryFormat format) const override; + IntArrayRef strides_custom() const override; + SymIntArrayRef sym_strides_custom() const override; + + private: + UndefinedTensorImpl(); +#ifdef _WIN32 + static UndefinedTensorImpl& getInstance(); +#else + static UndefinedTensorImpl _singleton; +#endif + const char* tensorimpl_type_name() const override; +}; + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/WrapDimMinimal.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/WrapDimMinimal.h new file mode 100644 index 0000000000000000000000000000000000000000..02570ae84ffdb64c1b2c8b20deb52178c606f57d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/WrapDimMinimal.h @@ -0,0 +1,53 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include + +namespace c10 { + +namespace detail { +// This template can only be specialized at int64_t and c10::SymInt; +// you'll get linker errors otherwise +template +C10_API T maybe_wrap_dim_slow(T dim, T dim_post_expr, bool wrap_scalar); +} // namespace detail + +template +T _maybe_wrap_dim(T dim, T dim_post_expr, bool wrap_scalar = true) { + // Inline the fast paths + if (C10_LIKELY(dim_post_expr * -1 <= dim && dim < dim_post_expr)) { + // For SymInts, we want an explicit control flow to trigger a guard, so we + // may as well branch too. + if (dim < 0) { + return dim + dim_post_expr; + } + return dim; + } + // Check edge-cases out-of-line (wrapping scalars and out-of-bounds errors) + return c10::detail::maybe_wrap_dim_slow( + std::move(dim), std::move(dim_post_expr), wrap_scalar); +} + +inline int64_t maybe_wrap_dim( + int64_t dim, + int64_t dim_post_expr, + bool wrap_scalar = true) { + return _maybe_wrap_dim(dim, dim_post_expr, wrap_scalar); +} + +inline c10::SymInt maybe_wrap_dim( + c10::SymInt dim, + c10::SymInt dim_post_expr, + bool wrap_scalar = true) { + return _maybe_wrap_dim(std::move(dim), std::move(dim_post_expr), wrap_scalar); +} + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/alignment.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/alignment.h new file mode 100644 index 0000000000000000000000000000000000000000..4ef01f7bfa99c473ebb6612a83f0cdde53eeec6b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/alignment.h @@ -0,0 +1,35 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +namespace c10 { + +#ifdef C10_MOBILE +// Use 16-byte alignment on mobile +// - ARM NEON AArch32 and AArch64 +// - x86[-64] < AVX +constexpr size_t gAlignment = 16; +#else +// Use 64-byte alignment should be enough for computation up to AVX512. +constexpr size_t gAlignment = 64; +#endif + +constexpr size_t gPagesize = 4096; +// since the default thp pagesize is 2MB, enable thp only +// for buffers of size 2MB or larger to avoid memory bloating +constexpr size_t gAlloc_threshold_thp = static_cast(2) * 1024 * 1024; + +// Cache line size used to avoid false sharing between threads. Falls back to 64 +// bytes if C++17 feature is unavailable. +#ifdef __cpp_lib_hardware_interference_size +using std::hardware_destructive_interference_size; +#else +constexpr std::size_t hardware_destructive_interference_size = 64; +#endif +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/impl/COW.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/impl/COW.h new file mode 100644 index 0000000000000000000000000000000000000000..1ef394e6e3536530af4a6427f16f0a383c39c5be --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/impl/COW.h @@ -0,0 +1,37 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +namespace c10 { +struct StorageImpl; +class DataPtr; +} // namespace c10 + +namespace c10::impl::cow { + +// Creates a Copy-on-write (COW) clone of the given storage. This will also +// convert the given storage into a COW storage if it is not COW already. +// +// Converting the storage into a COW storage will not be successful if the +// storage's DataPtr has some context (`DataPtr::get_context()`) which is not +// equal to the data pointer (`DataPtr::get()`). In this case, a nullptr is +// returned. +C10_API c10::intrusive_ptr lazy_clone_storage( + StorageImpl& storage); + +// Check if a storage has a simple DataPtr with no abnormal context +C10_API bool has_simple_data_ptr(const c10::StorageImpl& storage); + +// Check if a DataPtr is COW +C10_API bool is_cow_data_ptr(const c10::DataPtr& data_ptr); + +// Eagerly copies a COW storage's data, turning it into a non-COW storage. +C10_API void materialize_cow_storage(StorageImpl& storage); + +} // namespace c10::impl::cow + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/impl/COWDeleter.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/impl/COWDeleter.h new file mode 100644 index 0000000000000000000000000000000000000000..90a618003c995ce6fe949b8f0ea5110a8a47b74a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/impl/COWDeleter.h @@ -0,0 +1,71 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +#include +#include +#include +#include +#include + +namespace c10::impl::cow { + +// A COWDeleterContext object is used as the `ctx` argument for DataPtr +// to implement a Copy-on-write (COW) DataPtr. +class C10_API COWDeleterContext { + public: + // Creates an instance, holding the pair of data and original + // deleter. + // + // Note that the deleter will only be called in our destructor if + // the last reference to this goes away without getting + // materialized. + explicit COWDeleterContext(std::unique_ptr data); + + // Increments the current refcount. + void increment_refcount(); + + // See README.md in this directory to understand the locking + // strategy. + + // Represents a reference to the context. + // + // This is returned by decrement_refcount to allow the caller to + // copy the data under the shared lock. + using NotLastReference = std::shared_lock; + + // Represents the last reference to the context. + // + // This will be returned by decrement_refcount when it is the last + // reference remaining and after any pending copies have completed. + using LastReference = std::unique_ptr; + + // Decrements the refcount, returning a handle indicating what to + // do with it. + std::variant decrement_refcount(); + + private: + // The destructor is hidden, this should only ever be used within + // UniqueVoidPtr using cow::delete_context as the deleter. + ~COWDeleterContext(); + + std::shared_mutex mutex_; + std::unique_ptr data_; + std::atomic refcount_ = 1; +}; + +// `cow_deleter` is used as the `ctx_deleter` for DataPtr to implement a COW +// DataPtr. +// +// Warning: This should only be called on a pointer to a COWDeleterContext that +// was allocated on the heap with `new`, because when the refcount reaches 0, +// the context is deleted with `delete`. +C10_API void cow_deleter(void* ctx); + +} // namespace c10::impl::cow + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/impl/DeviceGuardImplInterface.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/impl/DeviceGuardImplInterface.h new file mode 100644 index 0000000000000000000000000000000000000000..f8b12a993a2a82c4b09b74e5c26ca48bcff3f4bf --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/impl/DeviceGuardImplInterface.h @@ -0,0 +1,417 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include + +// Just for C10_ANONYMOUS_VARIABLE +#include +#include + +#include +#include + +namespace c10 { + +// Forward declaration +class DataPtr; + +/** + * Note [Flags defining the behavior of events] + * + * PYTORCH_DEFAULT and BACKEND_DEFAULT are valid for all backends. The + * BACKEND_DEFAULT is what a particular backend would select if no + * flags were given. PYTORCH_DEFAULT is the PyTorch's framework default + * choice for events on that backend, which may not be the same. + * + * The mapping of PYTORCH_DEFAULT and BACKEND_DEFAULT is done by each + * backend implementation. + */ +enum class EventFlag { + // Disable timing + PYTORCH_DEFAULT, + // Enable timing + BACKEND_DEFAULT, + // FOR TESTING ONLY + INVALID +}; + +namespace impl { + +/** + * DeviceGuardImplInterface represents the virtual interface which provides + * functionality to provide an RAII class for device and stream switching, + * via DeviceGuard. Every distinct device type, e.g., CUDA and HIP, is + * expected to implement and register an implementation of this interface. + * All classes which inherit from DeviceGuardImplInterface should be declared + * 'final'. + * + * This class exists because we provide a unified interface for performing + * device guards via DeviceGuard, but we cannot assume that we have actually + * compiled against the, e.g., CUDA library, which actually implements + * this guard functionality. In this case, a dynamic dispatch is required + * to cross the library boundary. + * + * If possible, you should directly use implementations of this interface; + * those uses will be devirtualized. + */ +struct C10_API DeviceGuardImplInterface { + DeviceGuardImplInterface() = default; + DeviceGuardImplInterface(const DeviceGuardImplInterface&) = default; + DeviceGuardImplInterface& operator=(const DeviceGuardImplInterface&) = + default; + DeviceGuardImplInterface(DeviceGuardImplInterface&&) noexcept = default; + DeviceGuardImplInterface& operator=(DeviceGuardImplInterface&&) noexcept = + default; + + /** + * Return the type of device managed by this guard implementation. + */ + virtual DeviceType type() const = 0; + + /** + * Set the current device to Device, and return the previous Device. + */ + virtual Device exchangeDevice(Device) const = 0; + // NB: Implementations of exchangeDevice can be a bit boilerplatey. You might + // consider replacing exchangeDevice with a non-virtual function with a baked + // in implementation; however, note that this will triple the number of + // virtual calls (when you implement exchangeDevice in a final subclass, + // the compiler gets to devirtualize everything; it won't do that if you don't + // define it in the subclass!) A common way to solve this problem is to use + // some sort of CRTP; however, we can template DeviceGuardImplInterface since + // we really *do* need it to be virtual. A little boilerplate seems easiest + // to explain. (Another way around this problem is to provide inline + // functions that provide the default implementations, but this seems a little + // hard to explain. In any case, we're only going to have on order of ten + // implementations of this anyway.) + + /** + * Get the current device. + */ + virtual Device getDevice() const = 0; + + /** + * Set the current device to Device. + */ + virtual void setDevice(Device) const = 0; + + /** + * Set the current device to Device, without checking for errors + * (so, e.g., this can be called from a destructor). + */ + virtual void uncheckedSetDevice(Device) const noexcept = 0; + + /** + * Get the current stream for a given device. + */ + virtual Stream getStream(Device) const = 0; + + /** + * Get the default stream for a given device. + */ + virtual Stream getDefaultStream(Device /*unused*/) const { + TORCH_CHECK(false, "Backend doesn't support acquiring a default stream.") + } + + /** + * Get a stream from the global pool for a given device. + */ + virtual Stream getStreamFromGlobalPool( + Device /*unused*/, + bool isHighPriority = false) const { + (void)isHighPriority; // Suppress unused variable warning + TORCH_CHECK(false, "Backend doesn't support acquiring a stream from pool.") + } + + /** + * Return a new stream for a given device and priority. The stream will be + * copied and shared around, device backend should be able to correctly handle + * the lifetime of the stream. + */ + virtual Stream getNewStream(Device /*unused*/, int priority = 0) const { + (void)priority; + TORCH_CHECK(false, "Backend doesn't support create a new Stream.") + } + + /** + * Set a stream to be the thread local current stream for its device. + * Return the previous stream for that device. You are NOT required + * to set the current device to match the device of this stream. + */ + virtual Stream exchangeStream(Stream) const = 0; + + /** + * Destroys the given event. + */ + virtual void destroyEvent(void* /*event*/, const DeviceIndex /*device_index*/) + const noexcept {} + + /** + * Increments the event's version and enqueues a job with this version + * in the stream's work queue. When the stream process that job + * it notifies all streams waiting on / blocked by that version of the + * event to continue and marks that version as recorded. + * */ + virtual void record( + void** /*event*/, + const Stream& /*stream*/, + const DeviceIndex /*device_index*/, + const c10::EventFlag /*flag*/) const { + TORCH_CHECK(false, "Backend doesn't support events."); + } + + /** + * Does nothing if the event has not been scheduled to be recorded. + * If the event was previously enqueued to be recorded, a command + * to wait for the version of the event that exists at the time of this call + * is inserted in the stream's work queue. + * When the stream reaches this command it will stop processing + * additional commands until that version of the event is marked as recorded. + */ + virtual void block(void* /*event*/, const Stream& /*stream*/) const { + TORCH_CHECK(false, "Backend doesn't support events."); + } + + /** + * Returns true if (and only if) + * (1) the event has never been scheduled to be recorded + * (2) the current version is marked as recorded. + * Returns false otherwise. + */ + virtual bool queryEvent(void* /*event*/) const { + TORCH_CHECK(false, "Backend doesn't support events."); + } + + /** + * Get the number of devices. WARNING: This is REQUIRED to not raise + * an exception. If there is some sort of problem, e.g., driver error, + * you should report that there are zero available devices. + */ + virtual DeviceIndex deviceCount() const noexcept = 0; + + /** + * Get the following capabilities of the current device: + * (1) Data type support + * Returns DeviceCapability object. + */ + virtual DeviceCapability getDeviceCapability(Device /*unused*/) const { + TORCH_CHECK(false, "Backend doesn't support getting device capabilities."); + } + + /** + * Return true if all the work previously enqueued on the stream for + * asynchronous execution has completed running on the device. + */ + virtual bool queryStream(const Stream& /*stream*/) const { + TORCH_CHECK(false, "Backend doesn't support querying streams."); + } + + /** + * Wait (by blocking the calling thread) until all the work previously + * enqueued on the stream has completed running on the device. + */ + virtual void synchronizeStream(const Stream& /*stream*/) const { + TORCH_CHECK(false, "Backend doesn't support synchronizing streams."); + } + + /** + * Wait (by blocking the calling thread) until all the work previously + * recorded on the event has completed running on the device. + */ + virtual void synchronizeEvent(void* /*event*/) const { + TORCH_CHECK(false, "Backend doesn't support synchronizing events."); + } + + /** + * Wait (by blocking the calling thread) until all the work previously + * enqueued on the device has been completed. + */ + virtual void synchronizeDevice(const DeviceIndex /*device_index*/) const { + TORCH_CHECK( + false, "Backend doesn't support synchronizing all streams on device."); + } + + /** + * Ensure the caching allocator (if any) is aware that the given DataPtr is + * being used on the given stream, and that it should thus avoid recycling the + * DataPtr until all work on that stream is done. + */ + virtual void recordDataPtrOnStream( + const c10::DataPtr& /*unused*/, + const Stream& /*unused*/) const {} + + /** + * Fetch the elapsed time between two recorded events. + */ + virtual double elapsedTime( + void* /*event1*/, + void* /*event2*/, + const DeviceIndex /*device_index*/) const { + TORCH_CHECK(false, "Backend doesn't support elapsedTime."); + } + + /** + * Intended use of this class is to leak the DeviceGuardImpl at program end. + * So you better not call the destructor, buster! + */ + virtual ~DeviceGuardImplInterface() = default; +}; + +// A no-op device guard impl that doesn't do anything interesting. Useful +// for devices that don't actually have a concept of device index. Prominent +// examples are CPU and Meta. +template +struct NoOpDeviceGuardImpl : public DeviceGuardImplInterface { + NoOpDeviceGuardImpl() = default; + DeviceType type() const override { + return D; + } + Device exchangeDevice(Device /*unused*/) const override { + return Device(D, -1); // no-op + } + Device getDevice() const override { + return Device(D, -1); + } + void setDevice(Device /*unused*/) const override { + // no-op + } + void uncheckedSetDevice(Device /*unused*/) const noexcept override { + // no-op + } + Stream getStream(Device /*unused*/) const noexcept override { + // no-op + return Stream(Stream::DEFAULT, Device(D, -1)); + } + + Stream getNewStream(Device /*unused*/, int priority = 0) const override { + // no-op + (void)priority; + return Stream(Stream::DEFAULT, Device(D, -1)); + } + + // NB: These do NOT set the current device + Stream exchangeStream(Stream /*unused*/) const noexcept override { + // no-op + return Stream(Stream::DEFAULT, Device(D, -1)); + } + DeviceIndex deviceCount() const noexcept override { + return 1; + } + + DeviceCapability getDeviceCapability(Device /*unused*/) const override { + DeviceCapability cap; + if constexpr (D == DeviceType::Meta) { + cap.capability_data.capability_bits = 0; + // Meta only supports basic types for shape inference + // Byte, Char, Short, Int, Long, Float, Double, + // Bool, ComplexFloat, ComplexDouble + cap.capability_data.capability_bits = (1ULL << kIndex_Byte) | + (1ULL << kIndex_Char) | (1ULL << kIndex_Short) | + (1ULL << kIndex_Int) | (1ULL << kIndex_Long) | + (1ULL << kIndex_Float) | (1ULL << kIndex_Double) | + (1ULL << kIndex_ComplexFloat) | (1ULL << kIndex_ComplexDouble) | + (1ULL << kIndex_Bool); + } + return cap; + } + + // Event-related functions + void record( + void** /*event*/, + const Stream& /*stream*/, + const DeviceIndex /*device_index*/, + const EventFlag /*flag*/) const override { + TORCH_CHECK(false, D, " backend doesn't support events."); + } + void block(void* /*event*/, const Stream& /*stream*/) const override { + TORCH_CHECK(false, D, " backend doesn't support events.") + } + bool queryEvent(void* /*event*/) const override { + TORCH_CHECK(false, D, " backend doesn't support events.") + } + void destroyEvent(void* /*event*/, const DeviceIndex /*device_index*/) + const noexcept override {} + + // Stream-related functions + bool queryStream(const Stream& /*stream*/) const override { + return true; + } + void synchronizeStream(const Stream& /*stream*/) const override { + // Don't wait for anything. + } +}; + +// The registry is NON-owning. Each stored pointer is std::atomic so +// that under all interleavings of registry calls the structure is +// race-free. This doesn't cost us anything on reads in X86. (An +// unsynchronized implementation probably is OK too, but I didn't want +// to prove that we never read from device_guard_impl_registry at the +// same time some registration is occurring. Shiver.) +// +// I'd like this registry to be valid even at program destruction time +// (in case someone uses a DeviceGuard in a destructor to do some cleanup +// in the CUDA API.) Since there are no direct accesses of the underlying +// owning objects which I can use to enforce initialization order (unlike +// in a Meyer singleton), it implies that you must *leak* objects when +// putting them in the registry. This is done by deleting the destructor +// on DeviceGuardImplInterface. +extern C10_API std::array< + std::atomic, + static_cast(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES)> + device_guard_impl_registry; + +// I can't conveniently use c10/util/Registry.h for the following reason: +// c10/util/Registry.h gives me a slow way of Create'ing a object of some +// interface from the registry, but no way of quickly accessing an already +// created object. I'll be banging on getDeviceGuardImpl every time we do a +// DeviceGuard, so I really don't want to be doing an unordered_map lookup. +// Better if the registration mechanism directly drops its implementation +// into device_guard_impl_registry. + +class C10_API DeviceGuardImplRegistrar { + public: + DeviceGuardImplRegistrar( + DeviceType /*type*/, + const DeviceGuardImplInterface* /*impl*/); +}; + +#define C10_REGISTER_GUARD_IMPL(DevType, DeviceGuardImpl) \ + static ::c10::impl::DeviceGuardImplRegistrar C10_ANONYMOUS_VARIABLE( \ + g_##DeviceType)(::c10::DeviceType::DevType, new DeviceGuardImpl()); + +inline const DeviceGuardImplInterface* getDeviceGuardImpl(DeviceType type) { + // Two adjacent int16_t fields DeviceType and DeviceIndex has field access + // miscompiled on NVCC. To workaround this issue, we apply a mask to the + // DeviceType. First check if the DeviceType is 16-bit. + // FB employees can see + // https://fb.workplace.com/groups/llvm.gcc/permalink/4053565044692080/ + // for more details + static_assert(sizeof(DeviceType) == 1, "DeviceType is not 8-bit"); + auto p = device_guard_impl_registry[static_cast(type) & 0xFF].load(); + + // This seems to be the first place where you make use of a device + // when you pass devices to factory functions. Give a nicer error + // message in this case. + TORCH_CHECK(p, "PyTorch is not linked with support for ", type, " devices"); + return p; +} + +void C10_API +registerDeviceGuard(DeviceType type, const DeviceGuardImplInterface* impl); + +inline bool hasDeviceGuardImpl(DeviceType type) { + return device_guard_impl_registry[static_cast(type)].load(); +} + +void C10_API ensureCUDADeviceGuardSet(); + +} // namespace impl +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/impl/FakeGuardImpl.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/impl/FakeGuardImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..902a4d3febafc5d9ea5c5695c428d25be7c171c2 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/impl/FakeGuardImpl.h @@ -0,0 +1,107 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include + +#include + +namespace c10::impl { + +// FakeGuardImpl is hardcoded to have eight devices. Not for +// any good reason, just to simplify code. +constexpr DeviceIndex kFakeGuardImplMaxDevices = 8; + +/** + * A fake implementation of DeviceGuardImplInterface suitable for testing. + * The current device is modeled as a mutable field in the guard implementation + * class. See DeviceGuard_test.cpp for an example use. + */ +template +struct FakeGuardImpl final : public DeviceGuardImplInterface { + static constexpr DeviceType static_type = T; + // Runtime device type is not used + FakeGuardImpl(DeviceType /*unused*/) {} + FakeGuardImpl() = default; + DeviceType type() const override { + return T; + } + Device exchangeDevice(Device d) const override { + AT_ASSERT(d.type() == type()); + AT_ASSERT(d.index() < kFakeGuardImplMaxDevices); + Device old_device = getDevice(); + if (old_device.index() != d.index()) { + current_device_ = d.index(); + } + return old_device; + } + Device getDevice() const override { + return Device(type(), current_device_); + } + void setDevice(Device d) const override { + AT_ASSERT(d.type() == type()); + AT_ASSERT(d.index() >= 0); + AT_ASSERT(d.index() < kFakeGuardImplMaxDevices); + current_device_ = d.index(); + } + void uncheckedSetDevice(Device d) const noexcept override { + current_device_ = d.index(); + } + Stream getStream(Device d) const noexcept override { + return Stream(Stream::UNSAFE, d, current_streams_[d.index()]); + } + Stream exchangeStream(Stream s) const noexcept override { + auto old_id = current_streams_[s.device_index()]; + current_streams_[s.device_index()] = s.id(); + return Stream(Stream::UNSAFE, s.device(), old_id); + } + DeviceIndex deviceCount() const noexcept override { + return kFakeGuardImplMaxDevices; + } + + // Event-related functions + void record( + void** /*event*/, + const Stream& /*stream*/, + const DeviceIndex /*device_index*/, + const EventFlag /*flag*/) const override {} + void block(void* /*event*/, const Stream& /*stream*/) const override {} + bool queryEvent(void* /*event*/) const override { + return true; + } + void destroyEvent(void* /*event*/, const DeviceIndex /*device_index*/) + const noexcept override {} + + // Convenience methods for testing + static DeviceIndex getDeviceIndex() { + return current_device_; + } + static void setDeviceIndex(DeviceIndex i) { + AT_ASSERT(i >= 0); + AT_ASSERT(i < kFakeGuardImplMaxDevices); + current_device_ = i; + } + static StreamId getCurrentStreamIdFor(DeviceIndex i) { + return current_streams_.at(i); + } + static void resetStreams() { + current_streams_.fill(0); + } + + private: + thread_local static DeviceIndex current_device_; + thread_local static std::array + current_streams_; +}; + +template +thread_local DeviceIndex FakeGuardImpl::current_device_ = 0; + +template +thread_local std::array + FakeGuardImpl::current_streams_ = {0, 0, 0, 0, 0, 0, 0, 0}; + +} // namespace c10::impl + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/impl/GPUTrace.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/impl/GPUTrace.h new file mode 100644 index 0000000000000000000000000000000000000000..57761cff9bc254158816d43451ed5bc01f60411f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/impl/GPUTrace.h @@ -0,0 +1,33 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include + +namespace c10::impl { + +struct C10_API GPUTrace { + // On the x86 architecture the atomic operations are lock-less. + static std::atomic gpuTraceState; + + // When PyTorch migrates to C++20, this should be changed to an atomic flag. + // Currently, the access to this variable is not synchronized, on the basis + // that it will only be flipped once and by the first interpreter that + // accesses it. + static bool haveState; + + // This function will only register the first interpreter that tries to invoke + // it. For all of the next ones it will be a no-op. + static void set_trace(const PyInterpreter* /*trace*/); + + static const PyInterpreter* get_trace() { + if (!haveState) + return nullptr; + return gpuTraceState.load(std::memory_order_acquire); + } +}; + +} // namespace c10::impl + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/impl/HermeticPyObjectTLS.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/impl/HermeticPyObjectTLS.h new file mode 100644 index 0000000000000000000000000000000000000000..032b90a20bd297b742711ada1d9d5ed1501a5e7e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/impl/HermeticPyObjectTLS.h @@ -0,0 +1,67 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +namespace c10::impl { + +// This TLS controls whether or not we permanently associate PyObject +// with Tensor the first time it is allocated. When hermetic PyObject +// TLS is enabled (state is true), we DO NOT save PyObjects to Tensor, +// meaning you get a distinct PyObject whenever you execute the code in +// question. +struct C10_API HermeticPyObjectTLS { + static void set_state(bool state); + static bool get_state() { + // Hypothetical fastpath if torchdeploy/multipy // codespell:ignore multipy + // isn't used. Per + // https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2020/p2055r0.pdf + // this qualifies relaxed access because it is a single-location data + // structure (only the boolean here). + // + // Forgetting about data races for a moment, is there a logical race? + // + // - Boolean only ever transitions from false to true. So the + // critical situation is when one interpreter is already running + // when a second interpreter switches haveState from false to true. + // + // - The first interpreter is indifferent whether or not it sees + // hasState true/false; obviously false works (this is what the + // interpreter was previously using; more directly, the interpreter + // calls into itself as the handler, so being hermetic is not + // required), and true simply means serviced python operator calls will + // be hermetic; in these cases it is expected to be functionally + // equivalent. + // + // - The second interpreter MUST see hasState true (as its requests will + // be forwarded to the first interpreter), but it is assumed that there + // is a synchronization between the interpreter initialization, and + // when we actually perform operations, so it is guaranteed to see + // hasState true. + // + // QED. + // + // This fastpath is currently disabled so that we can more easily test that + // hermetic mode works correctly even on stock build of PyTorch. + if (false && !haveState_.load(std::memory_order_relaxed)) + return false; + return get_tls_state(); + } + // Call this from the multipy/torchdeploy // codespell:ignore multipy + // top level + static void init_state(); + + private: + // This only flipped once from false to true during + // torchdeploy/multipy initialization, // codespell:ignore multipy + // and never again. + static std::atomic haveState_; + static bool get_tls_state(); +}; + +} // namespace c10::impl + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/impl/InlineDeviceGuard.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/impl/InlineDeviceGuard.h new file mode 100644 index 0000000000000000000000000000000000000000..34d6dff97654888cd12d52ce1f44441f30247e44 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/impl/InlineDeviceGuard.h @@ -0,0 +1,438 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +// This file provides implementations of InlineDeviceGuard and +// InlineOptionalDeviceGuard. + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace c10::impl { + +/** + * A DeviceGuard is an RAII class that sets a device to some value + * on construction, and resets the device to its original value on + * destruction. + * + * InlineDeviceGuard is a helper class for implementing DeviceGuards. + * It is templated over a DeviceGuardImpl (anything that implements + * DeviceGuardImplInterface). There are two primary ways to instantiate + * InlineDeviceGuard: + * + * - With a concrete implementation of DeviceGuardImpl, e.g., CUDAGuardImpl. + * This is the best way to use InlineDeviceGuard, as all calls are + * devirtualized, giving you code as efficient as straight line + * calls to cudaGetDevice/cudaSetDevice. + * + * - With VirtualGuardImpl, which does a virtual dispatch to a DeviceGuardImpl + * retrieved from a DeviceType registry. We have explicitly instantiated + * InlineDeviceGuard this way as c10::DeviceGuard. + * + * If you are in a hurry, you can use InlineDeviceGuard directly: + * + * using CUDAGuard = impl::InlineDeviceGuard; + * + * However, you can provide a better user experience if you explicitly write a + * wrapper class that itself contains the template instantiation: + * + * class CUDAGuard { + * public: + * // ... the API ... + * private: + * impl::InlineDeviceGuard guard_; + * } + * + * The wrapper class provides a good place to write documentation, and helps + * avoid weird template instantiation errors when a user incorrectly uses the + * class. + * + * If you need to test this class, consider instantiating it with FakeGuardImpl. + */ +template +class InlineDeviceGuard { + public: + // Note [Omitted default constructor from RAII] + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + // In principle, we could add a default constructor to + // DeviceGuard which reads the current device and promises to + // restore to that device on exit. However, most cases where you + // would have written this, you probably meant to actually just + // use DeviceGuard (since you don't actually need the + // restore to happen if you don't ever actually set the device). + // We remove the constructor here to encourage you to think about + // what you actually want to happen. + explicit InlineDeviceGuard() = delete; + + /// Set the current device to the passed Device. + explicit InlineDeviceGuard(Device device) + : impl_(device.type()), + original_device_( + device.index() == -1 ? impl_.getDevice() + : impl_.exchangeDevice(device)), + current_device_(device.index() == -1 ? original_device_ : device) {} + + /// Set the current device index to the passed DeviceIndex. (The + /// device type is inferred from the template parameter T). + template < + typename U = T, + typename = + typename std::enable_if_t>> + explicit InlineDeviceGuard(DeviceIndex device_index) + : InlineDeviceGuard(Device(U::static_type, device_index)) {} + + /// Construct an InlineDeviceGuard using VirtualGuardImpl with an explicit + /// DeviceGuardImplInterface pointer. + template < + typename U = T, + typename = typename std::enable_if_t>> + explicit InlineDeviceGuard( + Device device, + const DeviceGuardImplInterface* impl) + : impl_( + VirtualGuardImpl(impl ? impl : getDeviceGuardImpl(device.type()))), + original_device_( + device.index() == -1 ? impl_.getDevice() + : impl_.exchangeDevice(device)), + current_device_(device.index() == -1 ? original_device_ : device) {} + + /// Copy is disallowed + InlineDeviceGuard(const InlineDeviceGuard&) = delete; + InlineDeviceGuard& operator=(const InlineDeviceGuard&) = delete; + + /// Move is disallowed, as DeviceGuard does not have an uninitialized state, + /// which is required for moves on types with nontrivial destructors. + InlineDeviceGuard(InlineDeviceGuard&& other) = delete; + InlineDeviceGuard& operator=(InlineDeviceGuard&& other) = delete; + + ~InlineDeviceGuard() { + impl_.uncheckedSetDevice(original_device_); + } + + /// Sets the device to the given one. + template < + typename U = T, + typename std::enable_if_t, int> = 0> + void set_device(at::Device device) { + AT_ASSERT( + (U::static_type == DeviceType::HIP && device.is_cuda()) || + device.type() == U::static_type); + auto index = device.index(); + if (index == -1) + return; + impl_.setDevice(device); + current_device_ = device; + } + + /// Resets the currently set device to its original device, and then sets the + /// current device to the passed device. This is effectively equivalent to + /// set_device when a guard supports only a single device type. + template + typename std::enable_if_t> reset_device( + at::Device device) { + set_device(device); + } + + /// Resets the currently set device to its original device, and then sets the + /// current device to the passed device (for a possibly different device + /// type). + /// + /// This method is named reset_device to highlight the fact that previous + /// device settings from this guard are NOT preserved, even if the device + /// has a different device type. For example: + /// + /// // CUDA device is 0 + /// DeviceGuard g(Device(kCUDA, 1)); + /// g.reset_device(Device(kHIP, 2)); + /// // CUDA device is 0 (!!) + /// + /// NOTE: this implementation may skip some device setting if it can prove + /// that it is unnecessary. + /// + /// Optional argument is for testing only. + template + typename std::enable_if_t> reset_device( + at::Device device, + const impl::DeviceGuardImplInterface* impl = nullptr) { + auto index = device.index(); + if (index == -1) + return; + if (device.type() == original_device_.type()) { + AT_ASSERT(impl == nullptr || impl->type() == device.type()); + impl_.setDevice(device); + current_device_ = device; + } else { + // Destruct and reconstruct the DeviceGuard in place + impl_.setDevice(original_device_); + impl_ = !impl ? VirtualGuardImpl(device.type()) : VirtualGuardImpl(impl); + original_device_ = impl_.exchangeDevice(device); + current_device_ = device; + } + } + + /// Sets the device index to the given one. The device type is inferred + /// from the original device type. + void set_index(DeviceIndex index) { + reset_device(Device(original_device_.type(), index)); + } + + /// Returns the device that was set at the time the most recent + /// reset_device(), or otherwise the device at construction time. + Device original_device() const { + return original_device_; + } + + /// Returns the most recent device that was set using this device guard, + /// either from construction, or via set_device/reset_device/set_index. + Device current_device() const { + return current_device_; + } + + protected: + T impl_; + + private: + Device original_device_; + Device current_device_; +}; + +/** + * A OptionalDeviceGuard is an RAII class that sets a device to some value on + * initialization, and resets the device to its original value on destruction. + * + * InlineOptionalDeviceGuard is a helper class for implementing + * OptionalDeviceGuards. See guidance in InlineDeviceGuard on how to + * use this. See OptionalDeviceGuard for user-oriented usage notes. + */ +template +class InlineOptionalDeviceGuard { + public: + // Note [Explicit initialization of optional fields] + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + // Explicit initialization of optional fields + // required to workaround an nvcc bug; see + // https://github.com/pytorch/pytorch/issues/12117 + + /// Creates an uninitialized OptionalDeviceGuard. + explicit InlineOptionalDeviceGuard() + : guard_() // See Note [Explicit initialization of optional fields] + {} + ~InlineOptionalDeviceGuard() = default; + + /// Set the current device to the passed Device, if it is not nullopt. + explicit InlineOptionalDeviceGuard(std::optional device_opt) + : guard_() { // See Note [Explicit initialization of optional fields] + if (device_opt.has_value()) { + guard_.emplace(device_opt.value()); + } + } + + /// Set the current device to the passed DeviceIndex, if it is not nullopt. + template < + typename U = T, + typename = + typename std::enable_if_t>> + explicit InlineOptionalDeviceGuard( + std::optional device_index_opt) + : guard_() { // See Note [Explicit initialization of optional fields] + if (device_index_opt.has_value()) { + guard_.emplace(device_index_opt.value()); + } + } + + /// All constructors of DeviceGuard are valid for OptionalDeviceGuard + /// and result in initialized OptionalDeviceGuard. + template + explicit InlineOptionalDeviceGuard(Args&&... args) + : guard_(std::in_place, std::forward(args)...) {} + + // TODO: Consider reading Tensor and TensorList constructors here, when + // Tensor moves to c10. (These are only valid on OptionalDeviceGuard, + // because a Tensor may be undefined, in which case we need an uninitialized + // tensor guard.) + + // Note [Move construction for RAII guards is tricky] + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + // In principle, move construction is useful for terminating + // the lifetime of a `OptionalDeviceGuard` early; for example: + // + // // current device is d0 + // OptionalDeviceGuard g1(d1); + // // current device is d1 + // { + // OptionalDeviceGuard g2(std::move(g1)); + // } + // // current device is d0!! + // + // However, it's difficult to implement the move constructor + // in a way that works in all situations. For example, consider + // the following example: + // + // OptionalDeviceGuard g1(d1); + // { + // OptionalDeviceGuard g2(d2); + // { + // OptionalDeviceGuard g3(std::move(g1)); // !!! + // } + // } + // + // What should the current device be while g3 in scope... and what + // should it be after it goes out of scope? What about g2? + // There don't seem to be satisfactory answers for these questions. + // + // It's in principle possible to raise an error when this occurs + // by doing some extra thread-local bookkeeping. But why bother? + // Just don't provide the constructor. + InlineOptionalDeviceGuard(const InlineOptionalDeviceGuard& other) = delete; + InlineOptionalDeviceGuard(InlineOptionalDeviceGuard&& other) = delete; + + // Note [Move assignment for RAII guards is tricky] + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + // Move assignment is deleted, because you need to know which guard was + // defined "first", as that guard's original_device_ wins--with the current + // representation, we have no way of telling which is the case. (Move + // construction does not have this problem, as one guard is always + // uninitialized.) + // + // We can make this clear by way of a pair of examples: + // + // Example 1: + // + // // initial device is n0 + // { + // CUDAGuard g1(n1); + // { + // CUDAGuard g2(n2); + // // current device should be n2 + // g1 = std::move(g2); + // // current device should still be n2 + // } + // // current device should still be n2 + // } + // // current device should be n0 + // + // Example 2 (flip the order of the two guards): + // + // // initial device is n0 + // { + // CUDAGuard g2(n2); + // { + // CUDAGuard g1(n1); + // // current device should be n1 + // g1 = std::move(g2); + // // current device should be n2 + // } + // // current device should be n0 (since g2 has been vacated) + // } + // + // In both examples, we need g1 to restore to n0 after move assignment. + // However, in example 1, this is determined by the restore value of g1 + // (prior to the move). In example 2, however, it is determined by the the + // restore value of g2(!!). We don't know which one should win, without having + // a way of telling which guard was allocated first. + // + // We could solve this with an extra thread-local variable. But no one is + // actually using move-assignment. So just get rid of it. + InlineOptionalDeviceGuard& operator=(const InlineOptionalDeviceGuard& other) = + delete; + InlineOptionalDeviceGuard& operator=(InlineOptionalDeviceGuard&& other) = + delete; + + /// Sets the device to the given one. Initializes OptionalDeviceGuard if it + /// is not already initialized. + template < + typename U = T, + typename = + typename std::enable_if_t>> + void set_device(at::Device device) { + if (!guard_.has_value()) { + guard_.emplace(device); + } else { + guard_->set_device(device); + } + } + + /// Resets the currently set device to its original device, and then sets the + /// current device to the passed device (for a possibly different device + /// type). Initializes OptionalDeviceGuard if it is not already initialized. + /// + /// See notes on why this is called reset_device on InlineDeviceGuard. + /// + /// Optional argument is for testing only. + template < + typename U = T, + typename = typename std::enable_if_t>> + void reset_device( + at::Device device, + const DeviceGuardImplInterface* impl = nullptr) { + if (!guard_.has_value()) { + guard_.emplace(device, impl); + } else { + guard_->reset_device(device, impl); + } + } + + /// Resets the currently set device to its original device, and then sets the + /// current device to the passed device. Initializes the guard if it is + /// not already initialized. This is effectively equivalent to set_device + /// when a guard supports only a single device type. + template < + typename U = T, + typename = + typename std::enable_if_t>> + void reset_device(at::Device device) { + if (!guard_.has_value()) { + guard_.emplace(device); + } else { + guard_->reset_device(device); + } + } + + /// Sets the device index to the given one. The device type is statically + /// known. + template < + typename U = T, + typename = + typename std::enable_if_t>> + void set_index(DeviceIndex index) { + if (!guard_.has_value()) { + guard_.emplace(index); + } else { + guard_->set_index(index); + } + } + + /// Returns the device that was set immediately prior to initialization of + /// the, guard, or nullopt if the guard is uninitialized. + std::optional original_device() const { + return guard_.has_value() ? std::make_optional(guard_->original_device()) + : std::nullopt; + } + + /// Returns the most recent device that was set using this device guard, + /// either from construction, or via set_device, if the guard is initialized, + /// or nullopt if the guard is uninitialized. + std::optional current_device() const { + return guard_.has_value() ? std::make_optional(guard_->current_device()) + : std::nullopt; + } + + /// Restore the original device, resetting this guard to uninitialized state. + void reset() { + guard_.reset(); + } + + private: + std::optional> guard_; +}; + +} // namespace c10::impl + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/impl/InlineEvent.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/impl/InlineEvent.h new file mode 100644 index 0000000000000000000000000000000000000000..15d4083daab7439295a132ca3b157eae1ba6745d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/impl/InlineEvent.h @@ -0,0 +1,152 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include + +namespace c10::impl { + +template +struct InlineEvent final { + InlineEvent() = delete; + InlineEvent( + const DeviceType _device_type, + const EventFlag _flag = EventFlag::PYTORCH_DEFAULT) + : backend_{_device_type}, device_type_{_device_type}, flag_{_flag} {} + + // Copy constructor and copy assignment operator (deleted) + InlineEvent(const InlineEvent&) = delete; + InlineEvent& operator=(const InlineEvent&) = delete; + + // Move constructor and move assignment operator + InlineEvent(InlineEvent&& other) noexcept + : event_(other.event_), + backend_(std::move(other.backend_)), + device_type_(other.device_type_), + device_index_(other.device_index_), + flag_(other.flag_), + was_marked_for_recording_(other.was_marked_for_recording_) { + other.event_ = nullptr; + } + InlineEvent& operator=(InlineEvent&& other) noexcept { + swap(other); + return *this; + } + + void swap(InlineEvent& other) noexcept { + std::swap(event_, other.event_); + std::swap(backend_, other.backend_); + std::swap(device_type_, other.device_type_); + std::swap(device_index_, other.device_index_); + std::swap(flag_, other.flag_); + std::swap(was_marked_for_recording_, other.was_marked_for_recording_); + } + + ~InlineEvent() noexcept { + if (event_) + backend_.destroyEvent(event_, device_index_); + } + + DeviceType device_type() const noexcept { + return device_type_; + } + DeviceIndex device_index() const noexcept { + return device_index_; + } + EventFlag flag() const noexcept { + return flag_; + } + bool was_marked_for_recording() const noexcept { + return was_marked_for_recording_; + } + + void recordOnce(const Stream& stream) { + if (!was_marked_for_recording_) + record(stream); + } + + void record(const Stream& stream) { + TORCH_CHECK( + stream.device_type() == device_type_, + "Event device type ", + DeviceTypeName(device_type_), + " does not match recording stream's device type ", + DeviceTypeName(stream.device_type()), + "."); + + backend_.record(&event_, stream, device_index_, flag_); + was_marked_for_recording_ = true; + device_index_ = stream.device_index(); + } + + void block(const Stream& stream) const { + if (!was_marked_for_recording_) + return; + + TORCH_CHECK( + stream.device_type() == device_type_, + "Event device type ", + DeviceTypeName(device_type_), + " does not match blocking stream's device type ", + DeviceTypeName(stream.device_type()), + "."); + + backend_.block(event_, stream); + } + + bool query() const { + if (!was_marked_for_recording_) + return true; + return backend_.queryEvent(event_); + } + + void* eventId() const { + return event_; + } + + double elapsedTime(const InlineEvent& other) const { + TORCH_CHECK( + other.device_type() == device_type_, + "Event device type ", + DeviceTypeName(device_type_), + " does not match other's device type ", + DeviceTypeName(other.device_type()), + "."); + TORCH_CHECK_VALUE( + (flag_ == EventFlag::BACKEND_DEFAULT) && + (other.flag_ == EventFlag::BACKEND_DEFAULT), + "Both events must be created with argument 'enable_timing=True'."); + TORCH_CHECK_VALUE( + was_marked_for_recording() && other.was_marked_for_recording(), + "Both events must be recorded before calculating elapsed time."); + // elapsedTime in MPS can wait event to be completed if event is not ready, + // which is a little different from CUDA + TORCH_CHECK( + (query() && other.query()) || device_type_ == DeviceType::MPS, + "Both events must be completed before calculating elapsed time."); + + return backend_.elapsedTime(event_, other.event_, device_index_); + } + + void synchronize() const { + if (!was_marked_for_recording_) + return; + backend_.synchronizeEvent(event_); + } + + private: + void* event_ = nullptr; + T backend_; + DeviceType device_type_; + DeviceIndex device_index_ = -1; + EventFlag flag_ = EventFlag::PYTORCH_DEFAULT; + bool was_marked_for_recording_ = false; +}; + +} // namespace c10::impl + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/impl/InlineStreamGuard.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/impl/InlineStreamGuard.h new file mode 100644 index 0000000000000000000000000000000000000000..7ce87a9a8eb55a30e8e6fb0ab6e5a38bc065dab9 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/impl/InlineStreamGuard.h @@ -0,0 +1,265 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include + +namespace c10::impl { + +/** + * A StreamGuard is an RAII class that changes the current device + * to the device corresponding to some stream, and changes the + * default stream on that device to be this stream. + * + * InlineStreamGuard is a helper class for implementing StreamGuards. + * See InlineDeviceGuard for guidance on how to use this class. + */ +template +class InlineStreamGuard : private InlineDeviceGuard { + public: + /// No default constructor, see Note [Omitted default constructor from RAII] + explicit InlineStreamGuard() = delete; + + /// Set the current device to the device associated with the passed stream, + /// and set the current stream on that device to the passed stream. + explicit InlineStreamGuard(Stream stream) + : InlineDeviceGuard(stream.device()), + original_stream_of_original_device_( + this->impl_.getStream(original_device())), + original_stream_of_current_device_(this->impl_.exchangeStream(stream)), + current_stream_(stream) {} + + /// This constructor exists purely for testing + template < + typename U = T, + typename = typename std::enable_if_t>> + explicit InlineStreamGuard( + Stream stream, + const DeviceGuardImplInterface* impl) + : InlineDeviceGuard( + stream.device(), + impl ? impl : getDeviceGuardImpl(stream.device_type())), + original_stream_of_original_device_( + this->impl_.getStream(original_device())), + original_stream_of_current_device_(this->impl_.exchangeStream(stream)), + current_stream_(stream) {} + + /// Copy is disallowed + InlineStreamGuard(const InlineStreamGuard&) = delete; + InlineStreamGuard& operator=(const InlineStreamGuard&) = delete; + + /// Move is disallowed, as StreamGuard does not have an uninitialized state, + /// which is required for moves on types with nontrivial destructors. + InlineStreamGuard(InlineStreamGuard&& other) = delete; + InlineStreamGuard& operator=(InlineStreamGuard&& other) = delete; + + ~InlineStreamGuard() { + this->impl_.exchangeStream(original_stream_of_current_device_); + } + + /// Resets the currently set stream to the original stream and + /// the currently set device to the original device. Then, + /// set the current device to the device associated with the passed stream, + /// and set the current stream on that device to the passed stream. + /// + /// NOTE: this implementation may skip some stream/device setting if + /// it can prove that it is unnecessary. + /// + /// WARNING: reset_stream does NOT preserve previously set streams on + /// different devices. If you need to set streams on multiple devices + /// use MultiStreamGuard instead. + void reset_stream(Stream stream) { + // TODO: make a version that takes an impl argument. Unfortunately, + // that will require SFINAE because impl is only valid for the + // VirtualGuardImpl specialization. + if (stream.device() == this->current_device()) { + this->impl_.exchangeStream(stream); + current_stream_ = stream; + } else { + // Destruct and reconstruct the StreamGuard in-place + this->impl_.exchangeStream(original_stream_of_current_device_); + this->reset_device(stream.device()); + original_stream_of_current_device_ = this->impl_.exchangeStream(stream); + current_stream_ = stream; + } + } + + // It's not clear if set_device should also reset the current stream + // if the device is unchanged; therefore, we don't provide it. + // The situation is somewhat clearer with reset_device, but it's still + // a pretty weird thing to do, so haven't added this either. + + /// Returns the stream of the original device prior to this guard. Subtly, + /// the stream returned here is the original stream of the *original* + /// device; i.e., it's the stream that your computation *would* have + /// been put on, if it hadn't been for this meddling stream guard. + /// This is usually what you want. + Stream original_stream() const { + return original_stream_of_original_device_; + } + + /// Returns the most recent stream that was set using this device guard, + /// either from construction, or via set_stream. + Stream current_stream() const { + return current_stream_; + } + + /// Returns the most recent device that was set using this device guard, + /// either from construction, or via set_device/reset_device/set_index. + Device current_device() const { + return InlineDeviceGuard::current_device(); + } + + /// Returns the device that was set at the most recent reset_stream(), + /// or otherwise the device at construction time. + Device original_device() const { + return InlineDeviceGuard::original_device(); + } + + private: + Stream + original_stream_of_original_device_; // what the user probably cares about + Stream original_stream_of_current_device_; // what we need to restore + Stream current_stream_; +}; + +/** + * An OptionalStreamGuard is an RAII class that sets a device to some value on + * initialization, and resets the device to its original value on destruction. + * See InlineOptionalDeviceGuard for more guidance on how to use this class. + */ +template +class InlineOptionalStreamGuard { + public: + /// Creates an uninitialized stream guard. + explicit InlineOptionalStreamGuard() + : guard_() // See Note [Explicit initialization of optional fields] + {} + ~InlineOptionalStreamGuard() = default; + + /// Set the current device to the device associated with the passed stream, + /// and set the current stream on that device to the passed stream, + /// if the passed stream is not nullopt. + explicit InlineOptionalStreamGuard(std::optional stream_opt) + : guard_() { + if (stream_opt.has_value()) { + guard_.emplace(stream_opt.value()); + } + } + + /// All constructors of StreamGuard are valid for OptionalStreamGuard + template + explicit InlineOptionalStreamGuard(Args&&... args) + : guard_(std::in_place, std::forward(args)...) {} + + InlineOptionalStreamGuard(const InlineOptionalStreamGuard& other) = delete; + InlineOptionalStreamGuard& operator=(const InlineOptionalStreamGuard& other) = + delete; + // See Note [Move construction for RAII guards is tricky] + InlineOptionalStreamGuard(InlineOptionalStreamGuard&& other) = delete; + + // See Note [Move assignment for RAII guards is tricky] + InlineOptionalStreamGuard& operator=(InlineOptionalStreamGuard&& other) = + delete; + + /// Resets the currently set stream to the original stream and + /// the currently set device to the original device. Then, + /// set the current device to the device associated with the passed stream, + /// and set the current stream on that device to the passed stream. + /// Initializes the OptionalStreamGuard if it was not previously initialized. + void reset_stream(Stream stream) { + if (guard_.has_value()) { + guard_->reset_stream(stream); + } else { + guard_.emplace(stream); + } + } + + /// Returns the stream that was set at the time the guard was most recently + /// initialized, or nullopt if the guard is uninitialized. + std::optional original_stream() const { + return guard_.has_value() ? std::make_optional(guard_->original_stream()) + : std::nullopt; + } + + /// Returns the most recent stream that was set using this stream guard, + /// either from construction, or via reset_stream, if the guard is + /// initialized, or nullopt if the guard is uninitialized. + std::optional current_stream() const { + return guard_.has_value() ? std::make_optional(guard_->current_stream()) + : std::nullopt; + } + + /// Restore the original device and stream, resetting this guard to + /// uninitialized state. + void reset() { + guard_.reset(); + } + + private: + std::optional> guard_; +}; + +template +class InlineMultiStreamGuard { + public: + /// Calls `set_stream` on each of the streams in the list. + /// This may be useful if you need to set different streams + /// for different devices. + explicit InlineMultiStreamGuard(ArrayRef streams) { + if (!streams.empty()) { + impl_.emplace(getDeviceTypeOfStreams(streams)); + original_streams_.reserve(streams.size()); + for (const Stream& s : streams) { + original_streams_.emplace_back(this->impl_->exchangeStream(s)); + } + } + } + + /// Copy is disallowed + InlineMultiStreamGuard(const InlineMultiStreamGuard&) = delete; + InlineMultiStreamGuard& operator=(const InlineMultiStreamGuard&) = delete; + + /// Move is disallowed, as StreamGuard does not have an uninitialized state, + /// which is required for moves on types with nontrivial destructors. + InlineMultiStreamGuard(InlineMultiStreamGuard&& other) = delete; + InlineMultiStreamGuard& operator=(InlineMultiStreamGuard&& other) = delete; + + ~InlineMultiStreamGuard() noexcept { + if (this->impl_.has_value()) { + for (const Stream& s : original_streams_) { + this->impl_->exchangeStream(s); + } + } + } + + protected: + std::optional impl_; + + private: + /// The original streams that were active on all devices. + std::vector original_streams_; + + static DeviceType getDeviceTypeOfStreams(ArrayRef streams) { + TORCH_INTERNAL_ASSERT(!streams.empty()); + DeviceType type = streams[0].device_type(); + for (const auto idx : c10::irange(1, streams.size())) { + TORCH_CHECK_VALUE( + streams[idx].device_type() == type, + "Streams have a mix of device types: stream 0 is on ", + streams[0].device(), + " while stream ", + idx, + " is on device ", + streams[idx].device()); + } + return type; + } +}; + +} // namespace c10::impl + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/impl/LocalDispatchKeySet.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/impl/LocalDispatchKeySet.h new file mode 100644 index 0000000000000000000000000000000000000000..123a288a0834468abc2e8bc7dc90b6e775506621 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/impl/LocalDispatchKeySet.h @@ -0,0 +1,174 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +// TLS management for DispatchKeySet (the "local" DispatchKeySet(s)) +// +// This manages two thread-local DispatchKeySets: +// +// - The included type set, which adds a tensor type for consideration +// in dispatch. (For example, you might add Profiling to +// the included type set to turn on profiling on all tensor operations.) +// +// - The excluded type set, which disqualifies a tensor type from dispatch. +// (For example, after redispatching on variable, we disqualify +// Autograd so we don't attempt to handle variable again.) +// (Exclusion wins over inclusion.) +// +// NB: Originally, I implemented the excluded type set as storing the inverted +// set, but TLS is defined to be zero-initialized, so this doesn't actually work +// (if it's inverted, you want the set to be -1 initialized). + +namespace c10::impl { + +// POD version of LocalDispatchKeySet. Declared here just so that +// we can put it in the guards. +// This struct encapsulates special handling for TLS initialization +// in set_included()/included() API so that they reflect the truth. +// If you want to create PODLocalDispatchKeySet with non-zero state, +// use set_included() instead of default constructor. +struct C10_API PODLocalDispatchKeySet { + uint64_t included_; + uint64_t excluded_; + + // See Note [TLS Initialization] + DispatchKeySet included() const { + return DispatchKeySet(DispatchKeySet::RAW, included_) ^ + c10::default_included_set; + } + DispatchKeySet excluded() const { + return DispatchKeySet(DispatchKeySet::RAW, excluded_) ^ + c10::default_excluded_set; + } + + void set_included(DispatchKeySet x) { + included_ = (x ^ c10::default_included_set).raw_repr(); + } + void set_excluded(DispatchKeySet x) { + excluded_ = (x ^ c10::default_excluded_set).raw_repr(); + } +}; +static_assert( + std::is_trivial_v, + "PODLocalDispatchKeySet must be a POD type."); + +struct C10_API LocalDispatchKeySet { + /* implicit */ LocalDispatchKeySet(PODLocalDispatchKeySet x) + : included_(x.included()), excluded_(x.excluded()) {} + DispatchKeySet included_; + DispatchKeySet excluded_; +}; + +// thread_local variables cannot be C10_API on Windows. +// Inlining this seems to break AutoDispatchBelowAutograd on Android. +#if defined(_MSC_VER) || defined(C10_ANDROID) || defined(C10_IPHONE) +C10_API LocalDispatchKeySet tls_local_dispatch_key_set(); +#else // defined(_MSC_VER) || defined(C10_ANDROID) || defined(C10_IPHONE) +extern C10_API thread_local PODLocalDispatchKeySet raw_local_dispatch_key_set; + +inline C10_API LocalDispatchKeySet tls_local_dispatch_key_set() { + // Don't let people fiddle with the thread_local directly just + // because they include this header. + return raw_local_dispatch_key_set; +} +#endif // defined(_MSC_VER) || defined(C10_ANDROID) || defined(C10_IPHONE) + +// Internal, use ThreadLocalStateGuard +C10_API void _force_tls_local_dispatch_key_set(LocalDispatchKeySet key_set); + +// RAII API for manipulating the thread-local dispatch state. + +class C10_API IncludeDispatchKeyGuard { + public: + IncludeDispatchKeyGuard(DispatchKeySet /*include*/); + IncludeDispatchKeyGuard(DispatchKey k) + : IncludeDispatchKeyGuard(DispatchKeySet(k)) {} + IncludeDispatchKeyGuard(const IncludeDispatchKeyGuard&) = delete; + IncludeDispatchKeyGuard operator=(const IncludeDispatchKeyGuard&) = delete; + IncludeDispatchKeyGuard(IncludeDispatchKeyGuard&&) = delete; + IncludeDispatchKeyGuard operator=(IncludeDispatchKeyGuard&&) = delete; + ~IncludeDispatchKeyGuard(); + + private: + // A little micro-optimization to save us from tls_get_addr call + // on destruction + PODLocalDispatchKeySet* tls_; + DispatchKeySet include_; +}; + +class C10_API ExcludeDispatchKeyGuard { + public: + ExcludeDispatchKeyGuard(DispatchKeySet /*exclude*/); + ExcludeDispatchKeyGuard(DispatchKey k) + : ExcludeDispatchKeyGuard(DispatchKeySet(k)) {} + ExcludeDispatchKeyGuard(const ExcludeDispatchKeyGuard&) = delete; + ExcludeDispatchKeyGuard operator=(const ExcludeDispatchKeyGuard&) = delete; + ExcludeDispatchKeyGuard(ExcludeDispatchKeyGuard&&) = delete; + ExcludeDispatchKeyGuard operator=(ExcludeDispatchKeyGuard&&) = delete; + ~ExcludeDispatchKeyGuard(); + + private: + // A little micro-optimization to save us from tls_get_addr call + // on destruction + PODLocalDispatchKeySet* tls_; + DispatchKeySet exclude_; +}; + +struct C10_API ForceDispatchKeyGuard { + public: + ForceDispatchKeyGuard() + : saved_keyset_(c10::impl::tls_local_dispatch_key_set()) {} + ForceDispatchKeyGuard(c10::impl::LocalDispatchKeySet key_set) + : ForceDispatchKeyGuard() { + c10::impl::_force_tls_local_dispatch_key_set(key_set); + } + ForceDispatchKeyGuard( + c10::DispatchKeySet include, + c10::DispatchKeySet exclude) + : ForceDispatchKeyGuard() { + auto updated_set = saved_keyset_; + updated_set.included_ = include; + updated_set.excluded_ = exclude; + c10::impl::_force_tls_local_dispatch_key_set(updated_set); + } + + ForceDispatchKeyGuard(ForceDispatchKeyGuard&&) noexcept = delete; + ForceDispatchKeyGuard(const ForceDispatchKeyGuard&) = delete; + ForceDispatchKeyGuard& operator=(const ForceDispatchKeyGuard&) = delete; + ForceDispatchKeyGuard& operator=(ForceDispatchKeyGuard&&) = delete; + ~ForceDispatchKeyGuard() { + c10::impl::_force_tls_local_dispatch_key_set(saved_keyset_); + } + + private: + c10::impl::LocalDispatchKeySet saved_keyset_; +}; + +// Non-RAII API for manipulating the thread-local dispatch state. +// Please prefer the RAII API. The non-RAII API may be useful when +// the included/excluded state of a given DispatchKey must span +// many calls from the Python to the C++, so you cannot conveniently +// use an RAII guard. +// +// Example use case: a Python context manager that includes a certain +// DispatchKey, to ensure ops running under the context manager dispatch +// through that DispatchKey's registered overrides. +// +// The non-RAII API is less efficient than the RAII guards because both the +// getter and setter will do a tls_getaddr lookup (the RAII struct only needs +// one!) + +C10_API bool tls_is_dispatch_key_excluded(DispatchKey x); +C10_API void tls_set_dispatch_key_excluded(DispatchKey x, bool desired_state); +C10_API bool tls_is_dispatch_key_included(DispatchKey x); +C10_API void tls_set_dispatch_key_included(DispatchKey x, bool desired_state); +C10_API bool tls_is_dispatch_keyset_excluded(DispatchKeySet ks); +C10_API bool tls_is_dispatch_keyset_included(DispatchKeySet ks); + +} // namespace c10::impl + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/impl/PyInterpreter.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/impl/PyInterpreter.h new file mode 100644 index 0000000000000000000000000000000000000000..ce74e9b9050b3db0db196ff4ef9f3cad198c9beb --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/impl/PyInterpreter.h @@ -0,0 +1,257 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// Forward declarations + +namespace c10 { +struct IValue; +class OperatorHandle; +struct TensorImpl; +namespace impl { +struct PyObjectSlot; +} // namespace impl +} // namespace c10 + +namespace torch::jit { +using Stack = std::vector; +} + +// Actual implementation + +namespace c10::impl { + +struct C10_API PyInterpreter; + +// Note [Python interpreter tag] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// Traditionally, PyTorch is layered such that our Python library +// (libtorch_python) references our pure C++ library (libtorch) as the +// natural order of things. However, sometimes this natural order is +// subverted: C++ objects refer to Python objects (for example, we +// store a PyObject* pointer on TensorImpl so that converting from a +// C++ Tensor to a Python Tensor is just a memory dereference). +// +// These unusual orderings must be treated with care. To start, you need to +// virtualize the destructor so that the PyObject can be decref'ed on +// destruction (because the C++ object itself doesn't know anything about +// Python--remember, layering!). This process itself is fraught, since +// acquiring the GIL could lead to deadlocks if someone is blocking on you +// while holding the GIL. Furthermore, if the C++ objects outlive the +// interpreter (which can happen if you stash them in a static global +// variable defined in libtorch), you may attempt to decref the object when +// the Python interpreter has already been shutdown. +// +// BUT WAIT, IT GETS WORSE. With torchdeploy, there may be multiple Python +// interpreters in a single process. If a C++ object is accessible from +// multiple interpreters, we must take care not to accidentally pass a +// PyObject from one interpreter with another interpreter. +// +// To prevent these mixups, we introduce a PyInterpreter "tag" (object with +// a vtable), which specifies a specific Python interpreter. +// +// - Any given object can be associated with AT MOST one Python interpreter. +// We represent the interpreter tag as a memory address to an instance of +// a virtual class that is allocated once per interpreter (this is so that +// we can request the interpreter to perform operations for us, if +// necessary). +// +// - It can be recorded with a PyObject (PyInterpreterObject) so that +// we know what interpreter the object is associated with, and we can +// raise an error if you try to use the PyObject from the wrong +// interpreter context. +// +// - It contains a vtable that can be used to perform various Python +// operations from ordinary C++ code that ordinarily wouldn't be accessible +// from libtorch. +// +// A simple use case is when a C++ object must be associated with a PyObject. +// However, for TensorImpl, we lazily allocate a PyObject the first time the +// object passes into Python. The invariants for this situation are more +// subtle: +// +// - A given TensorImpl's interpreter tag can only go from uninitialized to +// tagged; once tagged, this is a quiescent state (once tagged to an +// interpreter, ALWAYS tagged to that interpreter) +// +// - A thread may mutate the PyObject field of a TensorImpl if and only if it +// holds the GIL for the interpreter tagged on the TensorImpl. (If the +// TensorImpl is not tagged, it must first atomically claim its tag before it +// can validly write) +// +// WARNING: This class has to be written very carefully, because it may be +// possible for a Tensor to have a reference an interpreter corresponding to +// a shared library that has ALREADY BEEN UNLOADED. This makes blindly calling +// virtual methods very dangerous, because the vtable may be garbage at that +// point (on a good day, you might get "pure virtual method called"). +// +// The idea to solve this problem is we always leak PyInterpreters (so they +// always stay live even after dlclose), and make sure we can disarm their +// virtual methods by indirecting through a separate PyInterpreterVTable +// object. This can be replaced with a no-op vtable from libc10.so, which +// is guaranteed to stick around until the bitter end. +// +// NB: The downside with representing PyInterpreter tags as full objects is that +// it takes an extra word on TensorImpl. If tags were instead just integer +// indices, on 64-bit architectures we could pack the tag and PyObject together +// into a single atomic word. On 32-bit architectures we could simply say that +// only one Python interpreter is supported (erroring if a nontrivial +// interpreter tag is attempted to be set). +// +// The difficulty with this scheme is we need to maintain an out-of-line table +// to get at the PyInterpreters so that we can do virtual method calls on them, +// and registration/deregistration to this table must be done in a thread safe +// manner. This can be easily done if the number of possible PyInterpreters is +// small enough (e.g., 8-bit integer) by simply preallocating an array of +// sufficient size to hold all possible interpreters. Surely 128 threads is +// more than enough for anyone! +// +// I didn't decide to do this technique at the moment, because the extra word +// added by the PyInterpreter tag takes us to 24 words, which means that we +// still fit inside three eight word cache lines. If you need to penny pinch +// another word consider doing this! + +struct C10_API PyInterpreterVTable { + virtual ~PyInterpreterVTable() = default; + + // Report the name of this interpreter + virtual std::string name() const = 0; + + // Run Py_INCREF on a PyObject. + virtual void incref(PyObject* pyobj) const = 0; + // Run Py_DECREF on a PyObject. We DO NOT assume the GIL is held on call. + virtual void decref(PyObject* pyobj) const = 0; + // Run PyUnstable_TryIncRef on a PyObject if it's not NULL. + virtual bool try_incref(const c10::impl::PyObjectSlot& pyobj_slot) const = 0; + // Run Py_REFCNT on a PyObject. + virtual size_t refcnt(PyObject* pyobj) const = 0; + + // Perform a detach by deferring to the __torch_dispatch__ implementation of + // detach, which will also arrange for the PyObject to get copied in this + // situation + virtual c10::intrusive_ptr detach( + const TensorImpl* self) const = 0; + + // Invoke the Python boxed fallback dispatch to go back into Python + virtual void dispatch(const c10::OperatorHandle& op, torch::jit::Stack* stack) + const = 0; + + virtual void reportErrorCallback(PyObject* callback, DispatchKey key) + const = 0; + + // This is only invoked in the multipy/torchdeploy // codespell:ignore multipy + // situation from pythonOpRegistrationTrampoline; this lets us get to the + // Python interpreter to actually find the appropriate Python op registration + // entry to call. + virtual void python_op_registration_trampoline( + const c10::OperatorHandle& op, + c10::DispatchKey, + c10::DispatchKeySet keyset, + torch::jit::Stack* stack, + bool with_keyset, + bool with_op) const = 0; + + virtual void throw_abstract_impl_not_imported_error( + std::string opname, + const char* pymodule, + const char* context) const = 0; + + // Invoke the Python dispatcher to handle this call + virtual void python_dispatcher( + const c10::OperatorHandle& op, + c10::DispatchKeySet, + torch::jit::Stack* stack) const = 0; + + virtual bool is_contiguous(const TensorImpl* self, at::MemoryFormat) + const = 0; + virtual c10::SymBool sym_is_contiguous( + const TensorImpl* self, + at::MemoryFormat) const = 0; + virtual bool is_strides_like(const TensorImpl* self, at::MemoryFormat) + const = 0; + virtual bool is_non_overlapping_and_dense(const TensorImpl* self) const = 0; + virtual c10::Device device(const TensorImpl* self) const = 0; + virtual int64_t dim(const TensorImpl* self) const = 0; + virtual c10::IntArrayRef strides(const TensorImpl* self) const = 0; + virtual c10::IntArrayRef sizes(const TensorImpl* self) const = 0; + virtual c10::SymIntArrayRef sym_sizes(const TensorImpl* self) const = 0; + virtual c10::Layout layout(const TensorImpl* self) const = 0; + virtual int64_t numel(const TensorImpl* self) const = 0; + virtual c10::SymInt sym_numel(const TensorImpl* self) const = 0; + virtual c10::SymIntArrayRef sym_strides(const TensorImpl* self) const = 0; + virtual c10::SymInt sym_storage_offset(const TensorImpl* self) const = 0; + + virtual void trace_gpu_event_creation( + c10::DeviceType device_type, + uintptr_t event) const = 0; + virtual void trace_gpu_event_deletion( + c10::DeviceType device_type, + uintptr_t event) const = 0; + virtual void trace_gpu_event_record( + c10::DeviceType device_type, + uintptr_t event, + uintptr_t stream) const = 0; + virtual void trace_gpu_event_wait( + c10::DeviceType device_type, + uintptr_t event, + uintptr_t stream) const = 0; + virtual void trace_gpu_memory_allocation( + c10::DeviceType device_type, + uintptr_t ptr) const = 0; + virtual void trace_gpu_memory_deallocation( + c10::DeviceType device_type, + uintptr_t ptr) const = 0; + virtual void trace_gpu_stream_creation( + c10::DeviceType device_type, + uintptr_t stream) const = 0; + virtual void trace_gpu_device_synchronization( + c10::DeviceType device_type) const = 0; + virtual void trace_gpu_stream_synchronization( + c10::DeviceType device_type, + uintptr_t stream) const = 0; + virtual void trace_gpu_event_synchronization( + c10::DeviceType device_type, + uintptr_t event) const = 0; + + virtual void reset_backward_hooks(const TensorImpl* self) const = 0; +}; + +struct C10_API PyInterpreter { + const PyInterpreterVTable* vtable_; + + PyInterpreter(const PyInterpreterVTable* vtable) : vtable_(vtable) {} + + const PyInterpreterVTable& operator*() const noexcept { + return *vtable_; + } + const PyInterpreterVTable* operator->() const noexcept { + return vtable_; + } + + // Disarm this PyInterpreter, making all of its methods noops. + // The vtable pointer is not an atomic at the moment, which means + // a disarm() invocation that is concurrent with active destructors + // is not thread safe and will trigger TSAN. My hope is that this + // situations doesn't ever actually happen; tensor destruction should + // quiesce when a dlclose happens, and any long lived tensors whose + // destructors would be disarmed here only begin the destruction process + // on process shutdown (long after the dlclose has occurred). + void disarm() noexcept; +}; + +} // namespace c10::impl + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/impl/PyInterpreterHooks.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/impl/PyInterpreterHooks.h new file mode 100644 index 0000000000000000000000000000000000000000..acd2003569302cffcce5a907bd7fd506ac984a7b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/impl/PyInterpreterHooks.h @@ -0,0 +1,45 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include + +namespace c10::impl { + +// Minimal interface for PyInterpreter hooks +struct C10_API PyInterpreterHooksInterface { + virtual ~PyInterpreterHooksInterface() = default; + + // Get the PyInterpreter instance + // Stub implementation throws error when Python is not available + virtual PyInterpreter* getPyInterpreter() const { + TORCH_CHECK( + false, + "PyTorch was compiled without Python support. " + "Cannot access Python interpreter from C++."); + } +}; + +struct C10_API PyInterpreterHooksArgs{}; + +C10_DECLARE_REGISTRY( + PyInterpreterHooksRegistry, + PyInterpreterHooksInterface, + PyInterpreterHooksArgs); + +#define REGISTER_PYTHON_HOOKS(clsname) \ + C10_REGISTER_CLASS(PyInterpreterHooksRegistry, clsname, clsname) + +// Get the global PyInterpreter hooks instance +C10_API const PyInterpreterHooksInterface& getPyInterpreterHooks(); + +// Helper function to get the global interpreter +C10_API PyInterpreter* getGlobalPyInterpreter(); + +} // namespace c10::impl + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/impl/PyObjectSlot.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/impl/PyObjectSlot.h new file mode 100644 index 0000000000000000000000000000000000000000..8ba0688f66e597d4398d4a7d0407b2683ceb30aa --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/impl/PyObjectSlot.h @@ -0,0 +1,70 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include + +#include + +namespace torch::utils { +class PyObjectPreservation; +} + +namespace c10::impl { + +struct C10_API PyObjectSlot { + public: + PyObjectSlot() : pyobj_interpreter_(nullptr), pyobj_(nullptr) {} + + // Query the PyObject interpreter. This may return null if there is no + // interpreter. + PyInterpreter* pyobj_interpreter() const { + return pyobj_interpreter_.load(std::memory_order_acquire); + } + + PyInterpreter& load_pyobj_interpreter() const { + auto interpreter = pyobj_interpreter_.load(std::memory_order_acquire); + TORCH_INTERNAL_ASSERT( + interpreter, "cannot access PyObject for Tensor - no interpreter set"); + return *interpreter; + } + + PyObject* load_pyobj() const { + return pyobj_.load(std::memory_order_acquire); + } + + void store_pyobj(PyObject* obj) { + pyobj_.store(obj, std::memory_order_release); + } + + bool has_unique_reference() const { + PyObject* pyobj = load_pyobj(); + return pyobj != nullptr && load_pyobj_interpreter()->refcnt(pyobj) == 1; + } + + void clear() { + pyobj_.store(nullptr, std::memory_order_relaxed); + pyobj_interpreter_.store(nullptr, std::memory_order_relaxed); + } + + private: + // This is now always the global interpreter if the PyObject is set. + // Maybe we can remove this field some day... + std::atomic pyobj_interpreter_; + + // The PyObject representing this Tensor or nullptr. Ownership is managed + // by intrusive_ptr. By the time the PyObjectSlot is destroyed, this + // reference is already dead. + std::atomic pyobj_; + + friend class torch::utils::PyObjectPreservation; +}; + +} // namespace c10::impl + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/impl/PythonDispatcherTLS.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/impl/PythonDispatcherTLS.h new file mode 100644 index 0000000000000000000000000000000000000000..cffb7fc31e3d18b4544027b261b98c686f81274a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/impl/PythonDispatcherTLS.h @@ -0,0 +1,34 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +namespace c10::impl { + +struct C10_API PythonDispatcherTLS { + static void set_state(PyInterpreter* state); + static PyInterpreter* get_state(); + static void reset_state(); +}; + +struct C10_API DisablePythonDispatcher { + DisablePythonDispatcher() : old_(PythonDispatcherTLS::get_state()) { + PythonDispatcherTLS::set_state({}); + } + + DisablePythonDispatcher(DisablePythonDispatcher&& other) = delete; + DisablePythonDispatcher(const DisablePythonDispatcher&) = delete; + DisablePythonDispatcher& operator=(const DisablePythonDispatcher&) = delete; + DisablePythonDispatcher& operator=(DisablePythonDispatcher&&) = delete; + ~DisablePythonDispatcher() { + PythonDispatcherTLS::set_state(old_); + } + PyInterpreter* old_; +}; + +} // namespace c10::impl + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/impl/SizesAndStrides.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/impl/SizesAndStrides.h new file mode 100644 index 0000000000000000000000000000000000000000..da3a9a0c4abacf6165ca946e62257771cf2790ce --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/impl/SizesAndStrides.h @@ -0,0 +1,336 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +#include +#include +#include + +#define C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE 5 + +namespace c10::impl { + +// Packed container for TensorImpl sizes and strides. +// This design improves on the previous approach of using a pair of +// c10::SmallVector by specializing for the operations we +// actually use and enforcing that the number of sizes is the same as +// the number of strides. The memory layout is as follows: +// +// 1 size_t for the size +// 5 eightbytes of inline sizes and 5 eightbytes of inline strides, OR pointer +// to out-of-line array +class C10_API SizesAndStrides { + public: + // TODO: different iterator types for sizes & strides to prevent + // mixing the two accidentally. + using sizes_iterator = int64_t*; + using sizes_const_iterator = const int64_t*; + using strides_iterator = int64_t*; + using strides_const_iterator = const int64_t*; + + SizesAndStrides() { + size_at_unchecked(0) = 0; + stride_at_unchecked(0) = 1; + } + + ~SizesAndStrides() { + if (C10_UNLIKELY(!isInline())) { + // NOLINTNEXTLINE(cppcoreguidelines-no-malloc) + free(outOfLineStorage_); + } + } + + SizesAndStrides(const SizesAndStrides& rhs) : size_(rhs.size_) { + if (C10_LIKELY(rhs.isInline())) { + copyDataInline(rhs); + } else { + allocateOutOfLineStorage(size_); + copyDataOutline(rhs); + } + } + + bool operator==(const SizesAndStrides& other) const { + if (size_ != other.size_) { + return false; + } + return !( + isInline() + ? std::memcmp( + inlineStorage_, other.inlineStorage_, sizeof(inlineStorage_)) + : std::memcmp( + outOfLineStorage_, + other.outOfLineStorage_, + storageBytes(size_))); + } + + bool operator!=(const SizesAndStrides& other) const { + return !(*this == other); + } + + SizesAndStrides& operator=(const SizesAndStrides& rhs) { + if (this == &rhs) { + return *this; + } + if (C10_LIKELY(rhs.isInline())) { + if (C10_UNLIKELY(!isInline())) { + // NOLINTNEXTLINE(cppcoreguidelines-no-malloc) + free(outOfLineStorage_); + } + copyDataInline(rhs); + } else { + if (isInline()) { + allocateOutOfLineStorage(rhs.size_); + } else { + resizeOutOfLineStorage(rhs.size_); + } + copyDataOutline(rhs); + } + size_ = rhs.size_; + return *this; + } + + // Move from rhs. rhs.size() == 0 afterwards. + SizesAndStrides(SizesAndStrides&& rhs) noexcept : size_(rhs.size_) { + if (C10_LIKELY(isInline())) { + memcpy(inlineStorage_, rhs.inlineStorage_, sizeof(inlineStorage_)); + } else { + outOfLineStorage_ = rhs.outOfLineStorage_; + rhs.outOfLineStorage_ = nullptr; + } + + rhs.size_ = 0; + } + + // Move from rhs. rhs.size() == 0 afterwards. + SizesAndStrides& operator=(SizesAndStrides&& rhs) noexcept { + if (this == &rhs) { + return *this; + } + if (C10_LIKELY(rhs.isInline())) { + if (C10_UNLIKELY(!isInline())) { + // NOLINTNEXTLINE(cppcoreguidelines-no-malloc) + free(outOfLineStorage_); + } + copyDataInline(rhs); + } else { + // They're outline. We're going to steal their vector. + if (!isInline()) { + // NOLINTNEXTLINE(cppcoreguidelines-no-malloc) + free(outOfLineStorage_); + } + outOfLineStorage_ = rhs.outOfLineStorage_; + rhs.outOfLineStorage_ = nullptr; + } + size_ = rhs.size_; + rhs.size_ = 0; + + return *this; + } + + size_t size() const noexcept { + return size_; + } + + const int64_t* sizes_data() const noexcept { + if (C10_LIKELY(isInline())) { + return &inlineStorage_[0]; + } else { + return &outOfLineStorage_[0]; + } + } + + int64_t* sizes_data() noexcept { + if (C10_LIKELY(isInline())) { + return &inlineStorage_[0]; + } else { + return &outOfLineStorage_[0]; + } + } + + sizes_const_iterator sizes_begin() const noexcept { + return sizes_data(); + } + + sizes_iterator sizes_begin() noexcept { + return sizes_data(); + } + + sizes_const_iterator sizes_end() const noexcept { + return sizes_begin() + size(); + } + + sizes_iterator sizes_end() noexcept { + return sizes_begin() + size(); + } + + IntArrayRef sizes_arrayref() const noexcept { + return IntArrayRef{sizes_data(), size()}; + } + + void set_sizes(IntArrayRef newSizes) { + resize(newSizes.size()); + std::copy(newSizes.begin(), newSizes.end(), sizes_begin()); + } + + void set_strides(IntArrayRef strides) { + TORCH_INTERNAL_ASSERT(strides.size() == size()); + std::copy(strides.begin(), strides.end(), strides_begin()); + } + + const int64_t* strides_data() const noexcept { + if (C10_LIKELY(isInline())) { + return &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE]; + } else { + return &outOfLineStorage_[size()]; + } + } + + int64_t* strides_data() noexcept { + if (C10_LIKELY(isInline())) { + return &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE]; + } else { + return &outOfLineStorage_[size()]; + } + } + + strides_const_iterator strides_begin() const noexcept { + if (C10_LIKELY(isInline())) { + return &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE]; + } else { + return &outOfLineStorage_[size()]; + } + } + + strides_iterator strides_begin() noexcept { + if (C10_LIKELY(isInline())) { + return &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE]; + } else { + return &outOfLineStorage_[size()]; + } + } + + strides_const_iterator strides_end() const noexcept { + return strides_begin() + size(); + } + + strides_iterator strides_end() noexcept { + return strides_begin() + size(); + } + + IntArrayRef strides_arrayref() const noexcept { + return IntArrayRef{strides_data(), size()}; + } + + // Size accessors. + int64_t size_at(size_t idx) const noexcept { + assert(idx < size()); + return sizes_data()[idx]; + } + + int64_t& size_at(size_t idx) noexcept { + assert(idx < size()); + return sizes_data()[idx]; + } + + int64_t size_at_unchecked(size_t idx) const noexcept { + return sizes_data()[idx]; + } + + int64_t& size_at_unchecked(size_t idx) noexcept { + return sizes_data()[idx]; + } + + // Size accessors. + int64_t stride_at(size_t idx) const noexcept { + assert(idx < size()); + return strides_data()[idx]; + } + + int64_t& stride_at(size_t idx) noexcept { + assert(idx < size()); + return strides_data()[idx]; + } + + int64_t stride_at_unchecked(size_t idx) const noexcept { + return strides_data()[idx]; + } + + int64_t& stride_at_unchecked(size_t idx) noexcept { + return strides_data()[idx]; + } + + void resize(size_t newSize) { + const auto oldSize = size(); + if (newSize == oldSize) { + return; + } + if (C10_LIKELY( + newSize <= C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE && isInline())) { + if (oldSize < newSize) { + const auto bytesToZero = + (newSize - oldSize) * sizeof(inlineStorage_[0]); + memset(&inlineStorage_[oldSize], 0, bytesToZero); + memset( + &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE + oldSize], + 0, + bytesToZero); + } + size_ = newSize; + } else { + resizeSlowPath(newSize, oldSize); + } + } + + void resizeSlowPath(size_t newSize, size_t oldSize); + + private: + bool isInline() const noexcept { + return size_ <= C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE; + } + + void copyDataInline(const SizesAndStrides& rhs) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.isInline()); + memcpy(inlineStorage_, rhs.inlineStorage_, sizeof(inlineStorage_)); + } + + static size_t storageBytes(size_t size) noexcept { + return size * 2 * sizeof(int64_t); + } + + void allocateOutOfLineStorage(size_t size) { + // NOLINTNEXTLINE(cppcoreguidelines-no-malloc) + outOfLineStorage_ = static_cast(malloc(storageBytes(size))); + TORCH_CHECK( + outOfLineStorage_, + "Could not allocate memory for Tensor SizesAndStrides!"); + } + + void resizeOutOfLineStorage(size_t newSize) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!isInline()); + outOfLineStorage_ = static_cast( + // NOLINTNEXTLINE(cppcoreguidelines-no-malloc) + realloc(outOfLineStorage_, storageBytes(newSize))); + TORCH_CHECK( + outOfLineStorage_, + "Could not allocate memory for Tensor SizesAndStrides!"); + } + + void copyDataOutline(const SizesAndStrides& rhs) noexcept { + memcpy(outOfLineStorage_, rhs.outOfLineStorage_, storageBytes(rhs.size_)); + } + + size_t size_{1}; + union { + int64_t* outOfLineStorage_; + // NOLINTNEXTLINE(*c-array*) + int64_t inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE * 2]{}; + }; +}; + +} // namespace c10::impl + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/impl/TorchDispatchModeTLS.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/impl/TorchDispatchModeTLS.h new file mode 100644 index 0000000000000000000000000000000000000000..002bf4283806448b0cf9470116758b21fa5499e6 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/impl/TorchDispatchModeTLS.h @@ -0,0 +1,72 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +namespace c10::impl { + +enum class TorchDispatchModeKey : int8_t { + FAKE, + PROXY, + FUNCTIONAL, + NUM_MODE_KEYS +}; + +using PyObject_TorchDispatchMode = SafePyObjectT; + +struct C10_API TorchDispatchModeTLS { + // This API is NOT invariant safe. + // It must not take in an infra mode that uses TorchDispatchModeKey + // If you're pushing an infra mode onto the stack, we expect + // you to use set_mode + static void push_non_infra_mode_onto_stack( + std::shared_ptr mode); + // Pops the top mode of the stack, + // giving precedence to user modes before attempting to pop + // any infra modes + static const std::shared_ptr pop_stack(); + // Returns the highest-priority infra mode on the stack, + // along with its mode key. + static const std:: + tuple, TorchDispatchModeKey> + pop_highest_infra_mode(); + + static const std::shared_ptr& get_stack_at( + int64_t idx); + static int64_t stack_len(); + + static const std::optional> + get_mode(TorchDispatchModeKey mode_key); + static const std::optional> + unset_mode(TorchDispatchModeKey mode_key); + static void set_mode( + const std::shared_ptr& mode, + TorchDispatchModeKey mode_key); + + static const TorchDispatchModeTLS& get_state(); + static void set_state(TorchDispatchModeTLS state); + + static bool any_modes_set(bool skip_infra_modes = false); + + private: + std::vector> stack_; + // Users are allowed to push multiple ProxyTorchDispatchMode objects onto the + // stack + // However, we only allow a single FakeTensorMode onto the stack at a time + // (Pushing additional FakeTensorModes onto the stack is a no-op) + std::array< + std::optional>, + static_cast(TorchDispatchModeKey::NUM_MODE_KEYS)> + infra_modes_; +}; + +C10_API bool dispatch_mode_enabled(); + +C10_API std::string to_string(TorchDispatchModeKey mode_key); + +} // namespace c10::impl + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/impl/VirtualGuardImpl.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/impl/VirtualGuardImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..16b1970bfa1bbc7d6dc9c1a0463d17f3cb08b9fe --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/impl/VirtualGuardImpl.h @@ -0,0 +1,117 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include + +namespace c10::impl { + +/** + * An implementation of DeviceGuardImplInterface which delegates + * to virtual dispatch on the DeviceGuardImpl registry. + */ +class VirtualGuardImpl final : public DeviceGuardImplInterface { + public: + VirtualGuardImpl(DeviceType device_type) + : impl_(getDeviceGuardImpl(device_type)) {} + // This constructor exists purely for testing + VirtualGuardImpl(const DeviceGuardImplInterface* impl) : impl_(impl) {} + + // Copying and moving is OK! + VirtualGuardImpl(const VirtualGuardImpl&) = default; + VirtualGuardImpl& operator=(const VirtualGuardImpl&) = default; + VirtualGuardImpl(VirtualGuardImpl&&) noexcept = default; + VirtualGuardImpl& operator=(VirtualGuardImpl&&) noexcept = default; + ~VirtualGuardImpl() override = default; + + DeviceType type() const override { + return impl_->type(); + } + Device exchangeDevice(Device d) const override { + return impl_->exchangeDevice(d); + } + Device getDevice() const override { + return impl_->getDevice(); + } + void setDevice(Device d) const override { + impl_->setDevice(d); + } + void uncheckedSetDevice(Device d) const noexcept override { + impl_->uncheckedSetDevice(d); + } + Stream getStream(Device d) const override { + return impl_->getStream(d); + } + Stream getNewStream(Device d, int priority = 0) const override { + return impl_->getNewStream(d, priority); + } + Stream getDefaultStream(Device d) const override { + return impl_->getDefaultStream(d); + } + Stream getStreamFromGlobalPool(Device d, bool isHighPriority = false) + const override { + return impl_->getStreamFromGlobalPool(d, isHighPriority); + } + Stream exchangeStream(Stream s) const override { + return impl_->exchangeStream(s); + } + DeviceIndex deviceCount() const noexcept override { + return impl_->deviceCount(); + } + + DeviceCapability getDeviceCapability(Device d) const override { + return impl_->getDeviceCapability(d); + } + + // Event functions + void record( + void** event, + const Stream& stream, + const DeviceIndex device_index, + const EventFlag flag) const override { + impl_->record(event, stream, device_index, flag); + } + void block(void* event, const Stream& stream) const override { + impl_->block(event, stream); + } + bool queryEvent(void* event) const override { + return impl_->queryEvent(event); + } + void destroyEvent(void* event, const DeviceIndex device_index) + const noexcept override { + impl_->destroyEvent(event, device_index); + } + + bool queryStream(const Stream& stream) const override { + return impl_->queryStream(stream); + } + void synchronizeStream(const Stream& stream) const override { + impl_->synchronizeStream(stream); + } + + void recordDataPtrOnStream(const c10::DataPtr& data_ptr, const Stream& stream) + const override { + impl_->recordDataPtrOnStream(data_ptr, stream); + } + + double elapsedTime(void* event1, void* event2, const DeviceIndex device_index) + const override { + return impl_->elapsedTime(event1, event2, device_index); + } + + void synchronizeEvent(void* event) const override { + impl_->synchronizeEvent(event); + } + + void synchronizeDevice(const DeviceIndex device_index) const override { + impl_->synchronizeDevice(device_index); + } + + private: + const DeviceGuardImplInterface* impl_ = nullptr; +}; + +} // namespace c10::impl + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/impl/alloc_cpu.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/impl/alloc_cpu.h new file mode 100644 index 0000000000000000000000000000000000000000..ef28ed469f010d3aedeb5d68ad5405c2ffdaa055 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/impl/alloc_cpu.h @@ -0,0 +1,32 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +#include + +namespace c10 { + +C10_API void* alloc_cpu(size_t nbytes); +C10_API void free_cpu(void* data); + +#if defined(__linux__) && !defined(__ANDROID__) +C10_API size_t c10_compute_alignment(size_t nbytes); +#endif + +#ifdef USE_MIMALLOC_ON_MKL +namespace mi_malloc_wrapper { +C10_API void* c10_mi_malloc(size_t size); +C10_API void* c10_mi_calloc(size_t count, size_t size); +C10_API void* c10_mi_realloc(void* p, size_t newsize); +C10_API void* c10_mi_malloc_aligned(size_t size, size_t alignment); +C10_API void c10_mi_free(void* p); +} // namespace mi_malloc_wrapper +#endif + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/thread_pool.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/thread_pool.h new file mode 100644 index 0000000000000000000000000000000000000000..85b9a73d6bfa7bdf5a815c6e659f0c4af6bd8ef8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/core/thread_pool.h @@ -0,0 +1,125 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace c10 { + +class C10_API TaskThreadPoolBase { + public: + virtual void run(std::function func) = 0; + + virtual size_t size() const = 0; + + /** + * The number of available (i.e. idle) threads in this thread pool. + */ + virtual size_t numAvailable() const = 0; + + /** + * Check if the current thread is from the thread pool. + */ + virtual bool inThreadPool() const = 0; + + virtual ~TaskThreadPoolBase() noexcept = default; + + static size_t defaultNumThreads(); +}; + +class C10_API ThreadPool : public c10::TaskThreadPoolBase { + protected: + struct task_element_t { + bool run_with_id; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) + const std::function no_id; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) + const std::function with_id; + + explicit task_element_t(std::function f) + : run_with_id(false), no_id(std::move(f)), with_id(nullptr) {} + explicit task_element_t(std::function f) + : run_with_id(true), no_id(nullptr), with_id(std::move(f)) {} + }; + + std::queue tasks_; + std::vector threads_; + mutable std::mutex mutex_; + std::condition_variable condition_; + std::condition_variable completed_; + std::atomic_bool running_; + bool complete_; + std::size_t available_; + std::size_t total_; + int numa_node_id_; + + public: + ThreadPool() = delete; + + explicit ThreadPool( + int pool_size, + int numa_node_id = -1, + const std::function& init_thread = nullptr); + + ~ThreadPool() override; + + size_t size() const override; + + size_t numAvailable() const override; + + bool inThreadPool() const override; + + void run(std::function func) override; + + template + void runTaskWithID(Task task) { + std::unique_lock lock(mutex_); + + // Set task and signal condition variable so that a worker thread will + // wake up and use the task. + tasks_.emplace(static_cast>(task)); + complete_ = false; + condition_.notify_one(); + } + + /// @brief Wait for queue to be empty + void waitWorkComplete(); + + private: + // @brief Entry point for pool threads. + void main_loop(std::size_t index); +}; + +class C10_API TaskThreadPool : public c10::ThreadPool { + public: + explicit TaskThreadPool(int pool_size, int numa_node_id = -1) + : ThreadPool(pool_size, numa_node_id, [numa_node_id]() { + setThreadName("CaffeTaskThread"); + NUMABind(numa_node_id); + }) {} +}; + +C10_DECLARE_SHARED_REGISTRY( + ThreadPoolRegistry, + TaskThreadPoolBase, + int, + int, + bool); + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/cuda/CUDAAlgorithm.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/cuda/CUDAAlgorithm.h new file mode 100644 index 0000000000000000000000000000000000000000..62995e142a3e84bf83e2e7143cdc6bc8eb67f91f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/cuda/CUDAAlgorithm.h @@ -0,0 +1,36 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#ifdef THRUST_DEVICE_LOWER_BOUND_WORKS +#include +#include +#include +#include +#endif +namespace c10::cuda { +#ifdef THRUST_DEVICE_LOWER_BOUND_WORKS +template +__forceinline__ __device__ Iter +lower_bound(Iter start, Iter end, Scalar value) { + return thrust::lower_bound(thrust::device, start, end, value); +} +#else +// thrust::lower_bound is broken on device, see +// https://github.com/NVIDIA/thrust/issues/1734 Implementation inspired by +// https://github.com/pytorch/pytorch/blob/805120ab572efef66425c9f595d9c6c464383336/aten/src/ATen/native/cuda/Bucketization.cu#L28 +template +__device__ Iter lower_bound(Iter start, Iter end, Scalar value) { + while (start < end) { + auto mid = start + ((end - start) >> 1); + if (*mid < value) { + start = mid + 1; + } else { + end = mid; + } + } + return end; +} +#endif // THRUST_DEVICE_LOWER_BOUND_WORKS +} // namespace c10::cuda + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/cuda/CUDAAllocatorConfig.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/cuda/CUDAAllocatorConfig.h new file mode 100644 index 0000000000000000000000000000000000000000..286eb3daecb5aa73711392c839776ab5e0444275 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/cuda/CUDAAllocatorConfig.h @@ -0,0 +1,211 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace c10::cuda::CUDACachingAllocator { + +enum class Expandable_Segments_Handle_Type : int { + UNSPECIFIED = 0, + POSIX_FD = 1, + FABRIC_HANDLE = 2, +}; + +// Environment config parser +class C10_CUDA_API CUDAAllocatorConfig { + public: + C10_DEPRECATED_MESSAGE( + "c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::max_split_size() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::max_split_size() instead.") + static size_t max_split_size() { + return c10::CachingAllocator::AcceleratorAllocatorConfig::max_split_size(); + } + + C10_DEPRECATED_MESSAGE( + "c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::garbage_collection_threshold() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::garbage_collection_threshold() instead.") + static double garbage_collection_threshold() { + return c10::CachingAllocator::AcceleratorAllocatorConfig:: + garbage_collection_threshold(); + } + + static bool expandable_segments() { + bool enabled = c10::CachingAllocator::AcceleratorAllocatorConfig:: + use_expandable_segments(); +#ifndef PYTORCH_C10_DRIVER_API_SUPPORTED + if (enabled) { + TORCH_WARN_ONCE("expandable_segments not supported on this platform") + } + return false; +#else + return enabled; +#endif + } + + static Expandable_Segments_Handle_Type expandable_segments_handle_type() { + return instance().m_expandable_segments_handle_type; + } + + static void set_expandable_segments_handle_type( + Expandable_Segments_Handle_Type handle_type) { + instance().m_expandable_segments_handle_type = handle_type; + } + + static bool release_lock_on_cudamalloc() { + return instance().m_release_lock_on_cudamalloc; + } + + static bool graph_capture_record_stream_reuse() { + return instance().m_graph_capture_record_stream_reuse; + } + + static double per_process_memory_fraction() { + return instance().m_per_process_memory_fraction; + } + + /** Pinned memory allocator settings */ + static bool pinned_use_cuda_host_register() { + return instance().m_pinned_use_cuda_host_register; + } + + static size_t pinned_num_register_threads() { + return instance().m_pinned_num_register_threads; + } + + C10_DEPRECATED_MESSAGE( + "c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::pinned_use_background_threads() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::pinned_use_background_threads() instead.") + static bool pinned_use_background_threads() { + return c10::CachingAllocator::AcceleratorAllocatorConfig:: + pinned_use_background_threads(); + } + + static size_t pinned_reserve_segment_size_mb() { + return instance().m_pinned_reserve_segment_size_mb; + } + + static size_t pinned_max_register_threads() { + // Based on the benchmark results, we see better allocation performance + // with 8 threads. However on future systems, we may need more threads + // and limiting this to 128 threads. + return 128; + } + + C10_DEPRECATED_MESSAGE( + "c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::roundup_power2_divisions() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::roundup_power2_divisions() instead.") + static size_t roundup_power2_divisions(size_t size) { + return c10::CachingAllocator::AcceleratorAllocatorConfig:: + roundup_power2_divisions(size); + } + + C10_DEPRECATED_MESSAGE( + "c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::roundup_power2_divisions() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::roundup_power2_divisions() instead.") + static std::vector roundup_power2_divisions() { + return c10::CachingAllocator::AcceleratorAllocatorConfig:: + roundup_power2_divisions(); + } + + static size_t max_non_split_rounding_size() { + return c10::CachingAllocator::AcceleratorAllocatorConfig:: + max_non_split_rounding_size(); + } + + C10_DEPRECATED_MESSAGE( + "c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::last_allocator_settings() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::last_allocator_settings() instead.") + static std::string last_allocator_settings() { + return c10::CachingAllocator::getAllocatorSettings(); + } + + static CUDAAllocatorConfig& instance() { + static CUDAAllocatorConfig* s_instance = ([]() { + auto inst = new CUDAAllocatorConfig(); + auto env = c10::utils::get_env("PYTORCH_CUDA_ALLOC_CONF"); +#ifdef USE_ROCM + // convenience for ROCm users, allow alternative HIP token + if (!env.has_value()) { + env = c10::utils::get_env("PYTORCH_HIP_ALLOC_CONF"); + } +#endif + // Note: keep the parsing order and logic stable to avoid potential + // performance regressions in internal tests. + if (!env.has_value()) { + env = c10::utils::get_env("PYTORCH_ALLOC_CONF"); + } + if (env.has_value()) { + inst->parseArgs(env.value()); + } + return inst; + })(); + return *s_instance; + } + + // Use `Construct On First Use Idiom` to avoid `Static Initialization Order` + // issue. + static const std::unordered_set& getKeys() { + static std::unordered_set keys{ + "backend", + // keep BC for Rocm: `cuda` -> `cud` `a`, to avoid hipify issues + // NOLINTBEGIN(bugprone-suspicious-missing-comma,-warnings-as-errors) + "release_lock_on_cud" + "amalloc", + "pinned_use_cud" + "a_host_register", + // NOLINTEND(bugprone-suspicious-missing-comma,-warnings-as-errors) + "release_lock_on_hipmalloc", + "pinned_use_hip_host_register", + "graph_capture_record_stream_reuse", + "pinned_reserve_segment_size_mb", + "pinned_num_register_threads", + "per_process_memory_fraction"}; + return keys; + } + + void parseArgs(const std::string& env); + + private: + CUDAAllocatorConfig() = default; + + size_t parseAllocatorConfig( + const c10::CachingAllocator::ConfigTokenizer& tokenizer, + size_t i, + bool& used_cudaMallocAsync); + size_t parsePinnedUseCudaHostRegister( + const c10::CachingAllocator::ConfigTokenizer& tokenizer, + size_t i); + size_t parsePinnedNumRegisterThreads( + const c10::CachingAllocator::ConfigTokenizer& tokenizer, + size_t i); + size_t parsePinnedReserveSegmentSize( + const c10::CachingAllocator::ConfigTokenizer& tokenizer, + size_t i); + size_t parseGraphCaptureRecordStreamReuse( + const c10::CachingAllocator::ConfigTokenizer& tokenizer, + size_t i); + double parsePerProcessMemoryFraction( + const c10::CachingAllocator::ConfigTokenizer& tokenizer, + size_t i); + + std::atomic m_pinned_num_register_threads{1}; + std::atomic m_pinned_reserve_segment_size_mb{0}; + std::atomic m_expandable_segments_handle_type +#if CUDA_VERSION >= 12030 + {Expandable_Segments_Handle_Type::UNSPECIFIED}; +#else + {Expandable_Segments_Handle_Type::POSIX_FD}; +#endif + std::atomic m_release_lock_on_cudamalloc{false}; + std::atomic m_pinned_use_cuda_host_register{false}; + std::atomic m_graph_capture_record_stream_reuse{false}; + std::atomic m_per_process_memory_fraction{1.0}; +}; + +// Keep this for backwards compatibility +using c10::CachingAllocator::setAllocatorSettings; + +} // namespace c10::cuda::CUDACachingAllocator + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/cuda/CUDACachingAllocator.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/cuda/CUDACachingAllocator.h new file mode 100644 index 0000000000000000000000000000000000000000..b425157814aa15296d38633501e47035e2804130 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/cuda/CUDACachingAllocator.h @@ -0,0 +1,582 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace c10 { + +// Caching allocator will execute every registered callback if it unable to find +// block inside of already allocated area. +class C10_CUDA_API FreeMemoryCallback { + public: + virtual ~FreeMemoryCallback() = default; + virtual bool Execute() = 0; +}; + +C10_DECLARE_REGISTRY(FreeCudaMemoryCallbacksRegistry, FreeMemoryCallback); +#define REGISTER_FREE_MEMORY_CALLBACK(name, ...) \ + C10_REGISTER_CLASS(FreeCudaMemoryCallbacksRegistry, name, __VA_ARGS__) +} // namespace c10 + // +// TODO: Turn this into an honest to goodness class. I briefly attempted to do +// this, but it was a bit irritating to figure out how to also correctly +// apply pimpl pattern so I didn't have to leak any internal implementation +// details in the header (CUDACachingAllocator could be made a pimpl, but +// you also need to appropriately define a class which is a subclass +// of Allocator. Not impossible, but required a bit more surgery than +// I wanted to do at the time.) +// +// Why is this using a namespace rather than old-style THCCachingAllocator_ +// prefix? Mostly because it made the HIPify rules easier to write; _ is +// not counted as a word boundary, so you would otherwise have to list each +// of these functions. + +namespace c10::cuda::CUDACachingAllocator { + +// Preserved only for BC reasons +// NOLINTNEXTLINE(misc-unused-using-decls) +using c10::CachingAllocator::kLargeBuffer; +using c10::CachingDeviceAllocator::DeviceStats; + +typedef std::shared_ptr (*CreateContextFn)(); + +// Struct containing info of an allocation block (i.e. a fractional part of a +// cudaMalloc).. +struct BlockInfo { + size_t size = 0; + size_t requested_size = 0; + int32_t gc_counter = 0; + bool allocated = false; + bool active = false; + std::shared_ptr + context_when_allocated; // per-watcher context +}; + +// Struct containing info of a memory segment (i.e. one contiguous cudaMalloc). +struct SegmentInfo { + c10::DeviceIndex device = 0; + size_t address = 0; + size_t total_size = 0; + size_t requested_size = 0; // unrounded, actually requested size + size_t allocated_size = 0; + size_t active_size = 0; + cudaStream_t stream = nullptr; + bool is_large = false; + bool is_expandable = false; + MempoolId_t owner_private_pool_id = {0, 0}; + std::vector blocks; + std::shared_ptr context_when_allocated; +}; + +struct AllocatorState { + virtual ~AllocatorState() = default; +}; + +union trace_time_ { + time_t t_; + approx_time_t approx_t_; +}; + +struct TraceEntry { + enum Action { + ALLOC, // API made to the caching allocator for new memory + FREE_REQUESTED, // API call made to the caching allocator to free memory + FREE_COMPLETED, // The allocator might have to delay a free because + // it is still in use on another stream via record_stream + // This event is generated when a free actually completes. + SEGMENT_ALLOC, // a call to cudaMalloc to get more memory from the OS + SEGMENT_FREE, // a call to cudaFree to return memory to the OS (e.g. to + // defragment or empty_caches) + SEGMENT_MAP, // a call to cuMemMap (used with expandable_segments) + SEGMENT_UNMAP, // unmap part of a segment (used with expandable segments) + SNAPSHOT, // a call to snapshot, used to correlate memory snapshots to trace + // events + OOM // the allocator threw an OutOfMemoryError (addr_ is the amount of free + // bytes reported by cuda) + }; + TraceEntry( + Action action, + c10::DeviceIndex device, + size_t addr, + size_t size, + cudaStream_t stream, + MempoolId_t mempool, + approx_time_t time, + std::shared_ptr context = nullptr, + std::string compile_context = "", + std::string user_metadata = "") + : action_(action), + device_(device), + addr_(addr), + context_(std::move(context)), + stream_(stream), + size_(size), + mempool_(std::move(mempool)), + compile_context_(std::move(compile_context)), + user_metadata_(std::move(user_metadata)) { + time_.approx_t_ = time; + } + Action action_; + c10::DeviceIndex device_; + size_t addr_; // for OOM, this is the amount of free bytes reported by cuda + std::shared_ptr context_; + cudaStream_t stream_{}; + size_t size_; + MempoolId_t mempool_; + trace_time_ time_{}; + std::string compile_context_; + std::string user_metadata_; +}; + +// Calls made by record_function will save annotations +struct AnnotationEntry { + AnnotationEntry(c10::DeviceIndex device, approx_time_t time) + : device_(device) { + time_.approx_t_ = time; + } + + void recordUserMetadata(const std::string& name, std::string value) { + metadata_[name] = std::move(value); + } + + c10::DeviceIndex device_; + trace_time_ time_{}; + std::unordered_map metadata_; +}; + +struct AllocatorConfigInfo { + double garbage_collection_threshold; + size_t max_split_size; + size_t pinned_num_register_threads; + bool expandable_segments; + bool release_lock_on_malloc; + bool pinned_use_host_register; + bool graph_capture_record_stream_reuse; + std::string last_allocator_settings; + std::vector roundup_power2_divisions; +}; + +struct SnapshotInfo { + std::vector segments; + std::vector> device_traces; + std::vector external_annotations; + AllocatorConfigInfo config_metadata; +}; + +// returns the pointers freed in the pool +// and the pointers allocated. Note: a pointer +// may appear in both freed and allocated +struct CheckpointDelta { + std::vector ptrs_freed; + std::vector dataptrs_allocd; +}; + +enum struct RecordContext { + NEVER = 0, + STATE = 1, // only keep stacks for active allocations + ALLOC = 2, // additionally keep stacks for allocations in the trace history + ALL = 3, // additionally record stacks for when something is freed +}; + +using OutOfMemoryObserver = std::function; + +using AllocatorTraceTracker = std::function; + +struct ShareableHandle { + ptrdiff_t offset; + std::string handle; +}; + +struct StreamSegmentSize { + StreamSegmentSize(cudaStream_t s, bool small, size_t sz) + : stream(s), is_small_pool(small), total_size(sz) {} + cudaStream_t stream; + bool is_small_pool; + size_t total_size; +}; + +class CUDAAllocator : public DeviceAllocator { + public: + virtual void* raw_alloc(size_t nbytes) = 0; + virtual void* raw_alloc_with_stream(size_t nbytes, cudaStream_t stream) = 0; + virtual void raw_delete(void* ptr) = 0; + virtual void init(int device_count) = 0; + virtual double getMemoryFraction(c10::DeviceIndex device) = 0; + virtual void setMemoryFraction(double fraction, c10::DeviceIndex device) = 0; + virtual std::vector getExpandableSegmentSizes( + c10::DeviceIndex device) = 0; + virtual void enable(bool value) = 0; + virtual bool isEnabled() const = 0; + virtual void cacheInfo(c10::DeviceIndex device, size_t* largestBlock) = 0; + virtual void* getBaseAllocation(void* ptr, size_t* size) = 0; + // Keep for BC only + virtual void recordStream(const DataPtr& ptr, CUDAStream stream) = 0; + void recordStream(const DataPtr& ptr, c10::Stream stream) override { + CUDAStream cuda_stream = CUDAStream(stream); + recordStream(ptr, cuda_stream); + } + virtual SnapshotInfo snapshot(MempoolId_t mempool_id = {0, 0}) = 0; + virtual void beginAllocateToPool( + c10::DeviceIndex device, + MempoolId_t mempool_id, + std::function filter) = 0; + virtual void endAllocateToPool( + c10::DeviceIndex device, + MempoolId_t mempool_id) = 0; + virtual void releasePool(c10::DeviceIndex device, MempoolId_t mempool_id) = 0; + virtual int getPoolUseCount( + c10::DeviceIndex /*device*/, + MempoolId_t /*mempool_id*/) { + TORCH_CHECK( + false, + name(), + " does not yet support getPoolUseCount. " + "If you need it, please file an issue describing your use case."); + } + virtual void createOrIncrefPool( + c10::DeviceIndex /*device*/, + MempoolId_t /*mempool_id*/, + CUDAAllocator* allocator = nullptr) { + TORCH_CHECK( + false, + name(), + " does not yet support createOrIncrefPool. " + "If you need it, please file an issue describing your use case."); + } + virtual void setUseOnOOM(c10::DeviceIndex device, MempoolId_t mempool_id) { + TORCH_CHECK( + false, + name(), + " does not yet support setUseOnOOM. " + "If you need it, please file an issue describing your use case."); + } + virtual void setNoSplit(c10::DeviceIndex device, MempoolId_t mempool_id) { + TORCH_CHECK( + false, + name(), + " does not yet support setNoSplit. " + "If you need it, please file an issue describing your use case."); + } + + // returns true if the allocated blocks are equal to expected live allocations + virtual bool checkPoolLiveAllocations( + c10::DeviceIndex /*device*/, + MempoolId_t /*mempool_id*/, + const std::unordered_set& /*expected_live_allocations*/) { + TORCH_CHECK( + false, + name(), + " does not yet support checkPoolLiveAllocations. " + "If you need it, please file an issue describing your use case."); + } + virtual ShareableHandle shareIpcHandle(void* ptr) = 0; + virtual std::shared_ptr getIpcDevPtr(std::string handle) = 0; + virtual bool isHistoryEnabled() { + TORCH_CHECK( + false, + name(), + " does not yet support recordHistory. " + "If you need it, please file an issue describing your use case."); + } + virtual void recordHistory( + bool enabled, + CreateContextFn context_recorder, + size_t alloc_trace_max_entries, + RecordContext when, + bool clearHistory) = 0; + virtual void recordAnnotation( + const std::vector>& /*md*/) {} + virtual void pushCompileContext(std::string& md) {} + virtual void popCompileContext() {} + virtual void setUserMetadata(const std::string& metadata) {} + virtual std::string getUserMetadata() { + return ""; + } + virtual void attachOutOfMemoryObserver(OutOfMemoryObserver observer) = 0; + + // Attached AllocatorTraceTracker callbacks will be called while the + // per-device allocator lock is held. Any additional locks taken from within + // the callback must be proven to always have the lock order that never + // triggers a deadlock. In particular, Python's GIL may be held when + // calling the allocator so it is unsafe to try to acquire the GIL in this + // callback. + virtual void attachAllocatorTraceTracker(AllocatorTraceTracker tracker) = 0; + + virtual void enablePeerAccess( + c10::DeviceIndex dev, + c10::DeviceIndex dev_to_access) = 0; + + // memory not allocated from cudaMalloc cannot be copied + // across devices using cudaMemcpyAsync if peer to peer access is disabled. + // instead it requires cudaMemcpyAsyncPeer + // with P2P Enabled, all combinations work + // with P2P Disabled: + // cudaMalloc cudaMallocAsync/cuMemMap + // cudaMemcpyAsyncPeer works works + // cudaMemcpyAsync works error + + // This function performs chooses to use the Peer version of + // memcpy if required based on where the allocated put dst/src. + virtual cudaError_t memcpyAsync( + void* dst, + int dstDevice, + const void* src, + int srcDevice, + size_t count, + cudaStream_t stream, + bool p2p_enabled) = 0; + virtual std::shared_ptr getCheckpointState( + c10::DeviceIndex device, + MempoolId_t id) = 0; + virtual CheckpointDelta setCheckpointPoolState( + c10::DeviceIndex device, + std::shared_ptr pps) = 0; + virtual std::string name() = 0; + std::pair getMemoryInfo(c10::DeviceIndex device) override { + c10::DeviceGuard device_guard({at::kCUDA, device}); + size_t free = 0; + size_t total = 0; + C10_CUDA_CHECK(cudaMemGetInfo(&free, &total)); + return {free, total}; + } +}; + +// Allocator object, statically initialized +// See BackendInitializer in CUDACachingAllocator.cpp. +// Atomic loads on x86 are just normal loads, +// (atomic stores are different), so reading this value +// is no different than loading a pointer. +C10_CUDA_API extern std::atomic allocator; + +inline CUDAAllocator* get() { + return allocator.load(); +} + +// Called directly by clients. +inline void* raw_alloc(size_t nbytes) { + return get()->raw_alloc(nbytes); +} + +inline void* raw_alloc_with_stream(size_t nbytes, cudaStream_t stream) { + return get()->raw_alloc_with_stream(nbytes, stream); +} + +inline void raw_delete(void* ptr) { + get()->raw_delete(ptr); +} + +inline void init(int device_count) { + get()->init(device_count); +} + +inline double getMemoryFraction(c10::DeviceIndex device) { + return get()->getMemoryFraction(device); +} + +inline void setMemoryFraction(double fraction, c10::DeviceIndex device) { + get()->setMemoryFraction(fraction, device); +} + +inline std::vector getExpandableSegmentSizes( + c10::DeviceIndex device) { + return get()->getExpandableSegmentSizes(device); +} + +inline void emptyCache(MempoolId_t mempool_id = {0, 0}) { + get()->emptyCache(mempool_id); +} + +inline void enable(bool value) { + get()->enable(value); +} + +inline bool isEnabled() { + return get()->isEnabled(); +} + +inline void cacheInfo(c10::DeviceIndex device, size_t* largestBlock) { + get()->cacheInfo(device, largestBlock); +} + +inline void* getBaseAllocation(void* ptr, size_t* size) { + return get()->getBaseAllocation(ptr, size); +} + +inline void recordStream(const DataPtr& dataPtr, CUDAStream stream) { + get()->recordStream(dataPtr, stream); +} + +inline c10::CachingDeviceAllocator::DeviceStats getDeviceStats( + c10::DeviceIndex device) { + return get()->getDeviceStats(device); +} + +inline void resetAccumulatedStats(c10::DeviceIndex device) { + get()->resetAccumulatedStats(device); +} + +inline void resetPeakStats(c10::DeviceIndex device) { + get()->resetPeakStats(device); +} + +inline SnapshotInfo snapshot(MempoolId_t mempool_id = {0, 0}) { + return get()->snapshot(mempool_id); +} + +inline std::shared_ptr getCheckpointState( + c10::DeviceIndex device, + MempoolId_t id) { + return get()->getCheckpointState(device, id); +} + +inline CheckpointDelta setCheckpointPoolState( + c10::DeviceIndex device, + std::shared_ptr pps) { + return get()->setCheckpointPoolState(device, std::move(pps)); +} + +// CUDAGraph interactions +inline void beginAllocateToPool( + c10::DeviceIndex device, + MempoolId_t mempool_id, + std::function filter) { + get()->beginAllocateToPool(device, mempool_id, std::move(filter)); +} + +inline void endAllocateToPool(c10::DeviceIndex device, MempoolId_t mempool_id) { + get()->endAllocateToPool(device, mempool_id); +} + +inline void recordHistory( + bool enabled, + CreateContextFn context_recorder, + size_t alloc_trace_max_entries, + RecordContext when, + bool clearHistory) { + get()->recordHistory( + enabled, context_recorder, alloc_trace_max_entries, when, clearHistory); +} + +inline void recordAnnotation( + const std::vector>& md) { + get()->recordAnnotation(md); +} + +inline void pushCompileContext(std::string& md) { + get()->pushCompileContext(md); +} + +inline void popCompileContext() { + get()->popCompileContext(); +} + +inline bool isHistoryEnabled() { + return get()->isHistoryEnabled(); +} + +inline bool checkPoolLiveAllocations( + c10::DeviceIndex device, + MempoolId_t mempool_id, + const std::unordered_set& expected_live_allocations) { + return get()->checkPoolLiveAllocations( + device, mempool_id, expected_live_allocations); +} + +inline void attachOutOfMemoryObserver(OutOfMemoryObserver observer) { + get()->attachOutOfMemoryObserver(std::move(observer)); +} + +inline void attachAllocatorTraceTracker(AllocatorTraceTracker tracker) { + get()->attachAllocatorTraceTracker(std::move(tracker)); +} + +inline void releasePool(c10::DeviceIndex device, MempoolId_t mempool_id) { + get()->releasePool(device, mempool_id); +} +inline void createOrIncrefPool( + c10::DeviceIndex device, + MempoolId_t mempool_id, + CUDAAllocator* allocator_ptr = nullptr) { + get()->createOrIncrefPool(device, mempool_id, allocator_ptr); +} +inline void setUseOnOOM(c10::DeviceIndex device, MempoolId_t mempool_id) { + get()->setUseOnOOM(device, mempool_id); +} +inline void setNoSplit(c10::DeviceIndex device, MempoolId_t mempool_id) { + get()->setNoSplit(device, mempool_id); +} +inline int getPoolUseCount(c10::DeviceIndex device, MempoolId_t mempool_id) { + return get()->getPoolUseCount(device, mempool_id); +} + +// Not part of CUDA_ALLOCATOR_BACKEND_INTERFACE +inline std::shared_ptr getIpcDevPtr(std::string handle) { + return get()->getIpcDevPtr(std::move(handle)); +} + +inline ShareableHandle shareIpcHandle(void* ptr) { + return get()->shareIpcHandle(ptr); +} + +inline std::string name() { + return get()->name(); +} + +inline cudaError_t memcpyAsync( + void* dst, + int dstDevice, + const void* src, + int srcDevice, + size_t count, + cudaStream_t stream, + bool p2p_enabled) { + return get()->memcpyAsync( + dst, dstDevice, src, srcDevice, count, stream, p2p_enabled); +} + +inline void enablePeerAccess( + c10::DeviceIndex dev, + c10::DeviceIndex dev_to_access) { + get()->enablePeerAccess(dev, dev_to_access); +} + +inline void setUserMetadata(const std::string& metadata) { + get()->setUserMetadata(metadata); +} + +inline std::string getUserMetadata() { + return get()->getUserMetadata(); +} + +} // namespace c10::cuda::CUDACachingAllocator + +namespace c10::cuda { +// Keep BC only +using c10::CaptureId_t; +using c10::MempoolId_t; +} // namespace c10::cuda + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/cuda/CUDADeviceAssertion.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/cuda/CUDADeviceAssertion.h new file mode 100644 index 0000000000000000000000000000000000000000..294734601cb78d68aff50da939b3452c948adb80 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/cuda/CUDADeviceAssertion.h @@ -0,0 +1,103 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +namespace c10::cuda { + +#ifdef TORCH_USE_CUDA_DSA +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-function") +// Copy string from `src` to `dst` +static __device__ void dstrcpy(char* dst, const char* src) { + int i = 0; + // Copy string from source to destination, ensuring that it + // isn't longer than `C10_CUDA_DSA_MAX_STR_LEN-1` + while (*src != '\0' && i++ < C10_CUDA_DSA_MAX_STR_LEN - 1) { + *dst++ = *src++; + } + *dst = '\0'; +} + +static __device__ void dsa_add_new_assertion_failure( + DeviceAssertionsData* assertions_data, + const char* assertion_msg, + const char* filename, + const char* function_name, + const int line_number, + const uint32_t caller, + const dim3 block_id, + const dim3 thread_id) { + // `assertions_data` may be nullptr if device-side assertion checking + // is disabled at run-time. If it is disabled at compile time this + // function will never be called + if (!assertions_data) { + return; + } + + // Atomically increment so other threads can fail at the same time + // Note that incrementing this means that the CPU can observe that + // a failure has happened and can begin to respond before we've + // written information about that failure out to the buffer. + const auto nid = atomicAdd(&(assertions_data->assertion_count), 1); + + if (nid >= C10_CUDA_DSA_ASSERTION_COUNT) { + // At this point we're ran out of assertion buffer space. + // We could print a message about this, but that'd get + // spammy if a lot of threads did it, so we just silently + // ignore any other assertion failures. In most cases the + // failures will all probably be analogous anyway. + return; + } + + // Write information about the assertion failure to memory. + // Note that this occurs only after the `assertion_count` + // increment broadcasts that there's been a problem. + auto& self = assertions_data->assertions[nid]; + dstrcpy(self.assertion_msg, assertion_msg); + dstrcpy(self.filename, filename); + dstrcpy(self.function_name, function_name); + self.line_number = line_number; + self.caller = caller; + self.block_id[0] = block_id.x; + self.block_id[1] = block_id.y; + self.block_id[2] = block_id.z; + self.thread_id[0] = thread_id.x; + self.thread_id[1] = thread_id.y; + self.thread_id[2] = thread_id.z; +} +C10_CLANG_DIAGNOSTIC_POP() + +// Emulates a kernel assertion. The assertion won't stop the kernel's progress, +// so you should assume everything the kernel produces is garbage if there's an +// assertion failure. +// NOTE: This assumes that `assertions_data` and `assertion_caller_id` are +// arguments of the kernel and therefore accessible. +#define CUDA_KERNEL_ASSERT2(condition) \ + do { \ + if (C10_UNLIKELY(!(condition))) { \ + /* Has an atomic element so threads can fail at the same time */ \ + c10::cuda::dsa_add_new_assertion_failure( \ + assertions_data, \ + C10_STRINGIZE(condition), \ + __FILE__, \ + __FUNCTION__, \ + __LINE__, \ + assertion_caller_id, \ + blockIdx, \ + threadIdx); \ + /* Now that the kernel has failed we early exit the kernel, but */ \ + /* otherwise keep going and rely on the host to check UVM and */ \ + /* determine we've had a problem */ \ + return; \ + } \ + } while (false) +#else +#define CUDA_KERNEL_ASSERT2(condition) assert(condition) +#endif + +} // namespace c10::cuda + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/cuda/CUDADeviceAssertionHost.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/cuda/CUDADeviceAssertionHost.h new file mode 100644 index 0000000000000000000000000000000000000000..2d4921a100a1c73e2fd5a69284cd92435b7f70f4 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/cuda/CUDADeviceAssertionHost.h @@ -0,0 +1,169 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include + +#include +#include +#include +#include +#include +#include + +#ifdef USE_CUDA +#define TORCH_USE_CUDA_DSA +#endif + +/// Number of assertion failure messages we can store. If this is too small +/// threads will fail silently. +constexpr int C10_CUDA_DSA_ASSERTION_COUNT = 10; +constexpr int C10_CUDA_DSA_MAX_STR_LEN = 512; + +namespace c10::cuda { + +/// Holds information about any device-side assertions that fail. +/// Held in managed memory and access by both the CPU and the GPU. +struct DeviceAssertionData { + /// Stringification of the assertion + // NOLINTNEXTLINE(*-c-arrays) + char assertion_msg[C10_CUDA_DSA_MAX_STR_LEN]{}; + /// File the assertion was in + // NOLINTNEXTLINE(*-c-arrays) + char filename[C10_CUDA_DSA_MAX_STR_LEN]{}; + /// Name of the function the assertion was in + // NOLINTNEXTLINE(*-c-arrays) + char function_name[C10_CUDA_DSA_MAX_STR_LEN]{}; + /// Line number the assertion was at + int line_number{}; + /// Number uniquely identifying the kernel launch that triggered the assertion + uint32_t caller{}; + /// block_id of the thread that failed the assertion + // NOLINTNEXTLINE(*-c-arrays) + int32_t block_id[3]{}; + /// third_id of the thread that failed the assertion + // NOLINTNEXTLINE(*-c-arrays) + int32_t thread_id[3]{}; +}; + +/// Used to hold assertions generated by the device +/// Held in managed memory and access by both the CPU and the GPU. +struct DeviceAssertionsData { + /// Total number of assertions found; a subset of these will be recorded + /// in `assertions` + int32_t assertion_count{}; + /// An array of assertions that will be written to in a race-free manner + // NOLINTNEXTLINE(*-c-arrays) + DeviceAssertionData assertions[C10_CUDA_DSA_ASSERTION_COUNT]{}; +}; + +/// Use to hold info about kernel launches so that we can run kernels +/// asynchronously and still associate launches with device-side +/// assertion failures +struct CUDAKernelLaunchInfo { + /// Filename of the code where the kernel was launched from + const char* launch_filename; + /// Function from which the kernel was launched + const char* launch_function; + /// Line number of where the code was launched from + uint32_t launch_linenum; + /// Backtrace of where the kernel was launched from, only populated if + /// CUDAKernelLaunchRegistry::gather_launch_stacktrace is True + std::string launch_stacktrace; + /// Kernel that was launched + const char* kernel_name; + /// Device the kernel was launched on + int device; + /// Stream the kernel was launched on + int32_t stream; + /// A number that uniquely identifies the kernel launch + uint64_t generation_number; +}; + +/// Circular buffer used to hold information about kernel launches +/// this is later used to reconstruct how a device-side kernel assertion failure +/// occurred CUDAKernelLaunchRegistry is used as a singleton +class C10_CUDA_API CUDAKernelLaunchRegistry { + private: + /// Assume that this is the max number of kernel launches that might ever be + /// enqueued across all streams on a single device + static constexpr int max_kernel_launches = 1024; + /// How many kernel launch infos we've inserted. Used to ensure that circular + /// queue doesn't provide false information by always increasing, but also to + /// mark where we are inserting into the queue +#ifdef TORCH_USE_CUDA_DSA + uint64_t generation_number = 0; +#endif + /// Shared mutex between writer and accessor to ensure multi-threaded safety. + mutable std::mutex read_write_mutex; + /// Used to ensure prevent race conditions in GPU memory allocation + mutable std::mutex gpu_alloc_mutex; + /// Pointer to managed memory keeping track of device-side assertions. There + /// is one entry for each possible device the process might work with. Unused + /// entries are nullptrs. We could also use an unordered_set here, but this + /// vector design will be faster and the wasted memory is small since we + /// expect the number of GPUs per node will always be small + std::vector< + std::unique_ptr> + uvm_assertions; + /// A single circular buffer holds information about every kernel launch the + /// process makes across all devices. + std::vector kernel_launches; + bool check_env_for_enable_launch_stacktracing() const; + bool check_env_for_dsa_enabled() const; + + public: + CUDAKernelLaunchRegistry(); + /// Register a new kernel launch and obtain a generation number back to be + /// passed to the kernel + uint32_t insert( + const char* launch_filename, + const char* launch_function, + const uint32_t launch_linenum, + const char* kernel_name, + const int32_t stream_id); + /// Get copies of the kernel launch registry and each device's assertion + /// failure buffer so they can be inspected without raising race conditions + std:: + pair, std::vector> + snapshot() const; + /// Get a pointer to the current device's assertion failure buffer. If no such + /// buffer exists then one is created. This means that the first kernel launch + /// made on each device will be slightly slower because memory allocations are + /// required + DeviceAssertionsData* get_uvm_assertions_ptr_for_current_device(); + /// Gets the global singleton of the registry + static CUDAKernelLaunchRegistry& get_singleton_ref(); + /// If not all devices support DSA, we disable it + const bool do_all_devices_support_managed_memory = false; + /// Whether or not to gather stack traces when launching kernels + bool gather_launch_stacktrace = false; + /// Whether or not host-side DSA is enabled or disabled at run-time + /// Note: Device-side code cannot be enabled/disabled at run-time + bool enabled_at_runtime = false; + /// Whether or not a device has indicated a failure + bool has_failed() const; +#ifdef TORCH_USE_CUDA_DSA + const bool enabled_at_compile_time = true; +#else + const bool enabled_at_compile_time = false; +#endif +}; + +C10_CUDA_API std::string c10_retrieve_device_side_assertion_info(); + +} // namespace c10::cuda + +// Each kernel launched with TORCH_DSA_KERNEL_LAUNCH +// requires the same input arguments. We introduce the following macro to +// standardize these. +#define TORCH_DSA_KERNEL_ARGS \ + [[maybe_unused]] c10::cuda::DeviceAssertionsData *const assertions_data, \ + [[maybe_unused]] uint32_t assertion_caller_id + +// This macro can be used to pass the DSA arguments onward to another +// function +#define TORCH_DSA_KERNEL_ARGS_PASS assertions_data, assertion_caller_id + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/cuda/CUDAException.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/cuda/CUDAException.h new file mode 100644 index 0000000000000000000000000000000000000000..71a5a9b86d8833ca28adad37f36061b201b2d5d5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/cuda/CUDAException.h @@ -0,0 +1,102 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +// Note [CHECK macro] +// ~~~~~~~~~~~~~~~~~~ +// This is a macro so that AT_ERROR can get accurate __LINE__ +// and __FILE__ information. We could split this into a short +// macro and a function implementation if we pass along __LINE__ +// and __FILE__, but no one has found this worth doing. + +// Used to denote errors from CUDA framework. +// This needs to be declared here instead util/Exception.h for proper conversion +// during hipify. +namespace c10 { +class C10_CUDA_API CUDAError : public c10::Error { + using Error::Error; +}; +} // namespace c10 + +#define C10_CUDA_CHECK(EXPR) \ + do { \ + const cudaError_t __err = EXPR; \ + c10::cuda::c10_cuda_check_implementation( \ + static_cast(__err), \ + __FILE__, \ + __func__, /* Line number data type not well-defined between \ + compilers, so we perform an explicit cast */ \ + static_cast(__LINE__), \ + true); \ + } while (0) + +#define C10_CUDA_CHECK_WARN(EXPR) \ + do { \ + const cudaError_t __err = EXPR; \ + if (C10_UNLIKELY(__err != cudaSuccess)) { \ + [[maybe_unused]] auto error_unused = cudaGetLastError(); \ + TORCH_WARN("CUDA warning: ", cudaGetErrorString(__err)); \ + } \ + } while (0) + +// Indicates that a CUDA error is handled in a non-standard way +#define C10_CUDA_ERROR_HANDLED(EXPR) EXPR + +// Intentionally ignore a CUDA error +#define C10_CUDA_IGNORE_ERROR(EXPR) \ + do { \ + const cudaError_t __err = EXPR; \ + if (C10_UNLIKELY(__err != cudaSuccess)) { \ + [[maybe_unused]] cudaError_t error_unused = cudaGetLastError(); \ + } \ + } while (0) + +// Clear the last CUDA error +#define C10_CUDA_CLEAR_ERROR() \ + do { \ + [[maybe_unused]] cudaError_t error_unused = cudaGetLastError(); \ + } while (0) + +// This should be used directly after every kernel launch to ensure +// the launch happened correctly and provide an early, close-to-source +// diagnostic if it didn't. +#define C10_CUDA_KERNEL_LAUNCH_CHECK() C10_CUDA_CHECK(cudaGetLastError()) + +/// Launches a CUDA kernel appending to it all the information need to handle +/// device-side assertion failures. Checks that the launch was successful. +#define TORCH_DSA_KERNEL_LAUNCH( \ + kernel, blocks, threads, shared_mem, stream, ...) \ + do { \ + auto& launch_registry = \ + c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref(); \ + kernel<<>>( \ + __VA_ARGS__, \ + launch_registry.get_uvm_assertions_ptr_for_current_device(), \ + launch_registry.insert( \ + __FILE__, __FUNCTION__, __LINE__, #kernel, stream.id())); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ + } while (0) + +namespace c10::cuda { + +/// In the event of a CUDA failure, formats a nice error message about that +/// failure and also checks for device-side assertion failures +C10_CUDA_API void c10_cuda_check_implementation( + const int32_t err, + const char* filename, + const char* function_name, + const uint32_t line_number, + const bool include_device_assertions); + +} // namespace c10::cuda + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/cuda/CUDAFunctions.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/cuda/CUDAFunctions.h new file mode 100644 index 0000000000000000000000000000000000000000..a97b3d89401a64afc834bbb3c573a4f1b2f21c22 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/cuda/CUDAFunctions.h @@ -0,0 +1,131 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +// This header provides C++ wrappers around commonly used CUDA API functions. +// The benefit of using C++ here is that we can raise an exception in the +// event of an error, rather than explicitly pass around error codes. This +// leads to more natural APIs. +// +// The naming convention used here matches the naming convention of torch.cuda + +#include +#include +#include +#include +#include +namespace c10::cuda { + +// NB: In the past, we were inconsistent about whether or not this reported +// an error if there were driver problems are not. Based on experience +// interacting with users, it seems that people basically ~never want this +// function to fail; it should just return zero if things are not working. +// Oblige them. +// It still might log a warning for user first time it's invoked +C10_CUDA_API DeviceIndex device_count() noexcept; + +// Version of device_count that throws is no devices are detected +C10_CUDA_API DeviceIndex device_count_ensure_non_zero(); + +C10_CUDA_API DeviceIndex current_device(); + +C10_CUDA_API void set_device(DeviceIndex device, const bool force = false); + +C10_CUDA_API void device_synchronize(); + +C10_CUDA_API void warn_or_error_on_sync(); + +// Raw CUDA device management functions +C10_CUDA_API cudaError_t GetDeviceCount(int* dev_count); + +C10_CUDA_API cudaError_t GetDevice(DeviceIndex* device); + +C10_CUDA_API cudaError_t +SetDevice(DeviceIndex device, const bool force = false); + +C10_CUDA_API cudaError_t MaybeSetDevice(DeviceIndex device); + +C10_CUDA_API DeviceIndex ExchangeDevice(DeviceIndex device); + +C10_CUDA_API DeviceIndex MaybeExchangeDevice(DeviceIndex device); + +C10_CUDA_API void SetTargetDevice(); + +enum class SyncDebugMode { L_DISABLED = 0, L_WARN, L_ERROR }; + +// this is a holder for c10 global state (similar to at GlobalContext) +// currently it's used to store cuda synchronization warning state, +// but can be expanded to hold other related global state, e.g. to +// record stream usage +class WarningState { + public: + void set_sync_debug_mode(SyncDebugMode l) { + sync_debug_mode = l; + } + + SyncDebugMode get_sync_debug_mode() { + return sync_debug_mode; + } + + private: + SyncDebugMode sync_debug_mode = SyncDebugMode::L_DISABLED; +}; + +C10_CUDA_API __inline__ WarningState& warning_state() { + static WarningState warning_state_; + return warning_state_; +} +// the subsequent functions are defined in the header because for performance +// reasons we want them to be inline +C10_CUDA_API void __inline__ memcpy_and_sync( + void* dst, + const void* src, + int64_t nbytes, + cudaMemcpyKind kind, + cudaStream_t stream) { + if (C10_UNLIKELY( + warning_state().get_sync_debug_mode() != SyncDebugMode::L_DISABLED)) { + warn_or_error_on_sync(); + } + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_stream_synchronization( + c10::kCUDA, reinterpret_cast(stream)); + } +#if defined(USE_ROCM) && USE_ROCM + // As of ROCm 6.4.1, HIP runtime does not raise an error during capture of + // hipMemcpyWithStream which is a synchronous call. Thus, we add a check + // here explicitly. + hipStreamCaptureStatus captureStatus; + C10_CUDA_CHECK(hipStreamGetCaptureInfo(stream, &captureStatus, nullptr)); + if (C10_LIKELY(captureStatus == hipStreamCaptureStatusNone)) { + C10_CUDA_CHECK(hipMemcpyWithStream(dst, src, nbytes, kind, stream)); + } else { + C10_CUDA_CHECK(hipErrorStreamCaptureUnsupported); + } +#else + C10_CUDA_CHECK(cudaMemcpyAsync(dst, src, nbytes, kind, stream)); + C10_CUDA_CHECK(cudaStreamSynchronize(stream)); +#endif +} + +C10_CUDA_API void __inline__ stream_synchronize(cudaStream_t stream) { + if (C10_UNLIKELY( + warning_state().get_sync_debug_mode() != SyncDebugMode::L_DISABLED)) { + warn_or_error_on_sync(); + } + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_stream_synchronization( + c10::kCUDA, reinterpret_cast(stream)); + } + C10_CUDA_CHECK(cudaStreamSynchronize(stream)); +} + +C10_CUDA_API bool hasPrimaryContext(DeviceIndex device_index); +C10_CUDA_API std::optional getDeviceIndexWithPrimaryContext(); + +} // namespace c10::cuda + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/cuda/CUDAGraphsC10Utils.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/cuda/CUDAGraphsC10Utils.h new file mode 100644 index 0000000000000000000000000000000000000000..176c9290c3906815228faf0bdb502c50260eb1e9 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/cuda/CUDAGraphsC10Utils.h @@ -0,0 +1,81 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include + +// CUDA Graphs utils used by c10 and aten. +// aten/cuda/CUDAGraphsUtils.cuh adds utils used by aten only. + +namespace c10::cuda { + +// RAII guard for "cudaStreamCaptureMode", a thread-local value +// that controls the error-checking strictness of a capture. +struct C10_CUDA_API CUDAStreamCaptureModeGuard { + CUDAStreamCaptureModeGuard(cudaStreamCaptureMode desired) + : strictness_(desired) { + C10_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&strictness_)); + } + CUDAStreamCaptureModeGuard(const CUDAStreamCaptureModeGuard&) = delete; + CUDAStreamCaptureModeGuard(CUDAStreamCaptureModeGuard&&) = delete; + CUDAStreamCaptureModeGuard& operator=(const CUDAStreamCaptureModeGuard&) = + delete; + CUDAStreamCaptureModeGuard& operator=(CUDAStreamCaptureModeGuard&&) = delete; + ~CUDAStreamCaptureModeGuard() { + C10_CUDA_CHECK_WARN(cudaThreadExchangeStreamCaptureMode(&strictness_)); + } + + private: + cudaStreamCaptureMode strictness_; +}; + +// Protects against enum cudaStreamCaptureStatus implementation changes. +// Some compilers seem not to like static_assert without the messages. +static_assert( + int(cudaStreamCaptureStatus::cudaStreamCaptureStatusNone) == 0, + "unexpected int(cudaStreamCaptureStatusNone) value"); +static_assert( + int(cudaStreamCaptureStatus::cudaStreamCaptureStatusActive) == 1, + "unexpected int(cudaStreamCaptureStatusActive) value"); +static_assert( + int(cudaStreamCaptureStatus::cudaStreamCaptureStatusInvalidated) == 2, + "unexpected int(cudaStreamCaptureStatusInvalidated) value"); + +enum class CaptureStatus : int { + None = int(cudaStreamCaptureStatus::cudaStreamCaptureStatusNone), + Active = int(cudaStreamCaptureStatus::cudaStreamCaptureStatusActive), + Invalidated = int(cudaStreamCaptureStatus::cudaStreamCaptureStatusInvalidated) +}; + +inline std::ostream& operator<<(std::ostream& os, CaptureStatus status) { + switch (status) { + case CaptureStatus::None: + os << "cudaStreamCaptureStatusNone"; + break; + case CaptureStatus::Active: + os << "cudaStreamCaptureStatusActive"; + break; + case CaptureStatus::Invalidated: + os << "cudaStreamCaptureStatusInvalidated"; + break; + default: + TORCH_INTERNAL_ASSERT( + false, "Unknown CUDA graph CaptureStatus", int(status)); + } + return os; +} + +// Use this version where you're sure a CUDA context exists already. +inline CaptureStatus currentStreamCaptureStatusMayInitCtx() { + cudaStreamCaptureStatus is_capturing{cudaStreamCaptureStatusNone}; + C10_CUDA_CHECK( + cudaStreamIsCapturing(c10::cuda::getCurrentCUDAStream(), &is_capturing)); + return CaptureStatus(is_capturing); +} + +} // namespace c10::cuda + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/cuda/CUDAGuard.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/cuda/CUDAGuard.h new file mode 100644 index 0000000000000000000000000000000000000000..6cf6ce4be26c07d3869fb4c7d7242fc220128fe8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/cuda/CUDAGuard.h @@ -0,0 +1,311 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include + +namespace c10::cuda { + +// This code is kind of boilerplatey. See Note [Whither the DeviceGuard +// boilerplate] + +/// A variant of DeviceGuard that is specialized for CUDA. It accepts +/// integer indices (interpreting them as CUDA devices) and is a little +/// more efficient than DeviceGuard (it compiles to straight line +/// cudaSetDevice/cudaGetDevice calls); however, it can only be used +/// from code that links against CUDA directly. +struct CUDAGuard { + /// No default constructor; see Note [Omitted default constructor from RAII] + explicit CUDAGuard() = delete; + + /// Set the current CUDA device to the passed device index. + explicit CUDAGuard(DeviceIndex device_index) : guard_(device_index) {} + + /// Sets the current CUDA device to the passed device. Errors if the passed + /// device is not a CUDA device. + explicit CUDAGuard(Device device) : guard_(device) {} + + // Copy is not allowed + CUDAGuard(const CUDAGuard&) = delete; + CUDAGuard& operator=(const CUDAGuard&) = delete; + + // Move is not allowed (there is no uninitialized state) + CUDAGuard(CUDAGuard&& other) = delete; + CUDAGuard& operator=(CUDAGuard&& other) = delete; + ~CUDAGuard() = default; + + /// Sets the CUDA device to the given device. Errors if the given device + /// is not a CUDA device. + void set_device(Device device) { + guard_.set_device(device); + } + + /// Sets the CUDA device to the given device. Errors if the given device + /// is not a CUDA device. (This method is provided for uniformity with + /// DeviceGuard). + void reset_device(Device device) { + guard_.reset_device(device); + } + + /// Sets the CUDA device to the given device index. + void set_index(DeviceIndex device_index) { + guard_.set_index(device_index); + } + + /// Returns the device that was set upon construction of the guard + Device original_device() const { + return guard_.original_device(); + } + + /// Returns the last device that was set via `set_device`, if any, otherwise + /// the device passed during construction. + Device current_device() const { + return guard_.current_device(); + } + + private: + /// The guard for the current device. + c10::impl::InlineDeviceGuard guard_; +}; + +/// A variant of OptionalDeviceGuard that is specialized for CUDA. See +/// CUDAGuard for when you can use this. +struct OptionalCUDAGuard { + /// Create an uninitialized OptionalCUDAGuard. + explicit OptionalCUDAGuard() = default; + + /// Set the current CUDA device to the passed Device, if it is not nullopt. + explicit OptionalCUDAGuard(std::optional device_opt) + : guard_(device_opt) {} + + /// Set the current CUDA device to the passed device index, if it is not + /// nullopt + explicit OptionalCUDAGuard(std::optional device_index_opt) + : guard_(device_index_opt) {} + + // Copy is not allowed + OptionalCUDAGuard(const OptionalCUDAGuard&) = delete; + OptionalCUDAGuard& operator=(const OptionalCUDAGuard&) = delete; + + // See Note [Move construction for RAII guards is tricky] + OptionalCUDAGuard(OptionalCUDAGuard&& other) = delete; + + // See Note [Move assignment for RAII guards is tricky] + OptionalCUDAGuard& operator=(OptionalCUDAGuard&& other) = delete; + ~OptionalCUDAGuard() = default; + + /// Sets the CUDA device to the given device, initializing the guard if it + /// is not already initialized. Errors if the given device is not a CUDA + /// device. + void set_device(Device device) { + guard_.set_device(device); + } + + /// Sets the CUDA device to the given device, initializing the guard if it is + /// not already initialized. Errors if the given device is not a CUDA device. + /// (This method is provided for uniformity with OptionalDeviceGuard). + void reset_device(Device device) { + guard_.reset_device(device); + } + + /// Sets the CUDA device to the given device index, initializing the guard if + /// it is not already initialized. + void set_index(DeviceIndex device_index) { + guard_.set_index(device_index); + } + + /// Returns the device that was set immediately prior to initialization of the + /// guard, or nullopt if the guard is uninitialized. + std::optional original_device() const { + return guard_.original_device(); + } + + /// Returns the most recent device that was set using this device guard, + /// either from construction, or via set_device, if the guard is initialized, + /// or nullopt if the guard is uninitialized. + std::optional current_device() const { + return guard_.current_device(); + } + + /// Restore the original CUDA device, resetting this guard to uninitialized + /// state. + void reset() { + guard_.reset(); + } + + private: + c10::impl::InlineOptionalDeviceGuard guard_; +}; + +/// A variant of StreamGuard that is specialized for CUDA. See CUDAGuard +/// for when you can use this. +struct CUDAStreamGuard { + /// No default constructor, see Note [Omitted default constructor from RAII] + explicit CUDAStreamGuard() = delete; + + /// Set the current CUDA device to the device associated with the passed + /// stream, and set the current CUDA stream on that device to the passed + /// stream. Errors if the Stream is not a CUDA stream. + explicit CUDAStreamGuard(Stream stream) : guard_(stream) {} + ~CUDAStreamGuard() = default; + + /// Copy is disallowed + CUDAStreamGuard(const CUDAStreamGuard&) = delete; + CUDAStreamGuard& operator=(const CUDAStreamGuard&) = delete; + + /// Move is disallowed, as CUDAStreamGuard does not have an uninitialized + /// state, which is required for moves on types with nontrivial destructors. + CUDAStreamGuard(CUDAStreamGuard&& other) = delete; + CUDAStreamGuard& operator=(CUDAStreamGuard&& other) = delete; + + /// Resets the currently set stream to the original stream and + /// the currently set device to the original device. Then, + /// set the current device to the device associated with the passed stream, + /// and set the current stream on that device to the passed stream. + /// Errors if the stream passed is not a CUDA stream. + /// + /// NOTE: this implementation may skip some stream/device setting if + /// it can prove that it is unnecessary. + /// + /// WARNING: reset_stream does NOT preserve previously set streams on + /// different devices. If you need to set streams on multiple devices + /// on CUDA, use CUDAMultiStreamGuard instead. + void reset_stream(Stream stream) { + guard_.reset_stream(stream); + } + + /// Returns the CUDA stream that was set at the time the guard was + /// constructed. + CUDAStream original_stream() const { + return CUDAStream(CUDAStream::UNCHECKED, guard_.original_stream()); + } + + /// Returns the most recent CUDA stream that was set using this device guard, + /// either from construction, or via set_stream. + CUDAStream current_stream() const { + return CUDAStream(CUDAStream::UNCHECKED, guard_.current_stream()); + } + + /// Returns the most recent CUDA device that was set using this device guard, + /// either from construction, or via set_device/reset_device/set_index. + Device current_device() const { + return guard_.current_device(); + } + + /// Returns the CUDA device that was set at the most recent reset_stream(), + /// or otherwise the device at construction time. + Device original_device() const { + return guard_.original_device(); + } + + private: + c10::impl::InlineStreamGuard guard_; +}; + +/// A variant of OptionalStreamGuard that is specialized for CUDA. See +/// CUDAGuard for when you can use this. +struct OptionalCUDAStreamGuard { + /// Create an uninitialized guard. + explicit OptionalCUDAStreamGuard() = default; + + /// Set the current CUDA device to the device associated with the passed + /// stream, and set the current CUDA stream on that device to the passed + /// stream. Errors if the Stream is not a CUDA stream. + explicit OptionalCUDAStreamGuard(Stream stream) : guard_(stream) {} + + /// Set the current device to the device associated with the passed stream, + /// and set the current stream on that device to the passed stream, + /// if the passed stream is not nullopt. + explicit OptionalCUDAStreamGuard(std::optional stream_opt) + : guard_(stream_opt) {} + + /// Copy is disallowed + OptionalCUDAStreamGuard(const OptionalCUDAStreamGuard&) = delete; + OptionalCUDAStreamGuard& operator=(const OptionalCUDAStreamGuard&) = delete; + + // See Note [Move construction for RAII guards is tricky] + OptionalCUDAStreamGuard(OptionalCUDAStreamGuard&& other) = delete; + + // See Note [Move assignment for RAII guards is tricky] + OptionalCUDAStreamGuard& operator=(OptionalCUDAStreamGuard&& other) = delete; + ~OptionalCUDAStreamGuard() = default; + + /// Resets the currently set CUDA stream to the original stream and + /// the currently set device to the original device. Then, + /// set the current device to the device associated with the passed stream, + /// and set the current stream on that device to the passed stream. + /// Initializes the guard if it was not previously initialized. + void reset_stream(Stream stream) { + guard_.reset_stream(stream); + } + + /// Returns the CUDA stream that was set at the time the guard was most + /// recently initialized, or nullopt if the guard is uninitialized. + std::optional original_stream() const { + auto r = guard_.original_stream(); + if (r.has_value()) { + return CUDAStream(CUDAStream::UNCHECKED, r.value()); + } else { + return std::nullopt; + } + } + + /// Returns the most recent CUDA stream that was set using this stream guard, + /// either from construction, or via reset_stream, if the guard is + /// initialized, or nullopt if the guard is uninitialized. + std::optional current_stream() const { + auto r = guard_.current_stream(); + if (r.has_value()) { + return CUDAStream(CUDAStream::UNCHECKED, r.value()); + } else { + return std::nullopt; + } + } + + /// Restore the original CUDA device and stream, resetting this guard to + /// uninitialized state. + void reset() { + guard_.reset(); + } + + private: + c10::impl::InlineOptionalStreamGuard guard_; +}; + +/// A variant of MultiStreamGuard that is specialized for CUDA. +struct CUDAMultiStreamGuard { + explicit CUDAMultiStreamGuard(ArrayRef streams) + : guard_(unwrapStreams(streams)) {} + + /// Copy is disallowed + CUDAMultiStreamGuard(const CUDAMultiStreamGuard&) = delete; + CUDAMultiStreamGuard& operator=(const CUDAMultiStreamGuard&) = delete; + + // See Note [Move construction for RAII guards is tricky] + CUDAMultiStreamGuard(CUDAMultiStreamGuard&& other) = delete; + + // See Note [Move assignment for RAII guards is tricky] + CUDAMultiStreamGuard& operator=(CUDAMultiStreamGuard&& other) = delete; + ~CUDAMultiStreamGuard() = default; + + private: + c10::impl::InlineMultiStreamGuard guard_; + + static std::vector unwrapStreams(ArrayRef cudaStreams) { + std::vector streams; + streams.reserve(cudaStreams.size()); + for (const CUDAStream& cudaStream : cudaStreams) { + streams.push_back(cudaStream); + } + return streams; + } +}; + +} // namespace c10::cuda + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/cuda/CUDAMacros.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/cuda/CUDAMacros.h new file mode 100644 index 0000000000000000000000000000000000000000..93b371ce6ee854d074f6d47d0481c2a193e07d69 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/cuda/CUDAMacros.h @@ -0,0 +1,56 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#ifndef C10_USING_CUSTOM_GENERATED_MACROS + +// We have not yet modified the AMD HIP build to generate this file so +// we add an extra option to specifically ignore it. +#ifndef C10_CUDA_NO_CMAKE_CONFIGURE_FILE +#include +#endif // C10_CUDA_NO_CMAKE_CONFIGURE_FILE + +#endif + +// See c10/macros/Export.h for a detailed explanation of what the function +// of these macros are. We need one set of macros for every separate library +// we build. + +#ifdef _WIN32 +#if defined(C10_CUDA_BUILD_SHARED_LIBS) +#define C10_CUDA_EXPORT __declspec(dllexport) +#define C10_CUDA_IMPORT __declspec(dllimport) +#else +#define C10_CUDA_EXPORT +#define C10_CUDA_IMPORT +#endif +#else // _WIN32 +#if defined(__GNUC__) +#define C10_CUDA_EXPORT __attribute__((__visibility__("default"))) +#else // defined(__GNUC__) +#define C10_CUDA_EXPORT +#endif // defined(__GNUC__) +#define C10_CUDA_IMPORT C10_CUDA_EXPORT +#endif // _WIN32 + +// This one is being used by libc10_cuda.so +#ifdef C10_CUDA_BUILD_MAIN_LIB +#define C10_CUDA_API C10_CUDA_EXPORT +#else +#define C10_CUDA_API C10_CUDA_IMPORT +#endif + +/** + * The maximum number of GPUs that we recognizes. Increasing this beyond the + * initial limit of 16 broke Caffe2 testing, hence the ifdef guards. + * This value cannot be more than 128 because our DeviceIndex is a uint8_t. +o */ +#ifdef FBCODE_CAFFE2 +// fbcode depends on this value being 16 +#define C10_COMPILE_TIME_MAX_GPUS 16 +#else +#define C10_COMPILE_TIME_MAX_GPUS 120 +#endif + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/cuda/CUDAMathCompat.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/cuda/CUDAMathCompat.h new file mode 100644 index 0000000000000000000000000000000000000000..ec08cde0c1b71c9a0c8dd586e4fa7f6760e230f8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/cuda/CUDAMathCompat.h @@ -0,0 +1,157 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +/* This file defines math functions compatible across different gpu + * platforms (currently CUDA and HIP). + */ +#if defined(__CUDACC__) || defined(__HIPCC__) + +#include +#include + +#ifdef __HIPCC__ +#define __MATH_FUNCTIONS_DECL__ inline C10_DEVICE +#else /* __HIPCC__ */ +#ifdef __CUDACC_RTC__ +#define __MATH_FUNCTIONS_DECL__ C10_HOST_DEVICE +#else /* __CUDACC_RTC__ */ +#define __MATH_FUNCTIONS_DECL__ inline C10_HOST_DEVICE +#endif /* __CUDACC_RTC__ */ +#endif /* __HIPCC__ */ + +namespace c10::cuda::compat { + +__MATH_FUNCTIONS_DECL__ float abs(float x) { + return ::fabsf(x); +} +__MATH_FUNCTIONS_DECL__ double abs(double x) { + return ::fabs(x); +} + +__MATH_FUNCTIONS_DECL__ float exp(float x) { + return ::expf(x); +} +__MATH_FUNCTIONS_DECL__ double exp(double x) { + return ::exp(x); +} + +__MATH_FUNCTIONS_DECL__ float ceil(float x) { + return ::ceilf(x); +} +__MATH_FUNCTIONS_DECL__ double ceil(double x) { + return ::ceil(x); +} + +__MATH_FUNCTIONS_DECL__ float copysign(float x, float y) { +#if defined(__CUDA_ARCH__) || defined(__HIPCC__) + return ::copysignf(x, y); +#else + // std::copysign gets ICE/Segfaults with gcc 7.5/8 on arm64 + // (e.g. Jetson), see PyTorch PR #51834 + // This host function needs to be here for the compiler but is never used + TORCH_INTERNAL_ASSERT( + false, "CUDAMathCompat copysign should not run on the CPU"); +#endif +} +__MATH_FUNCTIONS_DECL__ double copysign(double x, double y) { +#if defined(__CUDA_ARCH__) || defined(__HIPCC__) + return ::copysign(x, y); +#else + // see above + TORCH_INTERNAL_ASSERT( + false, "CUDAMathCompat copysign should not run on the CPU"); +#endif +} + +__MATH_FUNCTIONS_DECL__ float floor(float x) { + return ::floorf(x); +} +__MATH_FUNCTIONS_DECL__ double floor(double x) { + return ::floor(x); +} + +__MATH_FUNCTIONS_DECL__ float log(float x) { + return ::logf(x); +} +__MATH_FUNCTIONS_DECL__ double log(double x) { + return ::log(x); +} + +__MATH_FUNCTIONS_DECL__ float log1p(float x) { + return ::log1pf(x); +} + +__MATH_FUNCTIONS_DECL__ double log1p(double x) { + return ::log1p(x); +} + +__MATH_FUNCTIONS_DECL__ float max(float x, float y) { + return ::fmaxf(x, y); +} +__MATH_FUNCTIONS_DECL__ double max(double x, double y) { + return ::fmax(x, y); +} + +__MATH_FUNCTIONS_DECL__ float min(float x, float y) { + return ::fminf(x, y); +} +__MATH_FUNCTIONS_DECL__ double min(double x, double y) { + return ::fmin(x, y); +} + +__MATH_FUNCTIONS_DECL__ float pow(float x, float y) { + return ::powf(x, y); +} +__MATH_FUNCTIONS_DECL__ double pow(double x, double y) { + return ::pow(x, y); +} + +__MATH_FUNCTIONS_DECL__ void sincos(float x, float* sptr, float* cptr) { + return ::sincosf(x, sptr, cptr); +} +__MATH_FUNCTIONS_DECL__ void sincos(double x, double* sptr, double* cptr) { + return ::sincos(x, sptr, cptr); +} + +__MATH_FUNCTIONS_DECL__ float sqrt(float x) { + return ::sqrtf(x); +} +__MATH_FUNCTIONS_DECL__ double sqrt(double x) { + return ::sqrt(x); +} + +__MATH_FUNCTIONS_DECL__ float rsqrt(float x) { + return ::rsqrtf(x); +} +__MATH_FUNCTIONS_DECL__ double rsqrt(double x) { + return ::rsqrt(x); +} + +__MATH_FUNCTIONS_DECL__ float tan(float x) { + return ::tanf(x); +} +__MATH_FUNCTIONS_DECL__ double tan(double x) { + return ::tan(x); +} + +__MATH_FUNCTIONS_DECL__ float tanh(float x) { + return ::tanhf(x); +} +__MATH_FUNCTIONS_DECL__ double tanh(double x) { + return ::tanh(x); +} + +__MATH_FUNCTIONS_DECL__ float normcdf(float x) { + return ::normcdff(x); +} +__MATH_FUNCTIONS_DECL__ double normcdf(double x) { + return ::normcdf(x); +} + +} // namespace c10::cuda::compat + +#endif + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/cuda/CUDAMiscFunctions.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/cuda/CUDAMiscFunctions.h new file mode 100644 index 0000000000000000000000000000000000000000..c44105fa61281b2d06f02524b789d7c7554374f9 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/cuda/CUDAMiscFunctions.h @@ -0,0 +1,20 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +// this file is to avoid circular dependency between CUDAFunctions.h and +// CUDAExceptions.h + +#include +#include + +#include +#include + +namespace c10::cuda { +C10_CUDA_API std::string get_cuda_error_help(cudaError_t /*error*/) noexcept; +C10_CUDA_API const char* get_cuda_check_suffix() noexcept; +C10_CUDA_API std::mutex* getFreeMutex(); +} // namespace c10::cuda + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/cuda/CUDAStream.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/cuda/CUDAStream.h new file mode 100644 index 0000000000000000000000000000000000000000..c0e616f584c5a41e40e75586c4e3d3ae8b381feb --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/cuda/CUDAStream.h @@ -0,0 +1,273 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include + +#include +#include +#include +#include + +/* + * Stream pool note. + * + * A CUDAStream is an abstraction of an actual cuStream on the GPU. CUDAStreams + * are backed by cuStreams, but they use several pools to minimize the costs + * associated with creating, retaining, and destroying cuStreams. + * + * There are three pools per device, and a device's pools are lazily created. + * + * The first pool contains only the default stream. When the default stream + * is requested it's returned. + * + * The second pool is the "low priority" or "default priority" streams. In + * HIP builds there is no distinction between streams in this pool and streams + * in the third pool (below). There are 32 of these streams per device, and + * when a stream is requested one of these streams is returned round-robin. + * That is, the first stream requested is at index 0, the second at index 1... + * to index 31, then index 0 again. + * + * This means that if 33 low priority streams are requested, the first and + * last streams requested are actually the same stream (under the covers) + * and kernels enqueued on them cannot run concurrently. + * + * The third pool is the "high priority" streams. The third pool acts like + * the second pool except the streams are created with a higher priority. + * + * These pools suggest that stream users should prefer many short-lived streams, + * as the cost of acquiring and releasing streams is effectively zero. If + * many longer-lived streams are required in performance critical scenarios + * then the functionality here may need to be extended to allow, for example, + * "reserving" a subset of the pool so that other streams do not accidentally + * overlap the performance critical streams. + * + * Note: although the notion of "current stream for device" is thread local + * (every OS thread has a separate current stream, as one might expect), + * the stream pool is global across all threads; stream 0 is always stream 0 + * no matter which thread you use it on. Multiple threads can synchronize + * on the same stream. Although the CUDA documentation is not very clear + * on the matter, streams are thread safe; e.g., it is safe to enqueue + * a kernel on the same stream from two different threads. + */ + +namespace c10::cuda { + +static constexpr int max_compile_time_stream_priorities = 4; + +// Value object representing a CUDA stream. This is just a wrapper +// around c10::Stream, but it comes with a little extra CUDA-specific +// functionality (conversion to cudaStream_t), and a guarantee that +// the wrapped c10::Stream really is a CUDA stream. +class C10_CUDA_API CUDAStream { + public: + enum Unchecked { UNCHECKED }; + + /// Construct a CUDAStream from a Stream. This construction is checked, + /// and will raise an error if the Stream is not, in fact, a CUDA stream. + explicit CUDAStream(Stream stream) : stream_(stream) { + TORCH_CHECK(stream_.device_type() == DeviceType::CUDA); + } + + /// Construct a CUDAStream from a Stream with no error checking. + /// This constructor uses the "named" constructor idiom, and can + /// be invoked as: CUDAStream(CUDAStream::UNCHECKED, stream) + explicit CUDAStream(Unchecked /*unused*/, Stream stream) : stream_(stream) {} + + bool operator==(const CUDAStream& other) const noexcept { + return unwrap() == other.unwrap(); + } + + bool operator!=(const CUDAStream& other) const noexcept { + return unwrap() != other.unwrap(); + } + + /// Implicit conversion to cudaStream_t. + operator cudaStream_t() const { + return stream(); + } + + /// Implicit conversion to Stream (a.k.a., forget that the stream is a + /// CUDA stream). + operator Stream() const { + return unwrap(); + } + + /// Used to avoid baking in device type explicitly to Python-side API. + DeviceType device_type() const { + return DeviceType::CUDA; + } + + /// Get the CUDA device index that this stream is associated with. + DeviceIndex device_index() const { + return stream_.device_index(); + } + + /// Get the full Device that this stream is associated with. The Device + /// is guaranteed to be a CUDA device. + Device device() const { + return Device(DeviceType::CUDA, device_index()); + } + + /// Return the stream ID corresponding to this particular stream. + StreamId id() const { + return stream_.id(); + } + + bool query() const { + DeviceGuard guard{stream_.device()}; + cudaError_t err = C10_CUDA_ERROR_HANDLED(cudaStreamQuery(stream())); + + if (err == cudaSuccess) { + return true; + } else if (err != cudaErrorNotReady) { + C10_CUDA_CHECK(err); + } else { + // ignore and clear the error if not ready + (void)cudaGetLastError(); + } + + return false; + } + + void synchronize() const { + DeviceGuard guard{stream_.device()}; + c10::cuda::stream_synchronize(stream()); + } + + int priority() const { + DeviceGuard guard{stream_.device()}; + int priority = 0; + C10_CUDA_CHECK(cudaStreamGetPriority(stream(), &priority)); + return priority; + } + + /// Explicit conversion to cudaStream_t. + cudaStream_t stream() const; + + /// Explicit conversion to Stream. + Stream unwrap() const { + return stream_; + } + + /// Reversibly pack a CUDAStream into a struct representation. + /// Previously the stream's data was packed into a single int64_t, + /// as it was assumed the fields would not require more than + /// 64 bits of storage in total. + /// See https://github.com/pytorch/pytorch/issues/75854 + /// for more information regarding newer platforms that may violate + /// this assumption. + /// + /// The CUDAStream can be unpacked using unpack(). + struct c10::StreamData3 pack3() const { + return stream_.pack3(); + } + + // Unpack a CUDAStream from the 3 fields generated by pack(). + static CUDAStream unpack3( + StreamId stream_id, + DeviceIndex device_index, + DeviceType device_type) { + return CUDAStream(Stream::unpack3(stream_id, device_index, device_type)); + } + + static std::tuple priority_range() { + // Note: this returns the range of priority **supported by PyTorch**, not + // the range of priority **supported by CUDA**. The former is a subset of + // the latter. + int least_priority = 0, greatest_priority = 0; + C10_CUDA_CHECK( + cudaDeviceGetStreamPriorityRange(&least_priority, &greatest_priority)); +#ifdef USE_ROCM + // See Note [HIP stream priorities] + TORCH_INTERNAL_ASSERT( + least_priority == 1, "Unexpected HIP stream priority range"); + least_priority = 0; +#else + TORCH_INTERNAL_ASSERT( + least_priority == 0, "Unexpected CUDA stream priority range"); +#endif + TORCH_INTERNAL_ASSERT( + greatest_priority <= -1, "Unexpected CUDA stream priority range"); + greatest_priority = std::max( + -c10::cuda::max_compile_time_stream_priorities + 1, greatest_priority); + return std::make_tuple(least_priority, greatest_priority); + } + + // Deleted for now; use CUDAEvent::block instead + // void synchronize_with(const CUDAEvent& event) const; + + private: + Stream stream_; +}; + +/** + * Get a new stream from the CUDA stream pool. You can think of this + * as "creating" a new stream, but no such creation actually happens; + * instead, streams are preallocated from the pool and returned in a + * round-robin fashion. + * + * You can request a stream from the high priority pool by setting + * isHighPriority to true, or a stream for a specific device by setting device + * (defaulting to the current CUDA stream.) + */ +C10_API CUDAStream +getStreamFromPool(const bool isHighPriority = false, DeviceIndex device = -1); +// no default priority to disambiguate overloads +C10_API CUDAStream +getStreamFromPool(const int priority, DeviceIndex device = -1); + +/** + * Get a CUDAStream from a externally allocated one. + * + * This is mainly for interoperability with different libraries where we + * want to operate on a non-torch allocated stream for data exchange or similar + * purposes + */ +C10_API CUDAStream +getStreamFromExternal(cudaStream_t ext_stream, DeviceIndex device_index); + +/** + * Get the default CUDA stream, for the passed CUDA device, or for the + * current device if no device index is passed. The default stream is + * where most computation occurs when you aren't explicitly using + * streams. + */ +C10_API CUDAStream getDefaultCUDAStream(DeviceIndex device_index = -1); + +/** + * Get the current CUDA stream, for the passed CUDA device, or for the + * current device if no device index is passed. The current CUDA stream + * will usually be the default CUDA stream for the device, but it may + * be different if someone called 'setCurrentCUDAStream' or used 'StreamGuard' + * or 'CUDAStreamGuard'. + */ +C10_API CUDAStream getCurrentCUDAStream(DeviceIndex device_index = -1); + +/** + * Set the current stream on the device of the passed in stream to be + * the passed in stream. Yes, you read that right: this function + * has *nothing* to do with the current device: it toggles the current + * stream of the device of the passed stream. + * + * Confused? Avoid using this function; prefer using 'CUDAStreamGuard' instead + * (which will switch both your current device and current stream in the way you + * expect, and reset it back to its original state afterwards). + */ +C10_API void setCurrentCUDAStream(CUDAStream stream); + +C10_API std::ostream& operator<<(std::ostream& stream, const CUDAStream& s); + +} // namespace c10::cuda + +namespace std { +template <> +struct hash { + size_t operator()(c10::cuda::CUDAStream s) const noexcept { + return std::hash{}(s.unwrap()); + } +}; +} // namespace std + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/cuda/driver_api.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/cuda/driver_api.h new file mode 100644 index 0000000000000000000000000000000000000000..49a5a131d4888f5f8f422bc07b74065db9315397 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/cuda/driver_api.h @@ -0,0 +1,124 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include +#define NVML_NO_UNVERSIONED_FUNC_DEFS +#include + +#include + +#define C10_CUDA_DRIVER_CHECK(EXPR) \ + do { \ + CUresult __err = EXPR; \ + if (__err != CUDA_SUCCESS) { \ + const char* err_str; \ + CUresult get_error_str_err [[maybe_unused]] = \ + c10::cuda::DriverAPI::get()->cuGetErrorString_(__err, &err_str); \ + if (get_error_str_err != CUDA_SUCCESS) { \ + TORCH_CHECK(false, "CUDA driver error: unknown error"); \ + } else { \ + TORCH_CHECK(false, "CUDA driver error: ", err_str); \ + } \ + } \ + } while (0) + +#define C10_CUDA_DRIVER_CHECK_GOTO(EXPR, NEXT) \ + do { \ + CUresult __err = EXPR; \ + if (__err != CUDA_SUCCESS) { \ + const char* err_str; \ + CUresult get_error_str_err [[maybe_unused]] = \ + c10::cuda::DriverAPI::get()->cuGetErrorString_(__err, &err_str); \ + if (get_error_str_err != CUDA_SUCCESS) { \ + TORCH_WARN("CUDA driver error: unknown error"); \ + } else { \ + TORCH_WARN("CUDA driver error: ", err_str); \ + } \ + goto NEXT; \ + } \ + } while (0) + +// The integer in the second column specifies the requested CUDA Driver API +// version. The dynamic loader will accept a driver with a newer version, but it +// ensures that the requested symbol exists in *at least* the specified version +// or earlier. + +// Keep these requested versions as low as possible to maximize compatibility +// across different driver versions. + +// Why do we pin to an older version instead of using the latest? +// If a user installs a newer driver, blindly resolving the symbol may bind to a +// newer version of the function with different behavior, potentially breaking +// PyTorch. + +#define C10_LIBCUDA_DRIVER_API_REQUIRED(_) \ + _(cuDeviceGetAttribute, 12000) \ + _(cuMemAddressReserve, 12000) \ + _(cuMemRelease, 12000) \ + _(cuMemMap, 12000) \ + _(cuMemAddressFree, 12000) \ + _(cuMemSetAccess, 12000) \ + _(cuMemUnmap, 12000) \ + _(cuMemCreate, 12000) \ + _(cuMemGetAllocationGranularity, 12000) \ + _(cuMemExportToShareableHandle, 12000) \ + _(cuMemImportFromShareableHandle, 12000) \ + _(cuMemsetD32Async, 12000) \ + _(cuStreamWriteValue32, 12000) \ + _(cuGetErrorString, 12000) + +#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12030) +#define C10_LIBCUDA_DRIVER_API_OPTIONAL(_) \ + _(cuCtxFromGreenCtx, 12080) \ + _(cuCtxGetCurrent, 12080) \ + _(cuCtxPopCurrent, 12080) \ + _(cuCtxPushCurrent, 12080) \ + _(cuCtxSetCurrent, 12080) \ + _(cuGreenCtxCreate, 12080) \ + _(cuGreenCtxDestroy, 12080) \ + _(cuDevSmResourceSplitByCount, 12080) \ + _(cuDeviceGet, 12080) \ + _(cuDeviceGetDevResource, 12080) \ + _(cuDevResourceGenerateDesc, 12080) \ + _(cuMulticastAddDevice, 12030) \ + _(cuMulticastBindMem, 12030) \ + _(cuMulticastCreate, 12030) \ + _(cuMulticastUnbind, 12030) +#else +#define C10_LIBCUDA_DRIVER_API_OPTIONAL(_) +#endif + +#define C10_NVML_DRIVER_API(_) \ + _(nvmlInit_v2) \ + _(nvmlDeviceGetHandleByPciBusId_v2) \ + _(nvmlDeviceGetNvLinkRemoteDeviceType) \ + _(nvmlDeviceGetNvLinkRemotePciInfo_v2) \ + _(nvmlDeviceGetComputeRunningProcesses) \ + _(nvmlSystemGetCudaDriverVersion_v2) + +#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12040) +#define C10_NVML_DRIVER_API_OPTIONAL(_) _(nvmlDeviceGetGpuFabricInfoV) +#else +#define C10_NVML_DRIVER_API_OPTIONAL(_) +#endif + +namespace c10::cuda { + +struct DriverAPI { +#define CREATE_MEMBER_VERSIONED(name, version) decltype(&name) name##_; +#define CREATE_MEMBER(name) decltype(&name) name##_; + C10_LIBCUDA_DRIVER_API_REQUIRED(CREATE_MEMBER_VERSIONED) + C10_LIBCUDA_DRIVER_API_OPTIONAL(CREATE_MEMBER_VERSIONED) + C10_NVML_DRIVER_API(CREATE_MEMBER) + C10_NVML_DRIVER_API_OPTIONAL(CREATE_MEMBER) +#undef CREATE_MEMBER_VERSIONED +#undef CREATE_MEMBER + + static DriverAPI* get(); + static void* get_nvml_handle(); +}; + +} // namespace c10::cuda + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/cuda/impl/CUDAGuardImpl.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/cuda/impl/CUDAGuardImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..24cb643a0599072f52eb1188bf53fc236368e957 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/cuda/impl/CUDAGuardImpl.h @@ -0,0 +1,270 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace c10::cuda::impl { + +struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface { + static constexpr DeviceType static_type = DeviceType::CUDA; + + CUDAGuardImpl() = default; + explicit CUDAGuardImpl(DeviceType t) { + TORCH_CHECK( + t == DeviceType::CUDA, + "CUDAGuardImpl initialized with non-CUDA DeviceType: ", + t); + } + DeviceType type() const override { + return DeviceType::CUDA; + } + Device exchangeDevice(Device d) const override { + TORCH_CHECK(d.is_cuda(), "Expected a CUDA device, but got ", d); + auto old_device_index = c10::cuda::ExchangeDevice(d.index()); + return Device(DeviceType::CUDA, old_device_index); + } + Device getDevice() const override { + DeviceIndex device = 0; + C10_CUDA_CHECK(c10::cuda::GetDevice(&device)); + return Device(DeviceType::CUDA, device); + } + std::optional uncheckedGetDevice() const noexcept { + DeviceIndex device{-1}; + const auto err = C10_CUDA_ERROR_HANDLED(c10::cuda::GetDevice(&device)); + C10_CUDA_CHECK_WARN(err); + if (err != cudaSuccess) { + return std::nullopt; + } + return Device(DeviceType::CUDA, device); + } + void setDevice(Device d) const override { + TORCH_CHECK(d.is_cuda(), "Expected a CUDA device, but got ", d); + C10_CUDA_CHECK(c10::cuda::SetDevice(d.index())); + } + void uncheckedSetDevice(Device d) const noexcept override { + C10_CUDA_CHECK_WARN(c10::cuda::MaybeSetDevice(d.index())); + } + Stream getStream(Device d) const override { + return getCurrentCUDAStream(d.index()).unwrap(); + } + Stream getDefaultStream(Device d) const override { + return getDefaultCUDAStream(d.index()); + } + Stream getNewStream(Device d, int priority = 0) const override { + return getStreamFromPool(priority, d.index()); + } + Stream getStreamFromGlobalPool(Device d, bool isHighPriority = false) + const override { + return getStreamFromPool(isHighPriority, d.index()); + } + // NB: These do NOT set the current device + Stream exchangeStream(Stream s) const override { + CUDAStream cs(s); + auto old_stream = getCurrentCUDAStream(s.device().index()); + setCurrentCUDAStream(cs); + return old_stream.unwrap(); + } + DeviceIndex deviceCount() const noexcept override { + return device_count(); + } + + // Event-related functions + void createEvent(cudaEvent_t* cuda_event, const EventFlag flag) const { + // Maps PyTorch's Event::Flag to CUDA flag + auto cuda_flag = cudaEventDefault; + switch (flag) { + case EventFlag::PYTORCH_DEFAULT: + cuda_flag = cudaEventDisableTiming; + break; + case EventFlag::BACKEND_DEFAULT: + cuda_flag = cudaEventDefault; + break; + default: + TORCH_CHECK(false, "CUDA event received unknown flag"); + } + + C10_CUDA_CHECK(cudaEventCreateWithFlags(cuda_event, cuda_flag)); + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_event_creation( + c10::kCUDA, reinterpret_cast(cuda_event)); + } + } + + void destroyEvent(void* event, const DeviceIndex device_index) + const noexcept override { + if (!event) + return; + auto cuda_event = static_cast(event); + DeviceIndex orig_device{-1}; + C10_CUDA_CHECK_WARN(c10::cuda::GetDevice(&orig_device)); + C10_CUDA_CHECK_WARN(c10::cuda::SetDevice(device_index)); + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_event_deletion( + c10::kCUDA, reinterpret_cast(cuda_event)); + } + C10_CUDA_CHECK_WARN(cudaEventDestroy(cuda_event)); + C10_CUDA_CHECK_WARN(c10::cuda::SetDevice(orig_device)); + } + + void record( + void** event, + const Stream& stream, + const DeviceIndex device_index, + const EventFlag flag) const override { + TORCH_CHECK( + device_index == -1 || device_index == stream.device_index(), + "Event device index ", + device_index, + " does not match recording stream's device index ", + stream.device_index(), + "."); + + cudaEvent_t cuda_event = static_cast(*event); + CUDAStream cuda_stream{stream}; + + // Moves to stream's device to record + const auto orig_device = getDevice(); + setDevice(stream.device()); + + // Creates the event (lazily) + if (!cuda_event) + createEvent(&cuda_event, flag); + C10_CUDA_CHECK(cudaEventRecord(cuda_event, cuda_stream)); + // Makes the void* point to the (possibly just allocated) CUDA event + *event = cuda_event; + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_event_record( + c10::kCUDA, + reinterpret_cast(cuda_event), + reinterpret_cast(cuda_stream.stream())); + } + + // Resets device + setDevice(orig_device); + } + + void block(void* event, const Stream& stream) const override { + if (!event) + return; + cudaEvent_t cuda_event = static_cast(event); + CUDAStream cuda_stream{stream}; + const auto orig_device = getDevice(); + setDevice(stream.device()); + C10_CUDA_CHECK(cudaStreamWaitEvent( + cuda_stream, + cuda_event, + /*flags (must be zero)=*/0)); + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_event_wait( + c10::kCUDA, + reinterpret_cast(cuda_event), + reinterpret_cast(cuda_stream.stream())); + } + setDevice(orig_device); + } + + // May be called from any device + bool queryEvent(void* event) const override { + if (!event) + return true; + cudaEvent_t cuda_event = static_cast(event); + // Note: cudaEventQuery can be safely called from any device + const cudaError_t err = C10_CUDA_ERROR_HANDLED(cudaEventQuery(cuda_event)); + if (err != cudaErrorNotReady) { + C10_CUDA_CHECK(err); + } else { + // ignore and clear the error if not ready + (void)cudaGetLastError(); + } + return (err == cudaSuccess); + } + + // Stream-related functions + bool queryStream(const Stream& stream) const override { + CUDAStream cuda_stream{stream}; + return cuda_stream.query(); + } + + void synchronizeStream(const Stream& stream) const override { + CUDAStream cuda_stream{stream}; + cuda_stream.synchronize(); + } + + void synchronizeEvent(void* event) const override { + if (!event) + return; + cudaEvent_t cuda_event = static_cast(event); + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_event_synchronization( + c10::kCUDA, reinterpret_cast(cuda_event)); + } + // Note: cudaEventSynchronize can be safely called from any device + C10_CUDA_CHECK(cudaEventSynchronize(cuda_event)); + } + + // Note: synchronizeDevice can be safely called from any device + void synchronizeDevice(const c10::DeviceIndex device_index) const override { + DeviceIndex orig_device{-1}; + C10_CUDA_CHECK(c10::cuda::GetDevice(&orig_device)); + C10_CUDA_CHECK(c10::cuda::SetDevice(device_index)); + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_device_synchronization(c10::kCUDA); + } + C10_CUDA_CHECK(cudaDeviceSynchronize()); + C10_CUDA_CHECK(c10::cuda::SetDevice(orig_device)); + } + + void recordDataPtrOnStream(const c10::DataPtr& data_ptr, const Stream& stream) + const override { + CUDAStream cuda_stream{stream}; + CUDACachingAllocator::recordStream(data_ptr, cuda_stream); + } + + double elapsedTime(void* event1, void* event2, const DeviceIndex device_index) + const override { + TORCH_CHECK( + event1 && event2, + "Both events must be recorded before calculating elapsed time."); + // Even though cudaEventElapsedTime can be safely called from any device, if + // the current device is not initialized, it will create a new cuda context, + // which will consume a lot of memory. + DeviceIndex orig_device{-1}; + C10_CUDA_CHECK(c10::cuda::GetDevice(&orig_device)); + C10_CUDA_CHECK(c10::cuda::SetDevice(device_index)); + cudaEvent_t cuda_event1 = static_cast(event1); + cudaEvent_t cuda_event2 = static_cast(event2); + float time_ms = 0; + // raise cudaErrorNotReady if either event is recorded but not yet completed + C10_CUDA_CHECK(cudaEventElapsedTime(&time_ms, cuda_event1, cuda_event2)); + C10_CUDA_CHECK(c10::cuda::SetDevice(orig_device)); + return static_cast(time_ms); + } +}; + +} // namespace c10::cuda::impl + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/cuda/impl/CUDATest.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/cuda/impl/CUDATest.h new file mode 100644 index 0000000000000000000000000000000000000000..3edcfe6d88a72a94120bf95d82a6bbc0a0798500 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/cuda/impl/CUDATest.h @@ -0,0 +1,14 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include + +namespace c10::cuda::impl { + +C10_CUDA_API int c10_cuda_test(); + +} + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/cuda/impl/cuda_cmake_macros.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/cuda/impl/cuda_cmake_macros.h new file mode 100644 index 0000000000000000000000000000000000000000..a2fb43f54676972b1df12b2be146786465a1b403 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/cuda/impl/cuda_cmake_macros.h @@ -0,0 +1,11 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +// Automatically generated header file for the C10 CUDA library. Do not +// include this file directly. Instead, include c10/cuda/CUDAMacros.h + +#define C10_CUDA_BUILD_SHARED_LIBS + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/macros/Export.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/macros/Export.h new file mode 100644 index 0000000000000000000000000000000000000000..dfc4378c482c621ce05179900c719510e59ee8d0 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/macros/Export.h @@ -0,0 +1,6 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/macros/Macros.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/macros/Macros.h new file mode 100644 index 0000000000000000000000000000000000000000..02fdbd4df99eaed11dfdc5dc190378156ea30177 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/macros/Macros.h @@ -0,0 +1,6 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/macros/cmake_macros.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/macros/cmake_macros.h new file mode 100644 index 0000000000000000000000000000000000000000..5d89f61f37a9db44fc7bbe5df20ce372e37dff4c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/macros/cmake_macros.h @@ -0,0 +1,10 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// This file exists for backwards compatibility and has been moved to +// torch/headeronly/macros/cmake_macros.h.in. No end user library should be +// including this file directly anyway (cuz they should be including +// Macros.h instead). +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/metal/atomic.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/metal/atomic.h new file mode 100644 index 0000000000000000000000000000000000000000..4bec87d32d3efa5badc79d2b85d2cb018fe9c9a1 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/metal/atomic.h @@ -0,0 +1,182 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include +namespace c10 { +namespace metal { + +// Atomic operations helper +template +struct AtomicType {}; +template +using AtomicType_t = typename AtomicType::type; + +template <> +struct AtomicType { + using type = ::metal::atomic; + static inline void atomic_add(device type* data, long offset, float value) { + ::metal::atomic_fetch_add_explicit( + data + offset, value, ::metal::memory_order_relaxed); + } +}; + +template <> +struct AtomicType { + using type = ::metal::atomic; + static inline void atomic_add(device type* data, long offset, int value) { + ::metal::atomic_fetch_add_explicit( + data + offset, value, ::metal::memory_order_relaxed); + } +}; + +// As of Metal3.2 atomic operations are not supported on half-precision floats, +// so they must be simulated Using atomic compare and exchange over 32-bit +// atomic type +template +static inline void atomic_add_helper( + device ::metal::atomic* data, + long offset, + T value) { + constexpr auto elem_per_enum = sizeof(uint) / sizeof(T); + auto ptr = data + (offset / elem_per_enum); + auto old = ::metal::atomic_load_explicit(ptr, ::metal::memory_order_relaxed); + union { + uint i; + T t[elem_per_enum]; + } val; + do { + val.i = old; + val.t[offset & (elem_per_enum - 1)] += value; + } while (!::metal::atomic_compare_exchange_weak_explicit( + ptr, + &old, + val.i, + ::metal::memory_order_relaxed, + ::metal::memory_order_relaxed)); +} + +template <> +struct AtomicType { + using type = ::metal::atomic; + static inline void atomic_add(device type* data, long offset, half value) { + atomic_add_helper(data, offset, value); + } +}; + +template <> +struct AtomicType { + using type = ::metal::atomic; + static inline void atomic_add(device type* data, long offset, short value) { + atomic_add_helper(data, offset, value); + } +}; + +template <> +struct AtomicType { + using type = ::metal::atomic; + static inline void atomic_add(device type* data, long offset, char value) { + atomic_add_helper(data, offset, value); + } +}; + +template <> +struct AtomicType { + using type = ::metal::atomic; + static inline void atomic_add(device type* data, long offset, char value) { + atomic_add_helper(data, offset, value); + } +}; + +template <> +struct AtomicType { + using type = ::metal::atomic; + static inline void atomic_add(device type* data, long offset, bfloat value) { + atomic_add_helper(data, offset, value); + } +}; + +// Metal supports atomic_store_explicit for bools, but +// sizeof(::metal::atomic_bool) is 4 Therefore it could not be used to +// atomically modify unaligned memory, so fall back to compare and exchange +// trick As accumulation over booleans are just or operation, do nothing if +// value is false +template <> +struct AtomicType { + using type = ::metal::atomic; + static inline void atomic_add(device type* data, long offset, bool value) { + if (!value) { + return; + } + auto ptr = data + (offset >> 2); + auto old = + ::metal::atomic_load_explicit(ptr, ::metal::memory_order_relaxed); + union { + uint i; + bool t[4]; + } val; + do { + val.i = old; + val.t[offset & 3] = true; + } while (!::metal::atomic_compare_exchange_weak_explicit( + ptr, + &old, + val.i, + ::metal::memory_order_relaxed, + ::metal::memory_order_relaxed)); + } +}; + +// ComplexHalf atomic op +template <> +struct AtomicType { + using type = ::metal::atomic; + static inline void atomic_add(device type* data, long offset, half2 value) { + auto ptr = data + offset; + auto old = + ::metal::atomic_load_explicit(ptr, ::metal::memory_order_relaxed); + while (!::metal::atomic_compare_exchange_weak_explicit( + ptr, + &old, + as_type(as_type(old) + value), + ::metal::memory_order_relaxed, + ::metal::memory_order_relaxed)) + ; + } +}; + +// There are no atomic 64-bit add in Metal yet, but templates below implements a +// consistent add I.e. if multiple threads are modify the same 64-bit value, +// results stored at the address will eventually be equal to its original value +// plus sum of all operands +template <> +struct AtomicType { + using type = ::metal::atomic; + static inline void atomic_add(device type* data, long offset, long value) { + const auto value_bits = as_type(value); + const uint low = static_cast(value_bits); + uint high = static_cast(value_bits >> 32); + auto ptr = data + (offset << 1); + auto old_low = + atomic_fetch_add_explicit(ptr, low, ::metal::memory_order_relaxed); + high += (old_low + low < old_low) ? 1 : 0; + atomic_fetch_add_explicit(ptr + 1, high, ::metal::memory_order_relaxed); + } +}; + +// ComplexFloat atomic op, which again is not really atomic, but eventually +// consistent +template <> +struct AtomicType { + using type = ::metal::atomic; + static inline void atomic_add(device type* data, long offset, float2 value) { + auto ptr = data + (offset << 1); + atomic_fetch_add_explicit(ptr + 0, value.x, ::metal::memory_order_relaxed); + atomic_fetch_add_explicit(ptr + 1, value.y, ::metal::memory_order_relaxed); + } +}; + +} // namespace metal +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/metal/common.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/metal/common.h new file mode 100644 index 0000000000000000000000000000000000000000..c508bbd55afa7077644bc5ff722ccbc46056e99c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/metal/common.h @@ -0,0 +1,50 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +// Set of global constants that could be shareable between CPU and Metal code + +#ifdef __METAL__ +#include +#define C10_METAL_CONSTEXPR constant constexpr +#else +#include +#define C10_METAL_CONSTEXPR constexpr +#endif + +#define C10_METAL_ALL_TYPES_FUNCTOR(_) \ + _(Byte, 0) \ + _(Char, 1) \ + _(Short, 2) \ + _(Int, 3) \ + _(Long, 4) \ + _(Half, 5) \ + _(Float, 6) \ + _(ComplexHalf, 8) \ + _(ComplexFloat, 9) \ + _(Bool, 11) \ + _(BFloat16, 15) + +namespace c10 { +namespace metal { +C10_METAL_CONSTEXPR unsigned max_ndim = 16; +C10_METAL_CONSTEXPR unsigned simdgroup_size = 32; + +#ifdef __METAL__ +template +using array = ::metal::array; +#else +template +using array = std::array; +#endif + +enum class ScalarType { +#define _DEFINE_ENUM_VAL_(_v, _n) _v = _n, + C10_METAL_ALL_TYPES_FUNCTOR(_DEFINE_ENUM_VAL_) +#undef _DEFINE_ENUM_VAL_ +}; + +} // namespace metal +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/metal/error.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/metal/error.h new file mode 100644 index 0000000000000000000000000000000000000000..25786e69bb6d9c37d69ce603aed53c8cb04a4a10 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/metal/error.h @@ -0,0 +1,116 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include + +namespace c10 { +namespace metal { +C10_METAL_CONSTEXPR unsigned error_message_count = 30; +struct ErrorMessage { + char file[128]; + char func[128]; + char message[250]; + unsigned int line; +}; + +struct ErrorMessages { +#ifdef __METAL__ + ::metal::atomic count; +#else + unsigned int count; +#endif + ErrorMessage msg[error_message_count]; +}; + +#ifdef __METAL__ +namespace detail { +static uint strncpy(device char* dst, constant const char* src, unsigned len) { + uint i = 0; + while (src[i] != 0 && i < len - 1) { + dst[i] = src[i]; + i++; + } + dst[i] = 0; + return i; +} + +inline uint print_arg( + device char* ptr, + unsigned len, + constant const char* arg) { + return strncpy(ptr, arg, len); +} + +// Returns number length as string in base10 +static inline uint base10_length(long num) { + uint rc = 1; + if (num < 0) { + num = -num; + rc += 1; + } + while (num > 9) { + num /= 10; + rc++; + } + return rc; +} + +// Converts signed integer to string +inline uint print_arg(device char* ptr, unsigned len, long arg) { + const auto arg_len = base10_length(arg); + if (arg_len >= len) + return 0; + if (arg < 0) { + ptr[0] = '-'; + arg = -arg; + } + uint idx = 1; + do { + ptr[arg_len - idx] = '0' + (arg % 10); + arg /= 10; + idx++; + } while (arg > 0); + ptr[arg_len] = 0; + return arg_len; +} + +template +inline void print_args(device char* ptr, unsigned len, T arg) { + print_arg(ptr, len, arg); +} + +template +inline void print_args(device char* ptr, unsigned len, T arg, Args... args) { + const auto rc = print_arg(ptr, len, arg); + print_args(ptr + rc, len - rc, args...); +} + +} // namespace detail + +template +static void report_error( + device ErrorMessages* msgs, + constant const char* file, + int line, + constant const char* func, + Args... args) { + const auto idx = + atomic_fetch_add_explicit(&msgs->count, 1, ::metal::memory_order_relaxed); + if (idx >= error_message_count) { + return; + } + device auto* msg = &msgs->msg[idx]; + detail::strncpy(msg->file, file, 128); + detail::strncpy(msg->func, func, 128); + detail::print_args(msg->message, 250, args...); + msg->line = line; +} + +#define TORCH_REPORT_ERROR(buf, ...) \ + ::c10::metal::report_error(buf, __FILE__, __LINE__, __func__, __VA_ARGS__) +#endif +} // namespace metal +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/metal/expm1f.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/metal/expm1f.h new file mode 100644 index 0000000000000000000000000000000000000000..18061b711232ddc8053f6672b23814fee5023926 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/metal/expm1f.h @@ -0,0 +1,102 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Copy-and-pasted from: +// https://github.com/ml-explore/mlx/blob/99c33d011d63174f50cea37c3eede002958be6d3/mlx/backend/metal/kernels/expm1f.h + +#pragma once + +#include + +// Original license copied below: +// Copyright (c) 2015-2023 Norbert Juffa +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// +// 1. Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +namespace c10 { +namespace metal { + +/* Compute exponential base e minus 1. Maximum ulp error = 0.997458 + + i = rint(a/log(2)), f = a-i*log(2). Then expm1(a) = 2**i * (expm1(f)+1) - 1. + Compute r = expm1(f). Then expm1(a)= 2 * (0.5 * 2**i * r + 0.5 * 2**i - 0.5). + With t = 0.5*2**i, expm1(a) = 2*(r * t + t-0.5). However, for best accuracy, + when i == 1, expm1(a)= 2*(r + 0.5), and when i == 0, expm1(a) = r. + + NOTE: Scale factor b is only applied if i < 0 or i > 1 (should be power of 2) +*/ +inline float expm1f_scaled_unchecked(float a, float b) { + float f, j, r, s, t, u, v, x, y; + int i; + + // exp(a) = 2**i * exp(f); i = rintf (a / log(2)) + j = ::metal::fma(1.442695f, a, 12582912.f); // 0x1.715476p0, 0x1.8p23 + j = j - 12582912.0f; // 0x1.8p23 + i = (int)j; + f = ::metal::fma(j, -6.93145752e-1f, a); + + // approximate r = exp(f)-1 on interval [-log(2)/2, +log(2)/2] + s = f * f; + if (a == 0.0f) + s = a; // ensure -0 is passed through + // err = 0.997458 ulp1 = 11081805 + r = 1.97350979e-4f; // 0x1.9de000p-13 + r = ::metal::fma(r, f, 1.39309070e-3f); // 0x1.6d30bcp-10 + r = ::metal::fma(r, f, 8.33343994e-3f); // 0x1.1111f6p-7 + r = ::metal::fma(r, f, 4.16668020e-2f); // 0x1.55559ep-5 + r = ::metal::fma(r, f, 1.66666716e-1f); // 0x1.55555cp-3 + r = ::metal::fma(r, f, 4.99999970e-1f); // 0x1.fffffep-2 + u = (j == 1) ? (f + 0.5f) : f; + v = ::metal::fma(r, s, u); + s = 0.5f * b; + t = ::metal::ldexp(s, i); + y = t - s; + x = (t - y) - s; // double-float canonicalization of difference + r = ::metal::fma(v, t, x) + y; + r = r + r; + if (j == 0) + r = v; + if (j == 1) + r = v + v; + return r; +} + +/* Compute exponential base e minus 1. max ulp err = 0.99746 */ +inline float expm1f(float a) { + float r; + + r = expm1f_scaled_unchecked(a, 1.0f); + /* handle severe overflow and underflow */ + if (::metal::abs(a - 1.0f) > 88.0f) { + r = ::metal::pow(2, a); + r = ::metal::fma(r, r, -1.0f); + } + return r; +} + +} // namespace metal +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/metal/igamma.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/metal/igamma.h new file mode 100644 index 0000000000000000000000000000000000000000..4fb235e226ad27e7bb94b76a02172df86ce4c17f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/metal/igamma.h @@ -0,0 +1,749 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include + +using namespace c10::metal; +using namespace metal; + +namespace c10 { +namespace metal { + +template +inline float log_gamma(const T); + +inline float expm1f(float a); + +template +float erfc(T x); + +} // namespace metal +} // namespace c10 + +namespace { + +template +inline float lgamma(const T a) { + return log_gamma(a); +} + +inline float expm1(float a) { + return expm1f(a); +} + +// NOTE: The following code was ported directly from the CUDA implementation in +// `aten/src/ATen/native/cuda/IGammaKernel.cu` + +/* + * This implementation of the regularized incomplete gamma functions and + * their helper functions are derived from the implementation of SciPy's + * gammainc, Cephes's igam and igamc, and Boost's Lanczos approximations. + * See NOTICE for the licenses. + */ +// regularized lower & upper incomplete gamma +template +scalar_t ratevl( + scalar_t x, + const scalar_t num[], + int64_t M, + const scalar_t denom[], + int64_t N) { + // evaluating rational function, i.e., the ratio of two polynomials + // the coefficients for numerator are given by `num` while coeffs for + // denumerator are given by `denom` + + using accscalar_t = opmath_t; + int64_t i, dir; + accscalar_t y, num_ans, denom_ans; + accscalar_t absx = ::fabs(x); + thread const accscalar_t* p; + + if (absx > 1) { + /* Evaluate as a polynomial in 1/x. */ + dir = -1; + p = num + M; + y = 1 / x; + } else { + dir = 1; + p = num; + y = x; + } + + /* Evaluate the numerator */ + num_ans = *p; + p += dir; + for (i = 1; i <= M; i++) { + num_ans = num_ans * y + *p; + p += dir; + } + /* Evaluate the denominator */ + if (absx > 1) { + p = denom + N; + } else { + p = denom; + } + + denom_ans = *p; + p += dir; + for (i = 1; i <= N; i++) { + denom_ans = denom_ans * y + *p; + p += dir; + } + if (absx > 1) { + i = N - M; + return ::pow(x, static_cast(i)) * num_ans / denom_ans; + } else { + return num_ans / denom_ans; + } +} + +template +scalar_t lanczos_sum_expg_scaled(scalar_t x) { + // lanczos approximation + using accscalar_t = opmath_t; + + const accscalar_t lanczos_sum_expg_scaled_num[13] = { + 0.006061842346248906525783753964555936883222, + 0.5098416655656676188125178644804694509993, + 19.51992788247617482847860966235652136208, + 449.9445569063168119446858607650988409623, + 6955.999602515376140356310115515198987526, + 75999.29304014542649875303443598909137092, + 601859.6171681098786670226533699352302507, + 3481712.15498064590882071018964774556468, + 14605578.08768506808414169982791359218571, + 43338889.32467613834773723740590533316085, + 86363131.28813859145546927288977868422342, + 103794043.1163445451906271053616070238554, + 56906521.91347156388090791033559122686859}; + const accscalar_t lanczos_sum_expg_scaled_denom[13] = { + 1., + 66., + 1925., + 32670., + 357423., + 2637558., + 13339535., + 45995730., + 105258076., + 150917976., + 120543840., + 39916800., + 0}; + return ratevl( + static_cast(x), + lanczos_sum_expg_scaled_num, + sizeof(lanczos_sum_expg_scaled_num) / + sizeof(lanczos_sum_expg_scaled_num[0]) - + 1, + lanczos_sum_expg_scaled_denom, + sizeof(lanczos_sum_expg_scaled_denom) / + sizeof(lanczos_sum_expg_scaled_denom[0]) - + 1); +} + +template +scalar_t _igam_helper_fac(scalar_t a, scalar_t x) { + // compute x^a * exp(-a) / gamma(a) + // corrected from (15) and (16) in [igam2] by replacing exp(x - a) with + // exp(a - x). + + using accscalar_t = opmath_t; + accscalar_t ax, fac, res, num, numfac; + const accscalar_t MAXLOG = 88.72283905206835; + const accscalar_t EXP1 = 2.718281828459045; + const accscalar_t lanczos_g = 6.024680040776729583740234375; + + if (::fabs(a - x) > 0.4 * ::fabs(a)) { + ax = a * ::log(x) - x - ::lgamma(a); + if (ax < -MAXLOG) { + return 0.0; + } + return ::exp(ax); + } + + fac = a + lanczos_g - 0.5; + res = ::sqrt(fac / EXP1) / lanczos_sum_expg_scaled(a); + + if ((a < 200) && (x < 200)) { + res *= ::exp(a - x) * ::pow(x / fac, a); + } else { + num = x - a - lanczos_g + 0.5; + numfac = num / fac; + res *= ::exp(a * (::log1p(numfac) - numfac) + x * (0.5 - lanczos_g) / fac); + } + return res; +} + +template +scalar_t _igam_helper_series(scalar_t a, scalar_t x) { + // Compute igam using DLMF 8.11.4. [igam1] + + using accscalar_t = opmath_t; + const accscalar_t MACHEP = 5.9604644775390625E-8; + const int MAXITER = 2000; + + int i; + accscalar_t ans, ax, c, r; + + ax = _igam_helper_fac(a, x); + if (ax == 0.0) { + return 0.0; + } + + /* power series */ + r = a; + c = 1.0; + ans = 1.0; + + for (i = 0; i < MAXITER; i++) { + r += 1.0; + c *= x / r; + ans += c; + if (c <= MACHEP * ans) { + break; + } + } + return (ans * ax / a); +} + +template +scalar_t _igamc_helper_series(scalar_t a, scalar_t x) { + // Compute igamc using DLMF 8.7.3 [igam1]. This is related to the series in + // _igam_helper_series but extra care is taken to avoid cancellation. + + using accscalar_t = opmath_t; + int n; + accscalar_t fac = 1; + accscalar_t sum = 0; + accscalar_t term, logx; + const int MAXITER = 2000; + const accscalar_t MACHEP = 5.9604644775390625E-8; + + for (n = 1; n < MAXITER; n++) { + fac *= -x / n; + term = fac / (a + n); + sum += term; + if (::fabs(term) <= MACHEP * ::fabs(sum)) { + break; + } + } + + logx = ::log(x); + term = -::expm1(a * logx - ::lgamma(1 + a)); + return term - ::exp(a * logx - ::lgamma(a)) * sum; +} + +template +scalar_t _igam_helper_asymptotic_series(scalar_t a, scalar_t x, bool igam) { + // Compute igam/igamc using DLMF 8.12.3/8.12.4 [igam1] + + using accscalar_t = opmath_t; + const accscalar_t d[25][25] = { + {-3.3333333333333333e-1, 8.3333333333333333e-2, + -1.4814814814814815e-2, 1.1574074074074074e-3, + 3.527336860670194e-4, -1.7875514403292181e-4, + 3.9192631785224378e-5, -2.1854485106799922e-6, + -1.85406221071516e-6, 8.296711340953086e-7, + -1.7665952736826079e-7, 6.7078535434014986e-9, + 1.0261809784240308e-8, -4.3820360184533532e-9, + 9.1476995822367902e-10, -2.551419399494625e-11, + -5.8307721325504251e-11, 2.4361948020667416e-11, + -5.0276692801141756e-12, 1.1004392031956135e-13, + 3.3717632624009854e-13, -1.3923887224181621e-13, + 2.8534893807047443e-14, -5.1391118342425726e-16, + -1.9752288294349443e-15}, + {-1.8518518518518519e-3, -3.4722222222222222e-3, 2.6455026455026455e-3, + -9.9022633744855967e-4, 2.0576131687242798e-4, -4.0187757201646091e-7, + -1.8098550334489978e-5, 7.6491609160811101e-6, -1.6120900894563446e-6, + 4.6471278028074343e-9, 1.378633446915721e-7, -5.752545603517705e-8, + 1.1951628599778147e-8, -1.7543241719747648e-11, -1.0091543710600413e-9, + 4.1627929918425826e-10, -8.5639070264929806e-11, 6.0672151016047586e-14, + 7.1624989648114854e-12, -2.9331866437714371e-12, 5.9966963656836887e-13, + -2.1671786527323314e-16, -4.9783399723692616e-14, 2.0291628823713425e-14, + -4.13125571381061e-15}, + {4.1335978835978836e-3, -2.6813271604938272e-3, 7.7160493827160494e-4, + 2.0093878600823045e-6, -1.0736653226365161e-4, 5.2923448829120125e-5, + -1.2760635188618728e-5, 3.4235787340961381e-8, 1.3721957309062933e-6, + -6.298992138380055e-7, 1.4280614206064242e-7, -2.0477098421990866e-10, + -1.4092529910867521e-8, 6.228974084922022e-9, -1.3670488396617113e-9, + 9.4283561590146782e-13, 1.2872252400089318e-10, -5.5645956134363321e-11, + 1.1975935546366981e-11, -4.1689782251838635e-15, -1.0940640427884594e-12, + 4.6622399463901357e-13, -9.905105763906906e-14, 1.8931876768373515e-17, + 8.8592218725911273e-15}, + {6.4943415637860082e-4, 2.2947209362139918e-4, -4.6918949439525571e-4, + 2.6772063206283885e-4, -7.5618016718839764e-5, -2.3965051138672967e-7, + 1.1082654115347302e-5, -5.6749528269915966e-6, 1.4230900732435884e-6, + -2.7861080291528142e-11, -1.6958404091930277e-7, 8.0994649053880824e-8, + -1.9111168485973654e-8, 2.3928620439808118e-12, 2.0620131815488798e-9, + -9.4604966618551322e-10, 2.1541049775774908e-10, -1.388823336813903e-14, + -2.1894761681963939e-11, 9.7909989511716851e-12, -2.1782191880180962e-12, + 6.2088195734079014e-17, 2.126978363279737e-13, -9.3446887915174333e-14, + 2.0453671226782849e-14}, + {-8.618882909167117e-4, 7.8403922172006663e-4, + -2.9907248030319018e-4, -1.4638452578843418e-6, + 6.6414982154651222e-5, -3.9683650471794347e-5, + 1.1375726970678419e-5, 2.5074972262375328e-10, + -1.6954149536558306e-6, 8.9075075322053097e-7, + -2.2929348340008049e-7, 2.956794137544049e-11, + 2.8865829742708784e-8, -1.4189739437803219e-8, + 3.4463580499464897e-9, -2.3024517174528067e-13, + -3.9409233028046405e-10, 1.8602338968504502e-10, + -4.356323005056618e-11, 1.2786001016296231e-15, + 4.6792750266579195e-12, -2.1492464706134829e-12, + 4.9088156148096522e-13, -6.3385914848915603e-18, + -5.0453320690800944e-14}, + {-3.3679855336635815e-4, -6.9728137583658578e-5, 2.7727532449593921e-4, + -1.9932570516188848e-4, 6.7977804779372078e-5, 1.419062920643967e-7, + -1.3594048189768693e-5, 8.0184702563342015e-6, -2.2914811765080952e-6, + -3.252473551298454e-10, 3.4652846491085265e-7, -1.8447187191171343e-7, + 4.8240967037894181e-8, -1.7989466721743515e-14, -6.3061945000135234e-9, + 3.1624176287745679e-9, -7.8409242536974293e-10, 5.1926791652540407e-15, + 9.3589442423067836e-11, -4.5134262161632782e-11, 1.0799129993116827e-11, + -3.661886712685252e-17, -1.210902069055155e-12, 5.6807435849905643e-13, + -1.3249659916340829e-13}, + {5.3130793646399222e-4, -5.9216643735369388e-4, 2.7087820967180448e-4, + 7.9023532326603279e-7, -8.1539693675619688e-5, 5.6116827531062497e-5, + -1.8329116582843376e-5, -3.0796134506033048e-9, 3.4651553688036091e-6, + -2.0291327396058604e-6, 5.7887928631490037e-7, 2.338630673826657e-13, + -8.8286007463304835e-8, 4.7435958880408128e-8, -1.2545415020710382e-8, + 8.6496488580102925e-14, 1.6846058979264063e-9, -8.5754928235775947e-10, + 2.1598224929232125e-10, -7.6132305204761539e-16, -2.6639822008536144e-11, + 1.3065700536611057e-11, -3.1799163902367977e-12, 4.7109761213674315e-18, + 3.6902800842763467e-13}, + {3.4436760689237767e-4, 5.1717909082605922e-5, + -3.3493161081142236e-4, 2.812695154763237e-4, + -1.0976582244684731e-4, -1.2741009095484485e-7, + 2.7744451511563644e-5, -1.8263488805711333e-5, + 5.7876949497350524e-6, 4.9387589339362704e-10, + -1.0595367014026043e-6, 6.1667143761104075e-7, + -1.7562973359060462e-7, -1.2974473287015439e-12, + 2.695423606288966e-8, -1.4578352908731271e-8, + 3.887645959386175e-9, -3.8810022510194121e-17, + -5.3279941738772867e-10, 2.7437977643314845e-10, + -6.9957960920705679e-11, 2.5899863874868481e-17, + 8.8566890996696381e-12, -4.403168815871311e-12, + 1.0865561947091654e-12}, + {-6.5262391859530942e-4, 8.3949872067208728e-4, -4.3829709854172101e-4, + -6.969091458420552e-7, 1.6644846642067548e-4, -1.2783517679769219e-4, + 4.6299532636913043e-5, 4.5579098679227077e-9, -1.0595271125805195e-5, + 6.7833429048651666e-6, -2.1075476666258804e-6, -1.7213731432817145e-11, + 3.7735877416110979e-7, -2.1867506700122867e-7, 6.2202288040189269e-8, + 6.5977038267330006e-16, -9.5903864974256858e-9, 5.2132144922808078e-9, + -1.3991589583935709e-9, 5.382058999060575e-16, 1.9484714275467745e-10, + -1.0127287556389682e-10, 2.6077347197254926e-11, -5.0904186999932993e-18, + -3.3721464474854592e-12}, + {-5.9676129019274625e-4, -7.2048954160200106e-5, + 6.7823088376673284e-4, -6.4014752602627585e-4, + 2.7750107634328704e-4, 1.8197008380465151e-7, + -8.4795071170685032e-5, 6.105192082501531e-5, + -2.1073920183404862e-5, -8.8585890141255994e-10, + 4.5284535953805377e-6, -2.8427815022504408e-6, + 8.7082341778646412e-7, 3.6886101871706965e-12, + -1.5344695190702061e-7, 8.862466778790695e-8, + -2.5184812301826817e-8, -1.0225912098215092e-14, + 3.8969470758154777e-9, -2.1267304792235635e-9, + 5.7370135528051385e-10, -1.887749850169741e-19, + -8.0931538694657866e-11, 4.2382723283449199e-11, + -1.1002224534207726e-11}, + {1.3324454494800656e-3, -1.9144384985654775e-3, 1.1089369134596637e-3, + 9.932404122642299e-7, -5.0874501293093199e-4, 4.2735056665392884e-4, + -1.6858853767910799e-4, -8.1301893922784998e-9, 4.5284402370562147e-5, + -3.127053674781734e-5, 1.044986828530338e-5, 4.8435226265680926e-11, + -2.1482565873456258e-6, 1.329369701097492e-6, -4.0295693092101029e-7, + -1.7567877666323291e-13, 7.0145043163668257e-8, -4.040787734999483e-8, + 1.1474026743371963e-8, 3.9642746853563325e-18, -1.7804938269892714e-9, + 9.7480262548731646e-10, -2.6405338676507616e-10, 5.794875163403742e-18, + 3.7647749553543836e-11}, + {1.579727660730835e-3, 1.6251626278391582e-4, -2.0633421035543276e-3, + 2.1389686185689098e-3, -1.0108559391263003e-3, -3.9912705529919201e-7, + 3.6235025084764691e-4, -2.8143901463712154e-4, 1.0449513336495887e-4, + 2.1211418491830297e-9, -2.5779417251947842e-5, 1.7281818956040463e-5, + -5.6413773872904282e-6, -1.1024320105776174e-11, 1.1223224418895175e-6, + -6.8693396379526735e-7, 2.0653236975414887e-7, 4.6714772409838506e-14, + -3.5609886164949055e-8, 2.0470855345905963e-8, -5.8091738633283358e-9, + -1.332821287582869e-16, 9.0354604391335133e-10, -4.9598782517330834e-10, + 1.3481607129399749e-10}, + {-4.0725121195140166e-3, 6.4033628338080698e-3, -4.0410161081676618e-3, + -2.183732802866233e-6, 2.1740441801254639e-3, -1.9700440518418892e-3, + 8.3595469747962458e-4, 1.9445447567109655e-8, -2.5779387120421696e-4, + 1.9009987368139304e-4, -6.7696499937438965e-5, -1.4440629666426572e-10, + 1.5712512518742269e-5, -1.0304008744776893e-5, 3.304517767401387e-6, + 7.9829760242325709e-13, -6.4097794149313004e-7, 3.8894624761300056e-7, + -1.1618347644948869e-7, -2.816808630596451e-15, 1.9878012911297093e-8, + -1.1407719956357511e-8, 3.2355857064185555e-9, 4.1759468293455945e-20, + -5.0423112718105824e-10}, + {-5.9475779383993003e-3, -5.4016476789260452e-4, 8.7910413550767898e-3, + -9.8576315587856125e-3, 5.0134695031021538e-3, 1.2807521786221875e-6, + -2.0626019342754683e-3, 1.7109128573523058e-3, -6.7695312714133799e-4, + -6.9011545676562133e-9, 1.8855128143995902e-4, -1.3395215663491969e-4, + 4.6263183033528039e-5, 4.0034230613321351e-11, -1.0255652921494033e-5, + 6.612086372797651e-6, -2.0913022027253008e-6, -2.0951775649603837e-13, + 3.9756029041993247e-7, -2.3956211978815887e-7, 7.1182883382145864e-8, + 8.925574873053455e-16, -1.2101547235064676e-8, 6.9350618248334386e-9, + -1.9661464453856102e-9}, + {1.7402027787522711e-2, -2.9527880945699121e-2, 2.0045875571402799e-2, + 7.0289515966903407e-6, -1.2375421071343148e-2, 1.1976293444235254e-2, + -5.4156038466518525e-3, -6.3290893396418616e-8, 1.8855118129005065e-3, + -1.473473274825001e-3, 5.5515810097708387e-4, 5.2406834412550662e-10, + -1.4357913535784836e-4, 9.9181293224943297e-5, -3.3460834749478311e-5, + -3.5755837291098993e-12, 7.1560851960630076e-6, -4.5516802628155526e-6, + 1.4236576649271475e-6, 1.8803149082089664e-14, -2.6623403898929211e-7, + 1.5950642189595716e-7, -4.7187514673841102e-8, -6.5107872958755177e-17, + 7.9795091026746235e-9}, + {3.0249124160905891e-2, 2.4817436002649977e-3, -4.9939134373457022e-2, + 5.9915643009307869e-2, -3.2483207601623391e-2, -5.7212968652103441e-6, + 1.5085251778569354e-2, -1.3261324005088445e-2, 5.5515262632426148e-3, + 3.0263182257030016e-8, -1.7229548406756723e-3, 1.2893570099929637e-3, + -4.6845138348319876e-4, -1.830259937893045e-10, 1.1449739014822654e-4, + -7.7378565221244477e-5, 2.5625836246985201e-5, 1.0766165333192814e-12, + -5.3246809282422621e-6, 3.349634863064464e-6, -1.0381253128684018e-6, + -5.608909920621128e-15, 1.9150821930676591e-7, -1.1418365800203486e-7, + 3.3654425209171788e-8}, + {-9.9051020880159045e-2, 1.7954011706123486e-1, -1.2989606383463778e-1, + -3.1478872752284357e-5, 9.0510635276848131e-2, -9.2828824411184397e-2, + 4.4412112839877808e-2, 2.7779236316835888e-7, -1.7229543805449697e-2, + 1.4182925050891573e-2, -5.6214161633747336e-3, -2.39598509186381e-9, + 1.6029634366079908e-3, -1.1606784674435773e-3, 4.1001337768153873e-4, + 1.8365800754090661e-11, -9.5844256563655903e-5, 6.3643062337764708e-5, + -2.076250624489065e-5, -1.1806020912804483e-13, 4.2131808239120649e-6, + -2.6262241337012467e-6, 8.0770620494930662e-7, 6.0125912123632725e-16, + -1.4729737374018841e-7}, + {-1.9994542198219728e-1, -1.5056113040026424e-2, 3.6470239469348489e-1, + -4.6435192311733545e-1, 2.6640934719197893e-1, 3.4038266027147191e-5, + -1.3784338709329624e-1, 1.276467178337056e-1, -5.6213828755200985e-2, + -1.753150885483011e-7, 1.9235592956768113e-2, -1.5088821281095315e-2, + 5.7401854451350123e-3, 1.0622382710310225e-9, -1.5335082692563998e-3, + 1.0819320643228214e-3, -3.7372510193945659e-4, -6.6170909729031985e-12, + 8.4263617380909628e-5, -5.5150706827483479e-5, 1.7769536448348069e-5, + 3.8827923210205533e-14, -3.53513697488768e-6, 2.1865832130045269e-6, + -6.6812849447625594e-7}, + {7.2438608504029431e-1, -1.3918010932653375, 1.0654143352413968, + 1.876173868950258e-4, -8.2705501176152696e-1, 8.9352433347828414e-1, + -4.4971003995291339e-1, -1.6107401567546652e-6, 1.9235590165271091e-1, + -1.6597702160042609e-1, 6.8882222681814333e-2, 1.3910091724608687e-8, + -2.146911561508663e-2, 1.6228980898865892e-2, -5.9796016172584256e-3, + -1.1287469112826745e-10, 1.5167451119784857e-3, -1.0478634293553899e-3, + 3.5539072889126421e-4, 8.1704322111801517e-13, -7.7773013442452395e-5, + 5.0291413897007722e-5, -1.6035083867000518e-5, 1.2469354315487605e-14, + 3.1369106244517615e-6}, + {1.6668949727276811, 1.165462765994632e-1, -3.3288393225018906, + 4.4692325482864037, -2.6977693045875807, -2.600667859891061e-4, + 1.5389017615694539, -1.4937962361134612, 6.8881964633233148e-1, + 1.3077482004552385e-6, -2.5762963325596288e-1, 2.1097676102125449e-1, + -8.3714408359219882e-2, -7.7920428881354753e-9, 2.4267923064833599e-2, + -1.7813678334552311e-2, 6.3970330388900056e-3, 4.9430807090480523e-11, + -1.5554602758465635e-3, 1.0561196919903214e-3, -3.5277184460472902e-4, + 9.3002334645022459e-14, 7.5285855026557172e-5, -4.8186515569156351e-5, + 1.5227271505597605e-5}, + {-6.6188298861372935, 1.3397985455142589e+1, -1.0789350606845146e+1, + -1.4352254537875018e-3, 9.2333694596189809, -1.0456552819547769e+1, + 5.5105526029033471, 1.2024439690716742e-5, -2.5762961164755816, + 2.3207442745387179, -1.0045728797216284, -1.0207833290021914e-7, + 3.3975092171169466e-1, -2.6720517450757468e-1, 1.0235252851562706e-1, + 8.4329730484871625e-10, -2.7998284958442595e-2, 2.0066274144976813e-2, + -7.0554368915086242e-3, 1.9402238183698188e-12, 1.6562888105449611e-3, + -1.1082898580743683e-3, 3.654545161310169e-4, -5.1290032026971794e-11, + -7.6340103696869031e-5}, + {-1.7112706061976095e+1, -1.1208044642899116, 3.7131966511885444e+1, + -5.2298271025348962e+1, 3.3058589696624618e+1, 2.4791298976200222e-3, + -2.061089403411526e+1, 2.088672775145582e+1, -1.0045703956517752e+1, + -1.2238783449063012e-5, 4.0770134274221141, -3.473667358470195, + 1.4329352617312006, 7.1359914411879712e-8, -4.4797257159115612e-1, + 3.4112666080644461e-1, -1.2699786326594923e-1, -2.8953677269081528e-10, + 3.3125776278259863e-2, -2.3274087021036101e-2, 8.0399993503648882e-3, + -1.177805216235265e-9, -1.8321624891071668e-3, 1.2108282933588665e-3, + -3.9479941246822517e-4}, + {7.389033153567425e+1, -1.5680141270402273e+2, 1.322177542759164e+2, + 1.3692876877324546e-2, -1.2366496885920151e+2, 1.4620689391062729e+2, + -8.0365587724865346e+1, -1.1259851148881298e-4, 4.0770132196179938e+1, + -3.8210340013273034e+1, 1.719522294277362e+1, 9.3519707955168356e-7, + -6.2716159907747034, 5.1168999071852637, -2.0319658112299095, + -4.9507215582761543e-9, 5.9626397294332597e-1, -4.4220765337238094e-1, + 1.6079998700166273e-1, -2.4733786203223402e-8, -4.0307574759979762e-2, + 2.7849050747097869e-2, -9.4751858992054221e-3, 6.419922235909132e-6, + 2.1250180774699461e-3}, + {2.1216837098382522e+2, 1.3107863022633868e+1, -4.9698285932871748e+2, + 7.3121595266969204e+2, -4.8213821720890847e+2, -2.8817248692894889e-2, + 3.2616720302947102e+2, -3.4389340280087117e+2, 1.7195193870816232e+2, + 1.4038077378096158e-4, -7.52594195897599e+1, 6.651969984520934e+1, + -2.8447519748152462e+1, -7.613702615875391e-7, 9.5402237105304373, + -7.5175301113311376, 2.8943997568871961, -4.6612194999538201e-7, + -8.0615149598794088e-1, 5.8483006570631029e-1, -2.0845408972964956e-1, + 1.4765818959305817e-4, 5.1000433863753019e-2, -3.3066252141883665e-2, + 1.5109265210467774e-2}, + {-9.8959643098322368e+2, 2.1925555360905233e+3, -1.9283586782723356e+3, + -1.5925738122215253e-1, 1.9569985945919857e+3, -2.4072514765081556e+3, + 1.3756149959336496e+3, 1.2920735237496668e-3, -7.525941715948055e+2, + 7.3171668742208716e+2, -3.4137023466220065e+2, -9.9857390260608043e-6, + 1.3356313181291573e+2, -1.1276295161252794e+2, 4.6310396098204458e+1, + -7.9237387133614756e-6, -1.4510726927018646e+1, 1.1111771248100563e+1, + -4.1690817945270892, 3.1008219800117808e-3, 1.1220095449981468, + -7.6052379926149916e-1, 3.6262236505085254e-1, 2.216867741940747e-1, + 4.8683443692930507e-1}}; + + int k, n, sgn; + int maxpow = 0; + const accscalar_t MACHEP = 5.9604644775390625E-8; + accscalar_t lambda = x / a; + accscalar_t sigma = (x - a) / a; + accscalar_t eta, res, ck, ckterm, term, absterm; + accscalar_t absoldterm = INFINITY; + accscalar_t etapow[25] = {1}; + accscalar_t sum = 0; + accscalar_t afac = 1; + + if (igam) { + sgn = -1; + } else { + sgn = 1; + } + + if (lambda > 1) { + eta = ::sqrt(-2 * (::log1p(sigma) - sigma)); + } else if (lambda < 1) { + eta = -::sqrt(-2 * (::log1p(sigma) - sigma)); + } else { + eta = 0; + } + res = 0.5 * ::erfc(sgn * eta * ::sqrt(a / 2)); + + for (k = 0; k < 25; k++) { + ck = d[k][0]; + for (n = 1; n < 25; n++) { + if (n > maxpow) { + etapow[n] = eta * etapow[n - 1]; + maxpow += 1; + } + ckterm = d[k][n] * etapow[n]; + ck += ckterm; + if (::fabs(ckterm) < MACHEP * ::fabs(ck)) { + break; + } + } + term = ck * afac; + absterm = ::fabs(term); + if (absterm > absoldterm) { + break; + } + sum += term; + if (absterm < MACHEP * ::fabs(sum)) { + break; + } + absoldterm = absterm; + afac /= a; + } + res += sgn * ::exp(-0.5 * a * eta * eta) * sum / ::sqrt(2 * 3.1415926535 * a); + + return res; +} + +template +scalar_t _igamc_helper_continued_fraction(scalar_t a, scalar_t x) { + // Compute igamc using DLMF 8.9.2. [igam1] + + using accscalar_t = opmath_t; + int i; + accscalar_t ans, ax, c, yc, r, t, y, z; + accscalar_t pk, pkm1, pkm2, qk, qkm1, qkm2; + const int MAXITER = 2000; + const accscalar_t MACHEP = 5.9604644775390625E-8; + const accscalar_t BIG = 16777216.; + const accscalar_t BIGINV = 5.9604644775390625E-8; + + ax = _igam_helper_fac(a, x); + if (ax == 0.0) { + return 0.0; + } + + /* continued fraction */ + y = 1.0 - a; + z = x + y + 1.0; + c = 0.0; + pkm2 = 1.0; + qkm2 = x; + pkm1 = x + 1.0; + qkm1 = z * x; + ans = pkm1 / qkm1; + + for (i = 0; i < MAXITER; i++) { + c += 1.0; + y += 1.0; + z += 2.0; + yc = y * c; + pk = pkm1 * z - pkm2 * yc; + qk = qkm1 * z - qkm2 * yc; + if (qk != 0) { + r = pk / qk; + t = ::fabs((ans - r) / r); + ans = r; + } else { + t = 1.0; + } + pkm2 = pkm1; + pkm1 = pk; + qkm2 = qkm1; + qkm1 = qk; + if (::fabs(pk) > BIG) { + pkm2 *= BIGINV; + pkm1 *= BIGINV; + qkm2 *= BIGINV; + qkm1 *= BIGINV; + } + if (t <= MACHEP) { + break; + } + } + return ans * ax; +} + +template +scalar_t calc_igammac(scalar_t a, scalar_t x) { + /* the calculation of the regularized upper incomplete gamma function + * is done differently based on the values of a and x: + * - if x and/or a is at the boundary of defined region, then assign the + * result at the boundary + * - if a is large and a ~ x, then using Uniform Asymptotic Expansions for + * Large Parameter (see DLMF 8.12.4 [igam1]) + * - if x > 1.1 and x < a, using the subtraction from the regularized lower + * incomplete gamma + * - otherwise, calculate the series from [igam2] eq (5) + */ + + using accscalar_t = opmath_t; + accscalar_t absxma_a; + + const accscalar_t SMALL = 20.0; + const accscalar_t LARGE = 200.0; + const accscalar_t SMALLRATIO = 0.3; + const accscalar_t LARGERATIO = 4.5; + + if ((x < 0) || (a < 0)) { + // out of defined-region of the function + return NAN; + } else if (a == 0) { + if (x > 0) { + return 0.0; + } else { + return NAN; + } + } else if (x == 0) { + return 1.0; + } else if (isinf(a)) { + if (isinf(x)) { + return NAN; + } + return 1.0; + } else if (isinf(x)) { + return 0.0; + } + + absxma_a = ::fabs(x - a) / a; + if ((a > SMALL) && (a < LARGE) && (absxma_a < SMALLRATIO)) { + return _igam_helper_asymptotic_series(a, x, 0); + } else if ((a > LARGE) && (absxma_a < LARGERATIO / ::sqrt(a))) { + return _igam_helper_asymptotic_series(a, x, 0); + } + + if (x > 1.1) { + if (x < a) { + return 1.0 - _igam_helper_series(a, x); + } else { + return _igamc_helper_continued_fraction(a, x); + } + } else if (x <= 0.5) { + if (-0.4 / ::log(x) < a) { + return 1.0 - _igam_helper_series(a, x); + } else { + return _igamc_helper_series(a, x); + } + } else { + if (x * 1.1 < a) { + return 1.0 - _igam_helper_series(a, x); + } else { + return _igamc_helper_series(a, x); + } + } +} + +template +scalar_t calc_igamma(scalar_t a, scalar_t x) { + /* the calculation of the regularized lower incomplete gamma function + * is done differently based on the values of a and x: + * - if x and/or a is at the boundary of defined region, then assign the + * result at the boundary + * - if a is large and a ~ x, then using Uniform Asymptotic Expansions for + * Large Parameter (see DLMF 8.12.3 [igam1]) + * - if x > 1 and x > a, using the subtraction from the regularized upper + * incomplete gamma + * - otherwise, calculate the series from [igam2] eq (4) + */ + + using accscalar_t = opmath_t; + accscalar_t absxma_a; + const accscalar_t SMALL = 20.0; + const accscalar_t LARGE = 200.0; + const accscalar_t SMALLRATIO = 0.3; + const accscalar_t LARGERATIO = 4.5; + + // boundary values following SciPy + if ((x < 0) || (a < 0)) { + // out of defined-region of the function + return NAN; + } else if (a == 0) { + if (x > 0) { + return 1.0; + } else { + return NAN; + } + } else if (x == 0) { + return 0.0; // zero integration limit + } else if (isinf(a)) { + if (isinf(x)) { + return NAN; + } + return 0.0; + } else if (isinf(x)) { + return 1.0; + } + + /* Asymptotic regime where a ~ x. */ + absxma_a = ::fabs(x - a) / a; + if ((a > SMALL) && (a < LARGE) && (absxma_a < SMALLRATIO)) { + return _igam_helper_asymptotic_series(a, x, 1); + } else if ((a > LARGE) && (absxma_a < LARGERATIO / ::sqrt(a))) { + return _igam_helper_asymptotic_series(a, x, 1); + } + + if ((x > 1.0) && (x > a)) { + return 1.0 - calc_igammac(a, x); + } + + return _igam_helper_series(a, x); +} + +} // namespace + +// end of regularized lower & upper incomplete gamma + +namespace c10 { +namespace metal { + +template +inline T igamma(T a, T b) { + return calc_igamma(a, b); +} + +template +inline T igammac(T a, T b) { + return calc_igammac(a, b); +} + +} // namespace metal +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/metal/indexing.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/metal/indexing.h new file mode 100644 index 0000000000000000000000000000000000000000..3a35aa1b87a2aa9a80cfaafd3d0cf0cf3076a215 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/metal/indexing.h @@ -0,0 +1,1050 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Metal indexing primitives +#pragma once +#include +#include +#include + +namespace c10 { +namespace metal { + +// Given coordinates and strides, calculates offset from the start of the +// tensors +template +inline T offset_from_coord( + thread T idx[max_ndim], + constant long* strides, + uint ndim) { + T rc = 0; + for (uint i = 0; i < ndim; ++i) { + rc += idx[i] * T(strides[i]); + } + return rc; +} + +// Given thread index calculates position in the ndim tensor +template +inline void pos_from_thread_index( + T idx, + thread T pos[max_ndim], + constant long* sizes, + uint ndim) { + for (uint i = 0; i < ndim; ++i) { + pos[i] = idx % T(sizes[i]); + idx /= T(sizes[i]); + } +} + +inline long offset_from_thread_index( + long idx, + constant long* sizes, + constant long* strides, + uint ndim) { + long pos[max_ndim]; + pos_from_thread_index(idx, pos, sizes, ndim); + return offset_from_coord(pos, strides, ndim); +} + +template +kernel void unary_dense( + device result_of* output [[buffer(0)]], + constant T* input [[buffer(1)]], + uint index [[thread_position_in_grid]]) { + F f; + output[index] = f(input[index]); +} + +template +kernel void unary_strided( + device result_of* output [[buffer(0)]], + constant T* input [[buffer(1)]], + constant long* sizes [[buffer(2)]], + constant long* input_strides [[buffer(3)]], + constant long* output_strides [[buffer(4)]], + constant uint& ndim [[buffer(5)]], + uint index [[thread_position_in_grid]]) { + F f; + int pos[max_ndim]; + pos_from_thread_index(int(index), pos, sizes, ndim); + const auto input_offs = offset_from_coord(pos, input_strides, ndim); + const auto output_offs = offset_from_coord(pos, output_strides, ndim); + output[output_offs] = f(input[input_offs]); +} + +#define REGISTER_UNARY_OP(NAME, DTYPE0, DTYPE1) \ + static_assert( \ + ::metal:: \ + is_same_v>, \ + "Output dtype mismatch for unary op " #NAME " and input " #DTYPE0); \ + template [[host_name(#NAME "_dense_" #DTYPE1 "_" #DTYPE0)]] kernel void :: \ + c10::metal::unary_dense( \ + device ::c10::metal::result_of * output, \ + constant DTYPE0 * input, \ + uint index); \ + template [[host_name(#NAME "_strided_" #DTYPE1 "_" #DTYPE0)]] kernel void :: \ + c10::metal::unary_strided( \ + device ::c10::metal::result_of * output, \ + constant DTYPE0 * input, \ + constant long* sizes, \ + constant long* input_strides, \ + constant long* output_strides, \ + constant uint& ndim, \ + uint index) + +#define DEFINE_UNARY_FLOATING_FUNCTOR(NAME) \ + struct NAME##_functor { \ + template \ + inline ::metal::enable_if_t<::metal::is_floating_point_v, T> operator()( \ + const T x) { \ + return T(NAME(x)); \ + } \ + template \ + inline ::metal::enable_if_t<::metal::is_integral_v, float> operator()( \ + const T x) { \ + return NAME(static_cast(x)); \ + } \ + } + +template +kernel void unary_alpha_dense( + device result_of* output [[buffer(0)]], + constant T* input [[buffer(1)]], + constant T2& alpha [[buffer(2)]], + uint index [[thread_position_in_grid]]) { + F f; + output[index] = f(input[index], alpha); +} + +template +kernel void unary_alpha_strided( + device result_of* output [[buffer(0)]], + constant T* input [[buffer(1)]], + constant long* sizes [[buffer(2)]], + constant long* input_strides [[buffer(3)]], + constant long* output_strides [[buffer(4)]], + constant uint& ndim [[buffer(5)]], + constant T2& alpha [[buffer(6)]], + uint index [[thread_position_in_grid]]) { + F f; + int pos[max_ndim]; + pos_from_thread_index(int(index), pos, sizes, ndim); + const auto input_offs = offset_from_coord(pos, input_strides, ndim); + const auto output_offs = offset_from_coord(pos, output_strides, ndim); + output[output_offs] = f(input[input_offs], alpha); +} + +#define REGISTER_UNARY_ALPHA_OP(NAME, DTYPEI, DTYPEA, DTYPEO) \ + static_assert( \ + ::metal::is_same_v< \ + DTYPEO, \ + ::c10::metal::result_of>, \ + "Output dtype mismatch for unary op " #NAME " and input " #DTYPEI); \ + template [[host_name(#NAME "_dense_" #DTYPEO "_" #DTYPEI \ + "_" #DTYPEA)]] kernel void ::c10::metal:: \ + unary_alpha_dense( \ + device ::c10::metal::result_of * \ + output, \ + constant DTYPEI * input, \ + constant DTYPEA & alpha, \ + uint index); \ + template [[host_name(#NAME "_strided_" #DTYPEO "_" #DTYPEI \ + "_" #DTYPEA)]] kernel void ::c10::metal:: \ + unary_alpha_strided( \ + device ::c10::metal::result_of * \ + output, \ + constant DTYPEI * input, \ + constant long* sizes, \ + constant long* input_strides, \ + constant long* output_strides, \ + constant uint& ndim, \ + constant DTYPEA& alpha, \ + uint index) + +template +inline T val_at_offs(constant void* ptr, long offs) { + return *reinterpret_cast( + static_cast(ptr) + offs); +} + +// Value at offset with dynamic cast from provided type +template +inline T val_at_offs(device void* ptr, long offs) { + return *reinterpret_cast(static_cast(ptr) + offs); +} + +template +inline T val_at_offs(P ptr, long offs, ScalarType type) { + switch (type) { + case ScalarType::Bool: + return cast_to(val_at_offs(ptr, offs)); + case ScalarType::Byte: + return cast_to(val_at_offs(ptr, offs)); + case ScalarType::Char: + return cast_to(val_at_offs(ptr, offs)); + case ScalarType::Short: + return cast_to(val_at_offs(ptr, offs)); + case ScalarType::Int: + return cast_to(val_at_offs(ptr, offs)); + case ScalarType::Long: + return cast_to(val_at_offs(ptr, offs)); + // Floats + case ScalarType::Float: + return cast_to(val_at_offs(ptr, offs)); + case ScalarType::Half: + return cast_to(val_at_offs(ptr, offs)); + case ScalarType::BFloat16: + return cast_to(val_at_offs(ptr, offs)); + // Complex + case ScalarType::ComplexHalf: + return cast_to(val_at_offs(ptr, offs)); + case ScalarType::ComplexFloat: + return cast_to(val_at_offs(ptr, offs)); + } +} + +template +inline device T& ref_at_offs(device void* ptr, long offs) { + return *reinterpret_cast(static_cast(ptr) + offs); +} + +// Binary elementwise ops kernels +// Right now there are 4 flavors available: +// - binary_dense where both input, other and output are dense and share the +// same type +// - binary_strided when all inputs are of the same types, but some elements are +// strided +// - binary_dense_cast - inputs are dense, but of different dtypes +// - binary_strided_cast - inputs or output are strided and of different dtypes +// - binary_dense_broadcast - one input is dense, another one is broadcastable +// Note about accuracy (for more info see +// https://github.com/pytorch/pytorch/issues/152736) Sometimes when kernel is +// invoked to produce `half` output, but one of the arguments is float arguments +// should be upcast to float, rather than downcast to half At the moment this is +// expressed with `om_t` optional argument (which stands for opmath_type) which +// is identical to output type but could be something else + +template +kernel void binary_strided( + device void* output [[buffer(0)]], + constant void* input [[buffer(1)]], + constant void* other [[buffer(2)]], + constant long* sizes [[buffer(3)]], + constant long* output_strides [[buffer(4)]], + constant long* input_strides [[buffer(5)]], + constant long* other_strides [[buffer(6)]], + constant uint3& ndim [[buffer(7)]], + uint index [[thread_position_in_grid]]) { + F f; + using res_t = result_of; + int pos[max_ndim]; + pos_from_thread_index(int(index), pos, sizes, ndim.x); + const auto input_offs = offset_from_coord(pos, input_strides, ndim.x); + const auto other_offs = offset_from_coord(pos, other_strides, ndim.x); + const auto output_offs = offset_from_coord(pos, output_strides, ndim.x); + const auto a = val_at_offs(input, input_offs); + const auto b = val_at_offs(other, other_offs); + ref_at_offs(output, output_offs) = + static_cast(f(om_t(a), om_t(b))); +} + +template +kernel void binary_alpha_strided( + device void* output [[buffer(0)]], + constant void* input [[buffer(1)]], + constant void* other [[buffer(2)]], + constant T2& alpha [[buffer(3)]], + constant long* sizes [[buffer(4)]], + constant long* output_strides [[buffer(5)]], + constant long* input_strides [[buffer(6)]], + constant long* other_strides [[buffer(7)]], + constant uint3& ndim [[buffer(8)]], + uint index [[thread_position_in_grid]]) { + F f; + int pos[max_ndim]; + pos_from_thread_index(int(index), pos, sizes, ndim.x); + const auto input_offs = offset_from_coord(pos, input_strides, ndim.x); + const auto other_offs = offset_from_coord(pos, other_strides, ndim.x); + const auto output_offs = offset_from_coord(pos, output_strides, ndim.x); + const auto a = val_at_offs(input, input_offs); + const auto b = val_at_offs(other, other_offs); + ref_at_offs>(output, output_offs) = f(a, b, alpha); +} + +template > +kernel void binary_strided_cast( + device void* output [[buffer(0)]], + constant void* input [[buffer(1)]], + constant void* other [[buffer(2)]], + constant long* sizes [[buffer(3)]], + constant long* output_strides [[buffer(4)]], + constant long* input_strides [[buffer(5)]], + constant long* other_strides [[buffer(6)]], + constant uint4& ndim_types [[buffer(7)]], + uint index [[thread_position_in_grid]]) { + F f; + using res_t = result_of; + int pos[max_ndim]; + pos_from_thread_index(int(index), pos, sizes, ndim_types.x); + const auto input_offs = offset_from_coord(pos, input_strides, ndim_types.x); + const auto other_offs = offset_from_coord(pos, other_strides, ndim_types.x); + const auto output_offs = offset_from_coord(pos, output_strides, ndim_types.x); + const auto a = val_at_offs( + input, input_offs, static_cast(ndim_types.y)); + const auto b = val_at_offs( + other, other_offs, static_cast(ndim_types.z)); + ref_at_offs(output, output_offs) = static_cast(f(a, b)); +} + +template +kernel void binary_alpha_strided_cast( + device void* output [[buffer(0)]], + constant void* input [[buffer(1)]], + constant void* other [[buffer(2)]], + constant T2& alpha [[buffer(3)]], + constant long* sizes [[buffer(4)]], + constant long* output_strides [[buffer(5)]], + constant long* input_strides [[buffer(6)]], + constant long* other_strides [[buffer(7)]], + constant uint4& ndim_types [[buffer(8)]], + uint index [[thread_position_in_grid]]) { + F f; + int pos[max_ndim]; + pos_from_thread_index(int(index), pos, sizes, ndim_types.x); + const auto input_offs = offset_from_coord(pos, input_strides, ndim_types.x); + const auto other_offs = offset_from_coord(pos, other_strides, ndim_types.x); + const auto output_offs = offset_from_coord(pos, output_strides, ndim_types.x); + const auto a = + val_at_offs(input, input_offs, static_cast(ndim_types.y)); + const auto b = + val_at_offs(other, other_offs, static_cast(ndim_types.z)); + ref_at_offs>(output, output_offs) = f(a, b, alpha); +} + +template > +kernel void binary_dense( + device result_of* out [[buffer(0)]], + constant T* input [[buffer(1)]], + constant T* other [[buffer(2)]], + uint tid [[thread_position_in_grid]]) { + F f; + using res_t = result_of; + out[tid] = static_cast(f(om_t(input[tid]), om_t(other[tid]))); +} + +template +kernel void binary_alpha_dense( + device result_of* out [[buffer(0)]], + constant T* input [[buffer(1)]], + constant T* other [[buffer(2)]], + constant T2& alpha [[buffer(3)]], + uint tid [[thread_position_in_grid]]) { + F f; + out[tid] = f(input[tid], other[tid], alpha); +} + +template +kernel void binary_dense_cast( + device result_of* out [[buffer(0)]], + constant void* input [[buffer(1)]], + constant void* other [[buffer(2)]], + constant uint4& sizes_types [[buffer(3)]], + uint tid [[thread_position_in_grid]]) { + F f; + using res_t = result_of; + const auto a = val_at_offs( + input, tid * sizes_types.x, static_cast(sizes_types.z)); + const auto b = val_at_offs( + other, tid * sizes_types.y, static_cast(sizes_types.w)); + out[tid] = static_cast(f(a, b)); +} + +template +kernel void binary_alpha_dense_cast( + device result_of* out [[buffer(0)]], + constant void* input [[buffer(1)]], + constant void* other [[buffer(2)]], + constant T2& alpha [[buffer(3)]], + constant uint4& sizes_types [[buffer(4)]], + uint tid [[thread_position_in_grid]]) { + F f; + const auto a = val_at_offs( + input, tid * sizes_types.x, static_cast(sizes_types.z)); + const auto b = val_at_offs( + other, tid * sizes_types.y, static_cast(sizes_types.w)); + out[tid] = f(a, b, alpha); +} + +template > +kernel void binary_dense_broadcast( + device result_of* out [[buffer(0)]], + constant T* input [[buffer(1)]], + constant T* broadcast [[buffer(2)]], + constant long& broadcast_numel [[buffer(3)]], + uint tid [[thread_position_in_grid]]) { + F f; + using res_t = result_of; + out[tid] = static_cast( + f(om_t(input[tid]), om_t(broadcast[tid % broadcast_numel]))); +} + +template > +kernel void binary_dense_broadcast_rhs( + device result_of* out [[buffer(0)]], + constant T* broadcast [[buffer(1)]], + constant T* input [[buffer(2)]], + constant long& broadcast_numel [[buffer(3)]], + uint tid [[thread_position_in_grid]]) { + F f; + using res_t = result_of; + out[tid] = static_cast( + f(om_t(broadcast[tid % broadcast_numel]), om_t(input[tid]))); +} + +template +kernel void binary_alpha_dense_broadcast( + device result_of* out [[buffer(0)]], + constant T* input [[buffer(1)]], + constant T* broadcast [[buffer(2)]], + constant long& broadcast_numel [[buffer(3)]], + constant T2& alpha [[buffer(4)]], + uint tid [[thread_position_in_grid]]) { + F f; + out[tid] = f(input[tid], broadcast[tid % broadcast_numel], alpha); +} + +template +kernel void binary_alpha_dense_broadcast_rhs( + device result_of* out [[buffer(0)]], + constant T* broadcast [[buffer(1)]], + constant T* input [[buffer(2)]], + constant long& broadcast_numel [[buffer(3)]], + constant T2& alpha [[buffer(4)]], + uint tid [[thread_position_in_grid]]) { + F f; + out[tid] = f(broadcast[tid % broadcast_numel], input[tid], alpha); +} + +template +kernel void binary_dense_broadcast_cast( + device result_of* out [[buffer(0)]], + constant void* input [[buffer(1)]], + constant void* broadcast [[buffer(2)]], + constant long& broadcast_numel [[buffer(3)]], + constant uint4& sizes_types [[buffer(4)]], + uint tid [[thread_position_in_grid]]) { + F f; + using res_t = result_of; + const auto a = val_at_offs( + input, tid * sizes_types.x, static_cast(sizes_types.z)); + const auto b = val_at_offs( + broadcast, + (tid % broadcast_numel) * sizes_types.y, + static_cast(sizes_types.w)); + out[tid] = static_cast(f(a, b)); +} + +template +kernel void binary_dense_broadcast_rhs_cast( + device result_of* out [[buffer(0)]], + constant void* broadcast [[buffer(1)]], + constant void* input [[buffer(2)]], + constant long& broadcast_numel [[buffer(3)]], + constant uint4& sizes_types [[buffer(4)]], + uint tid [[thread_position_in_grid]]) { + F f; + using res_t = result_of; + const auto a = val_at_offs( + broadcast, + (tid % broadcast_numel) * sizes_types.x, + static_cast(sizes_types.z)); + const auto b = val_at_offs( + input, tid * sizes_types.y, static_cast(sizes_types.w)); + out[tid] = static_cast(f(a, b)); +} + +template +kernel void binary_alpha_dense_broadcast_cast( + device result_of* out [[buffer(0)]], + constant void* input [[buffer(1)]], + constant void* broadcast [[buffer(2)]], + constant long& broadcast_numel [[buffer(3)]], + constant T2& alpha [[buffer(4)]], + constant uint4& sizes_types [[buffer(5)]], + uint tid [[thread_position_in_grid]]) { + F f; + const auto a = val_at_offs( + input, tid * sizes_types.x, static_cast(sizes_types.z)); + const auto b = val_at_offs( + broadcast, + (tid % broadcast_numel) * sizes_types.y, + static_cast(sizes_types.w)); + out[tid] = f(a, b, alpha); +} + +template +kernel void binary_alpha_dense_broadcast_rhs_cast( + device result_of* out [[buffer(0)]], + constant void* broadcast [[buffer(1)]], + constant void* input [[buffer(2)]], + constant long& broadcast_numel [[buffer(3)]], + constant T2& alpha [[buffer(4)]], + constant uint4& sizes_types [[buffer(5)]], + uint tid [[thread_position_in_grid]]) { + F f; + const auto a = val_at_offs( + broadcast, + (tid % broadcast_numel) * sizes_types.x, + static_cast(sizes_types.z)); + const auto b = val_at_offs( + input, tid * sizes_types.y, static_cast(sizes_types.w)); + out[tid] = f(a, b, alpha); +} + +template > +kernel void binary_dense_scalar( + device result_of* out [[buffer(0)]], + constant T* input [[buffer(1)]], + device T* scalar [[buffer(2)]], + uint tid [[thread_position_in_grid]]) { + F f; + using res_t = result_of; + out[tid] = static_cast(f(om_t(input[tid]), om_t(scalar[0]))); +} + +template > +kernel void binary_dense_scalar_lhs( + device result_of* out [[buffer(0)]], + device T* scalar [[buffer(1)]], + constant T* input [[buffer(2)]], + uint tid [[thread_position_in_grid]]) { + F f; + using res_t = result_of; + out[tid] = static_cast(f(om_t(scalar[0]), om_t(input[tid]))); +} + +template +kernel void binary_dense_scalar_cast( + device result_of* out [[buffer(0)]], + constant void* input [[buffer(1)]], + device void* scalar [[buffer(2)]], + constant uint4& sizes_types [[buffer(3)]], + uint tid [[thread_position_in_grid]]) { + F f; + using res_t = result_of; + const auto a = val_at_offs( + input, tid * sizes_types.x, static_cast(sizes_types.z)); + const auto b = + val_at_offs(scalar, 0, static_cast(sizes_types.w)); + out[tid] = static_cast(f(a, b)); +} + +template +kernel void binary_dense_scalar_lhs_cast( + device result_of* out [[buffer(0)]], + device void* scalar [[buffer(1)]], + constant void* input [[buffer(2)]], + constant uint4& sizes_types [[buffer(3)]], + uint tid [[thread_position_in_grid]]) { + F f; + using res_t = result_of; + const auto a = + val_at_offs(scalar, 0, static_cast(sizes_types.z)); + const auto b = val_at_offs( + input, tid * sizes_types.y, static_cast(sizes_types.w)); + out[tid] = static_cast(f(a, b)); +} + +template +kernel void binary_alpha_dense_scalar( + device result_of* out [[buffer(0)]], + constant T* input [[buffer(1)]], + device T* scalar [[buffer(2)]], + constant T2& alpha [[buffer(3)]], + uint tid [[thread_position_in_grid]]) { + F f; + out[tid] = f(input[tid], scalar[0], alpha); +} + +template +kernel void binary_alpha_dense_scalar_lhs( + device result_of* out [[buffer(0)]], + device T* scalar [[buffer(1)]], + constant T* input [[buffer(2)]], + constant T2& alpha [[buffer(3)]], + uint tid [[thread_position_in_grid]]) { + F f; + out[tid] = f(scalar[0], input[tid], alpha); +} + +template +kernel void binary_alpha_dense_scalar_cast( + device result_of* out [[buffer(0)]], + constant void* input [[buffer(1)]], + device void* scalar [[buffer(2)]], + constant T2& alpha [[buffer(3)]], + constant uint4& sizes_types [[buffer(4)]], + uint tid [[thread_position_in_grid]]) { + F f; + const auto a = val_at_offs( + input, tid * sizes_types.x, static_cast(sizes_types.z)); + const auto b = + val_at_offs(scalar, 0, static_cast(sizes_types.w)); + out[tid] = f(a, b, alpha); +} + +template +kernel void binary_alpha_dense_scalar_lhs_cast( + device result_of* out [[buffer(0)]], + device void* scalar [[buffer(1)]], + constant void* input [[buffer(2)]], + constant T2& alpha [[buffer(3)]], + constant uint4& sizes_types [[buffer(4)]], + uint tid [[thread_position_in_grid]]) { + F f; + const auto a = + val_at_offs(scalar, 0, static_cast(sizes_types.z)); + const auto b = val_at_offs( + input, tid * sizes_types.y, static_cast(sizes_types.w)); + out[tid] = f(a, b, alpha); +} + +#define REGISTER_BINARY_OP_(NAME, DTYPEI, DTYPEO, OMT) \ + static_assert( \ + ::metal::is_same_v< \ + DTYPEO, \ + ::c10::metal::result_of>, \ + "Output dtype mismatch for binary op " #NAME " and input " #DTYPEI); \ + template [[host_name(#NAME "_strided_" #DTYPEO "_" #DTYPEI)]] kernel void :: \ + c10::metal::binary_strided( \ + device void* out, \ + constant void* input, \ + constant void* other, \ + constant long* sizes, \ + constant long* output_strides, \ + constant long* input_strides, \ + constant long* other_strides, \ + constant uint3& ndim, \ + uint tid); \ + template [[host_name(#NAME "_strided_cast_" #DTYPEI)]] kernel void ::c10:: \ + metal::binary_strided_cast( \ + device void* out, \ + constant void* input, \ + constant void* other, \ + constant long* sizes, \ + constant long* output_strides, \ + constant long* input_strides, \ + constant long* other_strides, \ + constant uint4& ndim_types, \ + uint tid); \ + template [[host_name(#NAME "_dense_" #DTYPEO "_" #DTYPEI)]] kernel void :: \ + c10::metal::binary_dense( \ + device ::c10::metal::result_of * \ + out_, \ + constant DTYPEI * input_, \ + constant DTYPEI * other_, \ + uint tid); \ + template [[host_name(#NAME "_dense_cast_" #DTYPEI)]] kernel void ::c10:: \ + metal::binary_dense_cast( \ + device ::c10::metal::result_of * \ + out_, \ + constant void* input, \ + constant void* other, \ + constant uint4& sizes_types, \ + uint tid); \ + template [[host_name(#NAME "_dense_broadcast_" #DTYPEO "_" #DTYPEI)]] \ + kernel void ::c10::metal:: \ + binary_dense_broadcast( \ + device ::c10::metal::result_of * \ + out_, \ + constant DTYPEI * input_, \ + constant DTYPEI * broadcast_, \ + constant long& broadcast_numel, \ + uint tid); \ + template [[host_name(#NAME "_dense_broadcast_rhs_" #DTYPEO "_" #DTYPEI)]] \ + kernel void ::c10::metal:: \ + binary_dense_broadcast_rhs( \ + device ::c10::metal::result_of * \ + out_, \ + constant DTYPEI * broadcast_, \ + constant DTYPEI * input_, \ + constant long& broadcast_numel, \ + uint tid); \ + template [[host_name(#NAME "_dense_broadcast_cast_" #DTYPEI)]] \ + kernel void ::c10::metal:: \ + binary_dense_broadcast_cast( \ + device ::c10::metal::result_of * \ + out_, \ + constant void* input_, \ + constant void* broadcast_, \ + constant long& broadcast_numel, \ + constant uint4& sizes_types, \ + uint tid); \ + template [[host_name(#NAME "_dense_broadcast_rhs_cast_" #DTYPEI)]] \ + kernel void ::c10::metal:: \ + binary_dense_broadcast_rhs_cast( \ + device ::c10::metal::result_of * \ + out_, \ + constant void* broadcast_, \ + constant void* input_, \ + constant long& broadcast_numel, \ + constant uint4& sizes_types, \ + uint tid); \ + template [[host_name(#NAME "_dense_scalar_" #DTYPEO "_" #DTYPEI)]] \ + kernel void ::c10::metal::binary_dense_scalar( \ + device ::c10::metal::result_of * out_, \ + constant DTYPEI * input_, \ + device DTYPEI * scalar_, \ + uint tid); \ + template [[host_name(#NAME "_dense_scalar_lhs_" #DTYPEO "_" #DTYPEI)]] \ + kernel void ::c10::metal:: \ + binary_dense_scalar_lhs( \ + device ::c10::metal::result_of * \ + out_, \ + device DTYPEI * scalar_, \ + constant DTYPEI * input_, \ + uint tid); \ + template [[host_name(#NAME "_dense_scalar_cast_" #DTYPEI)]] \ + kernel void ::c10::metal:: \ + binary_dense_scalar_cast( \ + device ::c10::metal::result_of * \ + out_, \ + constant void* input_, \ + device void* scalar_, \ + constant uint4& sizes_types, \ + uint tid); \ + template [[host_name(#NAME "_dense_scalar_lhs_cast_" #DTYPEI)]] \ + kernel void ::c10::metal:: \ + binary_dense_scalar_lhs_cast( \ + device ::c10::metal::result_of * \ + out_, \ + device void* scalar_, \ + constant void* input_, \ + constant uint4& sizes_types, \ + uint tid) + +// OpMath Binary Op promotes inputs to higher precision type before Functor call +#define REGISTER_OPMATH_BINARY_OP(NAME, DTYPEI, DTYPEO) \ + REGISTER_BINARY_OP_(NAME, DTYPEI, DTYPEO, ::c10::metal::opmath_t) + +#define REGISTER_BINARY_OP(NAME, DTYPEI, DTYPEO) \ + REGISTER_BINARY_OP_(NAME, DTYPEI, DTYPEO, DTYPEI) + +#define REGISTER_BINARY_ALPHA_OP(NAME, DTYPEI, DTYPEA, DTYPEO) \ + static_assert( \ + ::metal::is_same_v< \ + DTYPEO, \ + ::c10::metal::result_of>, \ + "Output dtype mismatch for binary op " #NAME " and input " #DTYPEI); \ + template [[host_name(#NAME "_strided_" #DTYPEO "_" #DTYPEI \ + "_" #DTYPEA)]] kernel void ::c10::metal:: \ + binary_alpha_strided( \ + device void* out, \ + constant void* input, \ + constant void* other, \ + constant DTYPEA& alpha, \ + constant long* sizes, \ + constant long* output_strides, \ + constant long* input_strides, \ + constant long* other_strides, \ + constant uint3& ndim, \ + uint tid); \ + template [[host_name(#NAME "_strided_cast_" #DTYPEI \ + "_" #DTYPEA)]] kernel void ::c10::metal:: \ + binary_alpha_strided_cast( \ + device void* out, \ + constant void* input, \ + constant void* other, \ + constant DTYPEA& alpha, \ + constant long* sizes, \ + constant long* output_strides, \ + constant long* input_strides, \ + constant long* other_strides, \ + constant uint4& ndim_types, \ + uint tid); \ + template [[host_name(#NAME "_dense_" #DTYPEO "_" #DTYPEI \ + "_" #DTYPEA)]] kernel void ::c10::metal:: \ + binary_alpha_dense( \ + device ::c10::metal:: \ + result_of * \ + out_, \ + constant DTYPEI * input_, \ + constant DTYPEI * other_, \ + constant DTYPEA & alpha, \ + uint tid); \ + template \ + [[host_name(#NAME "_dense_cast_" #DTYPEI "_" #DTYPEA)]] kernel void :: \ + c10::metal::binary_alpha_dense_cast( \ + device ::c10::metal:: \ + result_of * \ + out_, \ + constant void* input, \ + constant void* other, \ + constant DTYPEA& alpha, \ + constant uint4& sizes_types, \ + uint tid); \ + template [[host_name(#NAME "_dense_broadcast_" #DTYPEO "_" #DTYPEI \ + "_" #DTYPEA)]] kernel void ::c10::metal:: \ + binary_alpha_dense_broadcast( \ + device ::c10::metal:: \ + result_of * \ + out_, \ + constant DTYPEI * input_, \ + constant DTYPEI * broadcast_, \ + constant long& broadcast_numel, \ + constant DTYPEA& alpha, \ + uint tid); \ + template [[host_name(#NAME "_dense_broadcast_rhs_" #DTYPEO "_" #DTYPEI \ + "_" #DTYPEA)]] kernel void ::c10::metal:: \ + binary_alpha_dense_broadcast_rhs( \ + device ::c10::metal:: \ + result_of * \ + out_, \ + constant DTYPEI * broadcast_, \ + constant DTYPEI * input_, \ + constant long& broadcast_numel, \ + constant DTYPEA& alpha, \ + uint tid); \ + template [[host_name(#NAME "_dense_broadcast_cast_" #DTYPEI \ + "_" #DTYPEA)]] kernel void ::c10::metal:: \ + binary_alpha_dense_broadcast_cast( \ + device ::c10::metal:: \ + result_of * \ + out_, \ + constant void* input_, \ + constant void* broadcast_, \ + constant long& broadcast_numel, \ + constant DTYPEA& alpha, \ + constant uint4& sizes_types, \ + uint tid); \ + template [[host_name(#NAME "_dense_broadcast_rhs_cast_" #DTYPEI \ + "_" #DTYPEA)]] kernel void ::c10::metal:: \ + binary_alpha_dense_broadcast_rhs_cast( \ + device ::c10::metal:: \ + result_of * \ + out_, \ + constant void* broadcast_, \ + constant void* input_, \ + constant long& broadcast_numel, \ + constant DTYPEA& alpha, \ + constant uint4& sizes_types, \ + uint tid); \ + template [[host_name(#NAME "_dense_scalar_" #DTYPEO "_" #DTYPEI \ + "_" #DTYPEA)]] kernel void ::c10::metal:: \ + binary_alpha_dense_scalar( \ + device ::c10::metal:: \ + result_of * \ + out_, \ + constant DTYPEI * input_, \ + device DTYPEI * scalar_, \ + constant DTYPEA & alpha, \ + uint tid); \ + template [[host_name(#NAME "_dense_scalar_lhs_" #DTYPEO "_" #DTYPEI \ + "_" #DTYPEA)]] kernel void ::c10::metal:: \ + binary_alpha_dense_scalar_lhs( \ + device ::c10::metal:: \ + result_of * \ + out_, \ + device DTYPEI * scalar_, \ + constant DTYPEI * input_, \ + constant DTYPEA & alpha, \ + uint tid); \ + template [[host_name(#NAME "_dense_scalar_cast_" #DTYPEI \ + "_" #DTYPEA)]] kernel void ::c10::metal:: \ + binary_alpha_dense_scalar_cast( \ + device ::c10::metal:: \ + result_of * \ + out_, \ + constant void* input_, \ + device void* scalar_, \ + constant DTYPEA& alpha, \ + constant uint4& sizes_types, \ + uint tid); \ + template [[host_name(#NAME "_dense_scalar_lhs_cast_" #DTYPEI \ + "_" #DTYPEA)]] kernel void ::c10::metal:: \ + binary_alpha_dense_scalar_lhs_cast( \ + device ::c10::metal:: \ + result_of * \ + out_, \ + device void* scalar_, \ + constant void* input_, \ + constant DTYPEA& alpha, \ + constant uint4& sizes_types, \ + uint tid) + +// Ternary elementwise ops kernels +// Right now there are 4 flavors available: +// - ternary_dense where both input, other1, other2, and output are dense and +// share the same type +// - ternary_strided when all inputs are of the same types, but some elements +// are strided +// - ternary_dense_cast - inputs are dense, but of different dtypes +// - ternary_strided_cast - inputs or output are strided and of different dtypes +// Note about accuracy (for more info see +// https://github.com/pytorch/pytorch/issues/152736) Sometimes when kernel is +// invoked to produce `half` output, but one of the arguments is float arguments +// should be upcast to float, rather than downcast to half At the moment this is +// expressed with `om_t` optional argument (which stands for opmath_type) which +// is identical to output type but could be something else + +template +kernel void ternary_strided( + device void* output [[buffer(0)]], + constant void* input [[buffer(1)]], + constant void* other1 [[buffer(2)]], + constant void* other2 [[buffer(3)]], + constant long* sizes [[buffer(4)]], + constant long* output_strides [[buffer(5)]], + constant long* input_strides [[buffer(6)]], + constant long* other1_strides [[buffer(7)]], + constant long* other2_strides [[buffer(8)]], + constant uint& ndim [[buffer(9)]], + constant uint4& types [[buffer(10)]], + uint index [[thread_position_in_grid]]) { + F f; + using res_t = result_of; + int pos[max_ndim]; + pos_from_thread_index(int(index), pos, sizes, ndim); + const auto input_offs = offset_from_coord(pos, input_strides, ndim); + const auto other1_offs = offset_from_coord(pos, other1_strides, ndim); + const auto other2_offs = offset_from_coord(pos, other2_strides, ndim); + const auto output_offs = offset_from_coord(pos, output_strides, ndim); + const auto a = val_at_offs(input, input_offs); + const auto b = val_at_offs(other1, other1_offs); + const auto c = val_at_offs(other2, other2_offs); + ref_at_offs(output, output_offs) = + static_cast(f(om_t(a), om_t(b), om_t(c))); +} + +template > +kernel void ternary_strided_cast( + device void* output [[buffer(0)]], + constant void* input [[buffer(1)]], + constant void* other1 [[buffer(2)]], + constant void* other2 [[buffer(3)]], + constant long* sizes [[buffer(4)]], + constant long* output_strides [[buffer(5)]], + constant long* input_strides [[buffer(6)]], + constant long* other1_strides [[buffer(7)]], + constant long* other2_strides [[buffer(8)]], + constant uint& ndim [[buffer(9)]], + constant uint4& types [[buffer(10)]], + uint index [[thread_position_in_grid]]) { + F f; + using res_t = result_of; + int pos[max_ndim]; + pos_from_thread_index(int(index), pos, sizes, ndim); + const auto input_offs = offset_from_coord(pos, input_strides, ndim); + const auto other1_offs = offset_from_coord(pos, other1_strides, ndim); + const auto other2_offs = offset_from_coord(pos, other2_strides, ndim); + const auto output_offs = offset_from_coord(pos, output_strides, ndim); + const auto a = + val_at_offs(input, input_offs, static_cast(types.x)); + const auto b = + val_at_offs(other1, other1_offs, static_cast(types.y)); + const auto c = + val_at_offs(other2, other2_offs, static_cast(types.z)); + ref_at_offs(output, output_offs) = static_cast(f(a, b, c)); +} + +template > +kernel void ternary_dense( + device result_of* out [[buffer(0)]], + constant T* input [[buffer(1)]], + constant T* other1 [[buffer(2)]], + constant T* other2 [[buffer(3)]], + uint tid [[thread_position_in_grid]]) { + F f; + using res_t = result_of; + out[tid] = static_cast( + f(om_t(input[tid]), om_t(other1[tid]), om_t(other2[tid]))); +} + +template +kernel void ternary_dense_cast( + device result_of* out [[buffer(0)]], + constant void* input [[buffer(1)]], + constant void* other1 [[buffer(2)]], + constant void* other2 [[buffer(3)]], + constant uint3& sizes [[buffer(4)]], + constant uint3& types [[buffer(5)]], + uint tid [[thread_position_in_grid]]) { + F f; + using res_t = result_of; + const auto a = + val_at_offs(input, tid * sizes.x, static_cast(types.x)); + const auto b = val_at_offs( + other1, tid * sizes.y, static_cast(types.y)); + const auto c = val_at_offs( + other2, tid * sizes.z, static_cast(types.z)); + out[tid] = static_cast(f(a, b, c)); +} + +#define REGISTER_TERNARY_OP_(NAME, DTYPEI, DTYPEO, OMT) \ + static_assert( \ + ::metal::is_same_v< \ + DTYPEO, \ + ::c10::metal::result_of>, \ + "Output dtype mismatch for ternary op " #NAME " and input " #DTYPEI); \ + template [[host_name(#NAME "_strided_" #DTYPEO "_" #DTYPEI)]] kernel void :: \ + c10::metal::ternary_strided( \ + device void* out, \ + constant void* input, \ + constant void* other1, \ + constant void* other2, \ + constant long* sizes, \ + constant long* output_strides, \ + constant long* input_strides, \ + constant long* other1_strides, \ + constant long* other2_strides, \ + constant uint& ndim, \ + constant uint4& types, \ + uint tid); \ + template [[host_name(#NAME "_strided_cast_" #DTYPEI)]] kernel void ::c10:: \ + metal::ternary_strided_cast( \ + device void* out, \ + constant void* input, \ + constant void* other1, \ + constant void* other2, \ + constant long* sizes, \ + constant long* output_strides, \ + constant long* input_strides, \ + constant long* other1_strides, \ + constant long* other2_strides, \ + constant uint& ndim, \ + constant uint4& types, \ + uint tid); \ + template [[host_name(#NAME "_dense_" #DTYPEO "_" #DTYPEI)]] kernel void :: \ + c10::metal::ternary_dense( \ + device ::c10::metal:: \ + result_of * \ + out_, \ + constant DTYPEI * input_, \ + constant DTYPEI * other1_, \ + constant DTYPEI * other2_, \ + uint tid); \ + template [[host_name(#NAME "_dense_cast_" #DTYPEI)]] kernel void ::c10:: \ + metal::ternary_dense_cast( \ + device ::c10::metal:: \ + result_of * \ + out_, \ + constant void* input, \ + constant void* other1, \ + constant void* other2, \ + constant uint3& sizes, \ + constant uint3& types, \ + uint tid) + +// OpMath ternary Op promotes inputs to higher precision type before Functor +// call +#define REGISTER_OPMATH_TERNARY_OP(NAME, DTYPEI, DTYPEO) \ + REGISTER_TERNARY_OP_(NAME, DTYPEI, DTYPEO, ::c10::metal::opmath_t) + +#define REGISTER_TERNARY_OP(NAME, DTYPEI, DTYPEO) \ + REGISTER_TERNARY_OP_(NAME, DTYPEI, DTYPEO, DTYPEI) + +} // namespace metal +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/metal/random.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/metal/random.h new file mode 100644 index 0000000000000000000000000000000000000000..711e446d667decbbf3e2cfc7fc5a0da5d81d3123 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/metal/random.h @@ -0,0 +1,83 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Philox Counter based RNG implementation for Metal +// Borrowed from aten/src/ATen/core/PhiloxRNGEngine.h +// Which in turn borrowed from +// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf +#pragma once +#include + +namespace c10 { +namespace metal { + +namespace detail { + +constexpr float uint32_to_uniform_float(uint32_t value) { + // maximum value such that `MAX_INT * scale < 1.0` (with float rounding) + constexpr float scale = 4.6566127342e-10; + return static_cast(value & 0x7FFFFFFF) * scale; +} + +inline uint2 splitlong(ulong v) { + return uint2(v >> 32, v & 0xffffffff); +} + +} // namespace detail + +namespace philox4 { + +uint2 mulhilo(uint a, uint b) { + auto rc = static_cast(a) * b; + return detail::splitlong(rc); +} +uint4 single_round(uint4 ctr, uint2 key) { + constexpr uint kPhiloxSA = 0xD2511F53; + constexpr uint kPhiloxSB = 0xCD9E8D57; + auto rc0 = mulhilo(kPhiloxSA, ctr.x); + auto rc1 = mulhilo(kPhiloxSB, ctr.z); + return uint4(rc1.y ^ ctr.y ^ key.x, rc1.x, rc0.y ^ ctr.w ^ key.y, rc0.x); +} + +uint4 multiple_rounds(uint4 ctr, uint2 key, uint rounds) { + constexpr uint2 kPhilox10 = {0x9E3779B9, 0xBB67AE85}; + for (uint round = 0; round < rounds - 1; ++round) { + ctr = single_round(ctr, key); + key += kPhilox10; + } + return ctr; +} + +uint4 rand(long seed, long index) { + uint4 ctr = 0; + ctr.zw = detail::splitlong(index); + return multiple_rounds(ctr, detail::splitlong(seed), 10); +} + +} // namespace philox4 + +float randn(long seed, long index) { + auto value = philox4::rand(seed, index); + float u1 = 1.0 - detail::uint32_to_uniform_float(value.x); + float u2 = 1.0 - detail::uint32_to_uniform_float(value.y); + return ::metal::sqrt(-2.0 * ::metal::log(u1)) * + ::metal::cos(2.0 * M_PI_F * u2); +} + +float rand(long seed, long index) { + auto value = philox4::rand(seed, index); + return detail::uint32_to_uniform_float(value.x); +} + +long randint64(long seed, long index, long low, long high) { + auto range = high - low; + auto value = philox4::rand(seed, index); + // TODO: Implement better algorithm for large ranges + return low + + static_cast(detail::uint32_to_uniform_float(value.x) * range); +} + +} // namespace metal +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/metal/reduction_utils.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/metal/reduction_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..f23c1af774ed88568bc1abacc668e98760bb6f98 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/metal/reduction_utils.h @@ -0,0 +1,364 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +namespace c10 { +namespace metal { +namespace detail { +template +struct simd_type { + using t = T; +}; + +// Helper that allows one to run simd ops over bfl16 by upcasting them to fp32 +template +using simd_type_t = typename simd_type::t; + +template <> +struct simd_type { + using t = float; +}; +} // namespace detail + +template +inline ::metal::enable_if_t, T> simd_sum(T val) { + return T(::metal::simd_sum(detail::simd_type_t(val))); +} + +template +inline ::metal::enable_if_t, T> simd_prod(T val) { + return T(::metal::simd_product(detail::simd_type_t(val))); +} + +// Extend simd_broadcast to 64-bit integral types using int2 trick +template < + typename T, + ::metal::enable_if_t<::metal::is_integral_v && sizeof(T) == 8, bool> = + true> +inline T simd_broadcast(T val, ushort lane_id) { + return as_type(::metal::simd_broadcast(as_type(val), lane_id)); +} + +template < + typename T, + ::metal::enable_if_t || sizeof(T) != 8, bool> = + true> +inline T simd_broadcast(T val, ushort lane_id) { + return ::metal::simd_broadcast(val, lane_id); +} + +// Floating simd_min/max with nan propagation +template < + typename T, + ::metal::enable_if_t<::metal::is_floating_point_v, bool> = true> +inline T simd_max(T val) { + if (::metal::simd_any(::metal::isnan(val))) { + return ::metal::numeric_limits::quiet_NaN(); + } + return T(::metal::simd_max(detail::simd_type_t(val))); +} + +template < + typename T, + ::metal::enable_if_t<::metal::is_floating_point_v, bool> = true> +inline T simd_min(T val) { + if (::metal::simd_any(::metal::isnan(val))) { + return ::metal::numeric_limits::quiet_NaN(); + } + return T(::metal::simd_min(detail::simd_type_t(val))); +} + +template < + typename T, + ::metal::enable_if_t<::metal::is_integral_v && sizeof(T) != 8, bool> = + true> +inline T simd_max(T val) { + return ::metal::simd_max(val); +} + +template < + typename T, + ::metal::enable_if_t<::metal::is_integral_v && sizeof(T) != 8, bool> = + true> +inline T simd_min(T val) { + return ::metal::simd_min(val); +} + +// Metal does not support SIMD reductions over 64-bit types, but it could be +// implement using simd_shuffle_down, that yields result in log2(simdgroup_size) +// iterations Use fill variant, as shuffle down returns garbage if inactive +// thread is referenced (on M1/M2, works fine on M4) and broadcast result to all +// threads in the end. Implementation heavily borrows from +// https://github.com/ml-explore/mlx/blob/86389bf9707f46101af45d90510e8e97c8a90b93/mlx/backend/metal/kernels/reduction/ops.h#L16 +template +inline ::metal::enable_if_t<::metal::is_same_v, T> simd_sum(T val) { + for (ushort i = simdgroup_size / 2; i > 0; i /= 2) { + val += as_type( + ::metal::simd_shuffle_and_fill_down(as_type(val), int2(0), i)); + } + return simd_broadcast(val, 0); +} + +template +inline ::metal::enable_if_t<::metal::is_same_v, T> simd_prod(T val) { + for (ushort i = simdgroup_size / 2; i > 0; i /= 2) { + val *= as_type( + ::metal::simd_shuffle_and_fill_down(as_type(val), int2(0), i)); + } + return simd_broadcast(val, 0); +} + +template +inline ::metal::enable_if_t<::metal::is_same_v, T> simd_max(T val) { + for (ushort i = simdgroup_size / 2; i > 0; i /= 2) { + val = ::metal::max( + val, + as_type(::metal::simd_shuffle_and_fill_down( + as_type(val), int2(0), i))); + } + return simd_broadcast(val, 0); +} + +template +inline ::metal::enable_if_t<::metal::is_same_v, T> simd_min(T val) { + for (ushort i = simdgroup_size / 2; i > 0; i /= 2) { + val = ::metal::min( + val, + as_type(::metal::simd_shuffle_and_fill_down( + as_type(val), int2(0), i))); + } + return simd_broadcast(val, 0); +} + +// argmin/argmax helpers using simd_ballot +template < + typename T, + ::metal::enable_if_t<::metal::is_integral_v, bool> = true> +inline ::c10::metal::pair simd_argmin(T val) { + const auto rc = simd_min(val); + const auto vote = ::metal::simd_ballot(val == rc); + return {rc, static_cast(::metal::ctz(static_cast(vote)))}; +} + +template < + typename T, + ::metal::enable_if_t<::metal::is_floating_point_v, bool> = true> +inline ::c10::metal::pair simd_argmin(T val) { + const auto rc = simd_min(val); + const auto vote = ::metal::simd_ballot(val == rc || ::metal::isnan(val)); + return {rc, static_cast(::metal::ctz(static_cast(vote)))}; +} + +template < + typename T, + ::metal::enable_if_t<::metal::is_integral_v, bool> = true> +inline ::c10::metal::pair simd_argmax(T val) { + const auto rc = simd_max(val); + const auto vote = ::metal::simd_ballot(val == rc); + return {rc, static_cast(::metal::ctz(static_cast(vote)))}; +} + +template < + typename T, + ::metal::enable_if_t<::metal::is_floating_point_v, bool> = true> +inline ::c10::metal::pair simd_argmax(T val) { + const auto rc = simd_max(val); + const auto vote = ::metal::simd_ballot(val == rc || ::metal::isnan(val)); + return {rc, static_cast(::metal::ctz(static_cast(vote)))}; +} + +template +inline c10::metal::pair simd_argmin(ARG_T val, IDX_T idx_val) { + auto rc = simd_argmin(val); + return {rc.first, simd_broadcast(idx_val, rc.second)}; +} + +template +inline c10::metal::pair simd_argmax(ARG_T val, IDX_T idx_val) { + auto rc = simd_argmax(val); + return {rc.first, simd_broadcast(idx_val, rc.second)}; +} + +// Below algorithms are written with hardcoded assumption that simdgroup is 32 +// and threadgroup_max is 1024, i.e. reduction can be done in two stages max +template +opmath_t threadgroup_sum( + threadgroup opmath_t* data, + T val, + unsigned idx, + unsigned size) { + auto rc = simd_sum(static_cast>(val)); + if (idx % simdgroup_size == 0) { + data[idx / simdgroup_size] = rc; + } + if (size > simdgroup_size) { + ::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup); + if (idx < ((size + simdgroup_size - 1) / simdgroup_size)) { + auto rc1 = simd_sum(data[idx]); + if (idx == 0) { + data[0] = rc1; + } + } + } + ::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup); + return data[0]; +} + +template +opmath_t threadgroup_prod( + threadgroup opmath_t* data, + T val, + unsigned idx, + unsigned size) { + auto rc = simd_prod(static_cast>(val)); + if (idx % simdgroup_size == 0) { + data[idx / simdgroup_size] = rc; + } + if (size > simdgroup_size) { + ::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup); + if (idx < ((size + simdgroup_size - 1) / simdgroup_size)) { + auto rc1 = simd_prod(data[idx]); + if (idx == 0) { + data[0] = rc1; + } + } + } + ::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup); + return data[0]; +} + +template +T threadgroup_max(threadgroup T* data, T val, unsigned idx, unsigned size) { + auto rc = simd_max(val); + if (idx % simdgroup_size == 0) { + data[idx / simdgroup_size] = rc; + } + if (size > simdgroup_size) { + ::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup); + if (idx < ((size + simdgroup_size - 1) / simdgroup_size)) { + auto rc1 = simd_max(data[idx]); + if (idx == 0) { + data[0] = rc1; + } + } + } + ::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup); + return data[0]; +} + +template +T threadgroup_min(threadgroup T* data, T val, unsigned idx, unsigned size) { + auto rc = simd_min(val); + if (idx % simdgroup_size == 0) { + data[idx / simdgroup_size] = rc; + } + if (size > simdgroup_size) { + ::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup); + if (idx < ((size + simdgroup_size - 1) / simdgroup_size)) { + auto rc1 = simd_min(data[idx]); + if (idx == 0) { + data[0] = rc1; + } + } + } + ::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup); + return data[0]; +} + +template +float3 threadgroup_welford_reduce(threadgroup T* data, unsigned size) { + ::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup); + float m = data[0]; + float m2 = 0; + for (unsigned idx = 1; idx < size; ++idx) { + float delta = data[idx] - m; + m += delta / (idx + 1); + m2 += delta * (data[idx] - m); + } + return float3(m, m2, size); +} + +// Each vec3type is tuple of mean, m2 and weight +template +float3 welford_combine(T a, T b) { + float delta = b.x - a.x; + float new_weight = a.z + b.z; + auto w2_over_w = new_weight != 0 ? b.z / new_weight : 0.0; + return float3( + a.x + delta * w2_over_w, + a.y + b.y + delta * delta * a.z * w2_over_w, + new_weight); +} + +template +float3 threadgroup_welford_combine(threadgroup T* data, unsigned size) { + ::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup); + float3 rc = data[0]; + for (unsigned idx = 1; idx < size; ++idx) { + rc = welford_combine(rc, data[idx]); + } + return rc; +} + +template +IDX_T threadgroup_argmax( + threadgroup ARG_T* arg_data, + threadgroup IDX_T* idx_data, + ARG_T val, + IDX_T idx_val, + unsigned idx, + unsigned size) { + auto rc = simd_argmax(val, idx_val); + if (size <= simdgroup_size) { + return rc.second; + } + if (idx % simdgroup_size == 0) { + arg_data[idx / simdgroup_size] = rc.first; + idx_data[idx / simdgroup_size] = rc.second; + } + ::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup); + if (idx < ((size + simdgroup_size - 1) / simdgroup_size)) { + auto rc1 = simd_argmax(arg_data[idx], idx_data[idx]); + if (idx == 0) { + idx_data[0] = rc1.second; + } + } + ::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup); + return idx_data[0]; +} + +template +IDX_T threadgroup_argmin( + threadgroup ARG_T* arg_data, + threadgroup IDX_T* idx_data, + ARG_T val, + IDX_T idx_val, + unsigned idx, + unsigned size) { + auto rc = simd_argmin(val, idx_val); + if (size <= simdgroup_size) { + return rc.second; + } + if (idx % simdgroup_size == 0) { + arg_data[idx / simdgroup_size] = rc.first; + idx_data[idx / simdgroup_size] = rc.second; + } + ::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup); + if (idx < ((size + simdgroup_size - 1) / simdgroup_size)) { + auto rc1 = simd_argmin(arg_data[idx], idx_data[idx]); + if (idx == 0) { + idx_data[0] = rc1.second; + } + } + ::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup); + return idx_data[0]; +} + +} // namespace metal +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/metal/special_math.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/metal/special_math.h new file mode 100644 index 0000000000000000000000000000000000000000..d0fb82cc0ad813b59aeb2e62a0d93ca37ac0d54b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/metal/special_math.h @@ -0,0 +1,2064 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Implementation of special math functions for Metal +#pragma once +#include +#include +#include +#include + +namespace c10 { +namespace metal { + +/* + * Approximation to the error function. + * Based on code from: + * https://stackoverflow.com/questions/35148198/efficient-faithfully-rounded-implementation-of-error-function-erff#answer-35148199 + * Copy-n-pasted from + * https://github.com/ml-explore/mlx/blob/2e8cf0b4506c200a5c2d199ecbbf655fdf4c2ce2/mlx/backend/metal/kernels/erf.h#L11 + */ +template +inline float erf(T x) { + const auto a = static_cast(x); + const auto t = ::metal::abs(a); + const auto s = a * a; + if (t > 0.927734375f) { + // maximum error 0.99527 ulp + auto r = ::metal::fma( + -1.72853470e-5f, t, 3.83197126e-4f); // -0x1.220000p-16,0x1.91cfb2p-12 + const auto u = ::metal::fma( + -3.88396438e-3f, t, 2.42546219e-2f); // -0x1.fd1438p-9, 0x1.8d6342p-6 + r = ::metal::fma(r, s, u); + r = ::metal::fma(r, t, -1.06777877e-1f); // -0x1.b55cb8p-4 + r = ::metal::fma(r, t, -6.34846687e-1f); // -0x1.450aa0p-1 + r = ::metal::fma(r, t, -1.28717512e-1f); // -0x1.079d0cp-3 + r = ::metal::fma(r, t, -t); + // TODO, replace with expm1 when implemented + r = 1.0f - ::metal::exp(r); + r = ::metal::copysign(r, a); + return r; + } + + // maximum error 0.98929 ulp + auto r = -5.96761703e-4f; // -0x1.38e000p-11 + r = ::metal::fma(r, s, 4.99119423e-3f); // 0x1.471a58p-8 + r = ::metal::fma(r, s, -2.67681349e-2f); // -0x1.b691b2p-6 + r = ::metal::fma(r, s, 1.12819925e-1f); // 0x1.ce1c44p-4 + r = ::metal::fma(r, s, -3.76125336e-1f); // -0x1.812700p-2 + r = ::metal::fma(r, s, 1.28379166e-1f); // 0x1.06eba8p-3 + r = ::metal::fma(r, a, a); + return r; +} + +template +float erfc(T x) { + return 1.0 - erf(x); +} + +template +inline float erfinv(T y) { + /* coefficients in rational expansion */ + constexpr float a[4] = {0.886226899, -1.645349621, 0.914624893, -0.140543331}; + constexpr float b[4] = {-2.118377725, 1.442710462, -0.329097515, 0.012229801}; + constexpr float c[4] = {-1.970840454, -1.624906493, 3.429567803, 1.641345311}; + constexpr float d[2] = {3.543889200, 1.637067800}; + + float x, z, num, dem; /*working variables */ + + float y_abs = ::metal::abs(static_cast(y)); + if (y_abs >= 1.0f) { + return y_abs > 1.0f ? NAN + : ::metal::copysign(INFINITY, static_cast(y)); + } + if (y_abs <= 0.7f) { + z = y * y; + num = ((a[3] * z + a[2]) * z + a[1]) * z + a[0]; + dem = (((b[3] * z + b[2]) * z + b[1]) * z + b[0]) * z + 1.0f; + x = y * num / dem; + } else { + z = ::metal::sqrt(-1.0f * ::metal::log((1.0 - y_abs) / 2.0)); + num = ((c[3] * z + c[2]) * z + c[1]) * z + c[0]; + dem = (d[1] * z + d[0]) * z + 1.0f; + x = ::metal::copysign(num, static_cast(y)) / dem; + } + + return x; +} + +/* + * For licensing information and documentation, please refer to the cpu + * implementation located in "ATen/native/Math.h". + */ + +template +inline T chbevl(T x, const float array[], const int len) { + T b0, b1, b2; + + b0 = array[0]; + b1 = 0; + + for (int i = 1; i < len; ++i) { + b2 = b1; + b1 = b0; + b0 = x * b1 - b2 + array[i]; + } + + return T{0.5} * (b0 - b2); +} + +// Copied from +// https://github.com/pytorch/pytorch/blob/58b661cda2c002a8e1ac3bee494bfe1f7420437c/aten/src/ATen/native/cuda/Math.cuh#L502 + +template +inline T i0(T _x) { + auto x = ::metal::fabs(_x); + + if (x <= 8.0) { + /* Chebyshev coefficients for exp(-x) I0(x) + * in the interval [0,8]. + * + * lim(x->0){ exp(-x) I0(x) } = 1. + */ + constexpr float A[] = { + -4.41534164647933937950E-18, 3.33079451882223809783E-17, + -2.43127984654795469359E-16, 1.71539128555513303061E-15, + -1.16853328779934516808E-14, 7.67618549860493561688E-14, + -4.85644678311192946090E-13, 2.95505266312963983461E-12, + -1.72682629144155570723E-11, 9.67580903537323691224E-11, + -5.18979560163526290666E-10, 2.65982372468238665035E-9, + -1.30002500998624804212E-8, 6.04699502254191894932E-8, + -2.67079385394061173391E-7, 1.11738753912010371815E-6, + -4.41673835845875056359E-6, 1.64484480707288970893E-5, + -5.75419501008210370398E-5, 1.88502885095841655729E-4, + -5.76375574538582365885E-4, 1.63947561694133579842E-3, + -4.32430999505057594430E-3, 1.05464603945949983183E-2, + -2.37374148058994688156E-2, 4.93052842396707084878E-2, + -9.49010970480476444210E-2, 1.71620901522208775349E-1, + -3.04682672343198398683E-1, 6.76795274409476084995E-1}; + + auto y = (x / 2.0) - 2.0; + return static_cast(::metal::exp(x) * chbevl(y, A, 30)); + } + + // Handles x > 8 case + /* Chebyshev coefficients for exp(-x) sqrt(x) I0(x) + * in the inverted interval [8,infinity]. + * + * lim(x->inf){ exp(-x) sqrt(x) I0(x) } = 1/sqrt(2pi). + */ + constexpr float B[] = { + -7.23318048787475395456E-18, -4.83050448594418207126E-18, + 4.46562142029675999901E-17, 3.46122286769746109310E-17, + -2.82762398051658348494E-16, -3.42548561967721913462E-16, + 1.77256013305652638360E-15, 3.81168066935262242075E-15, + -9.55484669882830764870E-15, -4.15056934728722208663E-14, + 1.54008621752140982691E-14, 3.85277838274214270114E-13, + 7.18012445138366623367E-13, -1.79417853150680611778E-12, + -1.32158118404477131188E-11, -3.14991652796324136454E-11, + 1.18891471078464383424E-11, 4.94060238822496958910E-10, + 3.39623202570838634515E-9, 2.26666899049817806459E-8, + 2.04891858946906374183E-7, 2.89137052083475648297E-6, + 6.88975834691682398426E-5, 3.36911647825569408990E-3, + 8.04490411014108831608E-1}; + + return static_cast( + (::metal::exp(x) * chbevl(32.0 / x - 2.0, B, 25)) / ::metal::sqrt(x)); +} + +template +inline T i0e(T _x) { + auto x = ::metal::fabs(_x); + + if (x <= 8.0) { + constexpr float coefficients[] = { + -4.41534164647933937950E-18, 3.33079451882223809783E-17, + -2.43127984654795469359E-16, 1.71539128555513303061E-15, + -1.16853328779934516808E-14, 7.67618549860493561688E-14, + -4.85644678311192946090E-13, 2.95505266312963983461E-12, + -1.72682629144155570723E-11, 9.67580903537323691224E-11, + -5.18979560163526290666E-10, 2.65982372468238665035E-9, + -1.30002500998624804212E-8, 6.04699502254191894932E-8, + -2.67079385394061173391E-7, 1.11738753912010371815E-6, + -4.41673835845875056359E-6, 1.64484480707288970893E-5, + -5.75419501008210370398E-5, 1.88502885095841655729E-4, + -5.76375574538582365885E-4, 1.63947561694133579842E-3, + -4.32430999505057594430E-3, 1.05464603945949983183E-2, + -2.37374148058994688156E-2, 4.93052842396707084878E-2, + -9.49010970480476444210E-2, 1.71620901522208775349E-1, + -3.04682672343198398683E-1, 6.76795274409476084995E-1}; + + auto y = (x / 2.0) - 2.0; + return static_cast(chbevl(y, coefficients, int{30})); + } + + // x > 8 + constexpr float coefficients[] = { + -7.23318048787475395456E-18, -4.83050448594418207126E-18, + 4.46562142029675999901E-17, 3.46122286769746109310E-17, + -2.82762398051658348494E-16, -3.42548561967721913462E-16, + 1.77256013305652638360E-15, 3.81168066935262242075E-15, + -9.55484669882830764870E-15, -4.15056934728722208663E-14, + 1.54008621752140982691E-14, 3.85277838274214270114E-13, + 7.18012445138366623367E-13, -1.79417853150680611778E-12, + -1.32158118404477131188E-11, -3.14991652796324136454E-11, + 1.18891471078464383424E-11, 4.94060238822496958910E-10, + 3.39623202570838634515E-9, 2.26666899049817806459E-8, + 2.04891858946906374183E-7, 2.89137052083475648297E-6, + 6.88975834691682398426E-5, 3.36911647825569408990E-3, + 8.04490411014108831608E-1}; + + return static_cast( + chbevl(32.0 / x - 2.0, coefficients, 25) / ::metal::sqrt(x)); +} + +// Copied from +// https://github.com/pytorch/pytorch/blob/58b661cda2c002a8e1ac3bee494bfe1f7420437c/aten/src/ATen/native/cuda/Math.cuh#L576 + +template +inline T i1(T _x) { + const auto x = ::metal::fabs(_x); + + if (x <= 8.0) { + // Chebyshev coefficients for exp(-x) i1(x) in the internal [0, 8] + // lim(x->0){ exp(-x) i1(x) / x } = 1/2 + constexpr float coefficients[] = { + 2.77791411276104639959E-18, -2.11142121435816608115E-17, + 1.55363195773620046921E-16, -1.10559694773538630805E-15, + 7.60068429473540693410E-15, -5.04218550472791168711E-14, + 3.22379336594557470981E-13, -1.98397439776494371520E-12, + 1.17361862988909016308E-11, -6.66348972350202774223E-11, + 3.62559028155211703701E-10, -1.88724975172282928790E-9, + 9.38153738649577178388E-9, -4.44505912879632808065E-8, + 2.00329475355213526229E-7, -8.56872026469545474066E-7, + 3.47025130813767847674E-6, -1.32731636560394358279E-5, + 4.78156510755005422638E-5, -1.61760815825896745588E-4, + 5.12285956168575772895E-4, -1.51357245063125314899E-3, + 4.15642294431288815669E-3, -1.05640848946261981558E-2, + 2.47264490306265168283E-2, -5.29459812080949914269E-2, + 1.02643658689847095384E-1, -1.76416518357834055153E-1, + 2.52587186443633654823E-1}; + const auto y = x / 2.0 - 2.0; + const auto out = ::metal::exp(x) * x * chbevl(y, coefficients, 29); + return static_cast(_x < T(0.) ? -out : out); + } + + // Chebyshev coefficients for exp(-x) sqrt(x) i1(x) + // in the inverted interval [8, infinity] + // lim(x->inf){ exp(-x) sqrt(x) i1(x) } = 1/sqrt(2pi) + constexpr float coefficients[] = { + 7.51729631084210481353E-18, 4.41434832307170791151E-18, + -4.65030536848935832153E-17, -3.20952592199342395980E-17, + 2.96262899764595013876E-16, 3.30820231092092828324E-16, + -1.88035477551078244854E-15, -3.81440307243700780478E-15, + 1.04202769841288027642E-14, 4.27244001671195135429E-14, + -2.10154184277266431302E-14, -4.08355111109219731823E-13, + -7.19855177624590851209E-13, 2.03562854414708950722E-12, + 1.41258074366137813316E-11, 3.25260358301548823856E-11, + -1.89749581235054123450E-11, -5.58974346219658380687E-10, + -3.83538038596423702205E-9, -2.63146884688951950684E-8, + -2.51223623787020892529E-7, -3.88256480887769039346E-6, + -1.10588938762623716291E-4, -9.76109749136146840777E-3, + 7.78576235018280120474E-1}; + const auto out = (::metal::exp(x) * chbevl(32. / x - 2., coefficients, 25)) / + ::metal::sqrt(x); + return static_cast(_x < T(0.) ? -out : out); +} + +template +inline T i1e(T _x) { + const auto x = ::metal::fabs(_x); + if (x <= 8.0) { + // Chebyshev double coefficients for exp(-x) i1(x) in the interval [0,8]. + // Note: lim(x->0){ exp(-x) i1(x) / x } = 1/2. + constexpr float coefficients[] = { + 9.38153738649577178388E-9f, + -4.44505912879632808065E-8f, + 2.00329475355213526229E-7f, + -8.56872026469545474066E-7f, + 3.47025130813767847674E-6f, + -1.32731636560394358279E-5f, + 4.78156510755005422638E-5f, + -1.61760815825896745588E-4f, + 5.12285956168575772895E-4f, + -1.51357245063125314899E-3f, + 4.15642294431288815669E-3f, + -1.05640848946261981558E-2f, + 2.47264490306265168283E-2f, + -5.29459812080949914269E-2f, + 1.02643658689847095384E-1f, + -1.76416518357834055153E-1f, + 2.52587186443633654823E-1f}; + const auto y = x / 2.0 - 2.0; + const auto out = chbevl(y, coefficients, 17) * x; + return static_cast(_x < 0. ? -out : out); + } + + // Chebyshev coefficients for exp(-x) sqrt(x) i1(x) + // in the inverted interval (8, infinity]. + // Note: lim(x->inf){ exp(-x) sqrt(x) i1(x) } = 1/sqrt(2pi). + // TODO: what's an "inverted interval"? Open on the left + // and closed on the right? + constexpr float coefficients[] = { + -3.83538038596423702205E-9f, + -2.63146884688951950684E-8f, + -2.51223623787020892529E-7f, + -3.88256480887769039346E-6f, + -1.10588938762623716291E-4f, + -9.76109749136146840777E-3f, + 7.78576235018280120474E-1f}; + + const auto out = + chbevl(32. / x - 2., coefficients, 7) / ::metal::precise::sqrt(x); + return static_cast(_x < 0. ? -out : out); +} + +// gamma, lgamma +template +inline float log_gamma(const T); + +template +inline float gamma(const T x) { + if (x < 0.001) { + constexpr float EULER_MASCHERONI = 0.577215664901532860606512090; + // For small x, 1/gamma(x) has power series x + gamma x^2 - ... + // So in this range, 1/gamma(x) = x + gamma x^2 with error on the order of + // x^3. The relative error over this interval is less than 6e-7. + + return 1.0 / (x * (1.0 + EULER_MASCHERONI * x)); + } + if (x >= 12.0) { + return ::metal::exp(log_gamma(x)); + } + // The algorithm directly approximates gamma over (1,2) and uses + // reduction identities to reduce other arguments to this interval. + // numerator coefficients for gamma approximation over the interval (1,2) + constexpr float GAMMA_NUMERATOR_COEF[8] = { + -1.71618513886549492533811E+0, + 2.47656508055759199108314E+1, + -3.79804256470945635097577E+2, + 6.29331155312818442661052E+2, + 8.66966202790413211295064E+2, + -3.14512729688483675254357E+4, + -3.61444134186911729807069E+4, + 6.64561438202405440627855E+4}; + + // denominator coefficients for gamma approximation over the interval (1,2) + constexpr float GAMMA_DENOMINATOR_COEF[8] = { + -3.08402300119738975254353E+1, + 3.15350626979604161529144E+2, + -1.01515636749021914166146E+3, + -3.10777167157231109440444E+3, + 2.25381184209801510330112E+4, + 4.75584627752788110767815E+3, + -1.34659959864969306392456E+5, + -1.15132259675553483497211E+5}; + + // Add or subtract integers as necessary to bring y into (1,2) + float y = 1.0 + ::metal::fract(x); + + float num = 0.0; + float den = 1.0; + + float z = y - 1; + for (int i = 0; i < 8; i++) { + num = (num + GAMMA_NUMERATOR_COEF[i]) * z; + den = den * z + GAMMA_DENOMINATOR_COEF[i]; + } + float result = num / den + 1.0; + + // Apply correction if argument was not initially in (1,2) + if (x < 1.0) { + // identity gamma(z) = gamma(z+1)/z + result /= (y - 1.0); + } else { + // identity gamma(z+n) = z*(z+1)* ... *(z+n-1)*gamma(z) + auto n = static_cast(::metal::floor(x)); + for (int i = 1; i < n; i++) { + result *= y++; + } + } + + return result; +} + +template +inline float log_gamma(const T x) { + constexpr float LOG_PI = 1.14472988584940017414342735135305; + constexpr float HALF_LOG_TWO_PI = 0.91893853320467274178032973640562; + constexpr float LGAMMA_EXPANSION_COEF[8] = { + 1.0 / 12.0, + -1.0 / 360.0, + 1.0 / 1260.0, + -1.0 / 1680.0, + 1.0 / 1188.0, + -691.0 / 360360.0, + 1.0 / 156.0, + -3617.0 / 122400.0}; + + float rc; + + const auto abs_x = ::metal::abs(static_cast(x)); + if (abs_x == 0) { + return INFINITY; + } + if (abs_x < 12.0) { + rc = ::metal::log(::metal::abs(gamma(abs_x))); + } else { + // Abramowitz and Stegun 6.1.41 + // Asymptotic series should be good to at least 11 or 12 figures + // For error analysis, see Whittiker and Watson + // A Course in Modern Analysis (1927), page 252 + + float z = 1.0 / (abs_x * abs_x); + float sum = LGAMMA_EXPANSION_COEF[7]; + + for (int i = 6; i >= 0; i--) { + sum *= z; + sum += LGAMMA_EXPANSION_COEF[i]; + } + float series = sum / abs_x; + + rc = (abs_x - 0.5) * ::metal::log(abs_x) - abs_x + HALF_LOG_TWO_PI + series; + } + + if (x >= 0) { + return rc; + } + + // Reflection formula + // Compute arg first to workaround Metal compiler bgg of sorts on M4 + // See https://github.com/pytorch/pytorch/pull/145740 for more details + auto log_arg = abs_x * ::metal::abs(::metal::sinpi(abs_x)); + return LOG_PI - rc - ::metal::log(log_arg); +} + +inline float zeta(float x, float q) { + constexpr float MACHEP = 1.11022302462515654042E-16; + constexpr float ZETA_EXPANSION[] = { + 12.0, + -720.0, + 30240.0, + -1209600.0, + 47900160.0, + -1.8924375803183791606e9, + 7.47242496e10, + -2.950130727918164224e12, + 1.1646782814350067249e14, + -4.5979787224074726105e15, + 1.8152105401943546773e17, + -7.1661652561756670113e18}; + if (x == 1.0f) { + return INFINITY; + } + + if (x < 1.0f) { + return NAN; + } + + if (q <= 0.0f) { + if (q == ::metal::trunc(q)) { + return INFINITY; + } + if (x != ::metal::trunc(x)) { + return NAN; + } + } + + float s = ::metal::pow(q, -x); + float a = q; + int i = 0; + float b = 0.0f; + while ((i < 9) || (a <= 9.0f)) { + i += 1; + a += 1.0f; + b = ::metal::pow(a, -x); + s += b; + if ((-MACHEP * s < b) && (b < MACHEP * s)) { + return s; + } + } + + float w = a; + s += b * w / (x - 1.0f); + s -= 0.5f * b; + a = 1.0f; + float t; + float k = 0.0f; + for (int i = 0; i < 12; i++) { + a *= x + k; + b /= w; + t = a * b / ZETA_EXPANSION[i]; + s += t; + t = ::metal::fabs(t / s); + if (t < MACHEP) { + return s; + } + k += 1.0f; + a *= x + k; + b /= w; + k += 1.0f; + } + return s; +} + +inline float calc_digamma_positive_domain(float x) { + constexpr float DIGAMMA_COEF[7] = { + 8.33333333333333333333E-2, + -2.10927960927960927961E-2, + 7.57575757575757575758E-3, + -4.16666666666666666667E-3, + 3.96825396825396825397E-3, + -8.33333333333333333333E-3, + 8.33333333333333333333E-2, + }; + + // Push x to be >= 10 + float result = 0; + while (x < 10) { + result -= 1 / x; + x += 1; + } + if (x == 10) { + constexpr float PSI_10 = 2.25175258906672110764; + return result + PSI_10; + } + + // Compute asymptotic digamma + float y = 0; + if (x < 1.0E+17) { + float z = 1.0 / (x * x); + for (int i = 0; i <= 6; i++) { + y += ::metal::pow(z, i) * DIGAMMA_COEF[i]; + } + y *= z; + } + return result + ::metal::log(x) - (0.5 / x) - y; +} + +template +inline float digamma(T0 x) { + if (x < 0.0f) { + if (x == ::metal::trunc(x)) { + // As per C++ standard for gamma related functions and SciPy, + // If the argument is a negative integer, NaN is returned + return NAN; + } else { + // Extracts the fractional part of x as r, since tan(pi * r) is more + // numerically accurate than tan(pi * x). While these operations are + // mathematically equivalent since both x and r are in radians and tan() + // has a periodicity of pi, in practice the computation of pi * x is a + // source of error (when |x| > 1). + float r = ::metal::fract(x); + return calc_digamma_positive_domain(1.0f - x) - + M_PI_F / ::metal::tan(M_PI_F * r); + } + } else if (x == 0.0f) { + // As per C++ standard for gamma related functions and SciPy, + // If the argument is ±0, ±∞ is returned + return ::metal::copysign(INFINITY, static_cast(-x)); + } else { + return calc_digamma_positive_domain(x); + } +} + +template +inline float polygamma(const int64_t order, const T0 input) { + // Filter out n == 0. + if (order == 0) { + return digamma(input); + } + + float x = input; + float n = order; + float sgn = ((order % 2) ? 1 : -1); + return sgn * gamma(n + 1) * zeta(n + 1, x); +} + +template +inline ::metal::enable_if_t, T> sinc(T a) { + if (a == static_cast(0)) { + return static_cast(1); + } + auto product = M_PI_F * static_cast(a); + return static_cast(::metal::precise::sin(product) / product); +} + +// Complex sinc2 implementation +template +inline ::metal::enable_if_t, T> sinc(T inp) { + auto a = static_cast(inp) * M_PI_F; + const float a2 = a.x * a.x + a.y * a.y; + if (a2 == 0) { + return 0; + } + float cosx; + float sinx = ::metal::sincos(a.x, cosx); + float sinhy = ::metal::sinh(a.y); + float coshy = ::metal::cosh(a.y); + auto re = sinx * coshy * a.x + cosx * sinhy * a.y; + auto im = cosx * sinhy * a.x - sinx * coshy * a.y; + return T(re, im) / a2; +} + +template +inline T spherical_bessel_j0(T x) { + if (::metal::isinf(x)) + return T(0.0); + T x2 = x * x; + T k1 = static_cast(-1.0); + T k2 = static_cast(1.0); + + if (::metal::fabs(static_cast(x)) < T(0.5)) { + return T(1.0) + + x2 * + (k1 / T(6.0) + + x2 * + (k2 / T(120.0) + + x2 * + (k1 / T(5040.0) + + x2 * + (k2 / T(362880.0) + + x2 * + (k1 / T(39916800.0) + + x2 * (k2 / T(6227020800.0))))))); + } + + return static_cast(::metal::sin(x) / x); +} + +template +inline ::metal::enable_if_t, T> logaddexp( + T a, + T b) { + float a0 = static_cast(a); + float b0 = static_cast(b); + if (::metal::isinf(a0) && a0 == b0) { + return static_cast(a0); + } else { + float m0 = ::metal::max(a0, b0); + return static_cast( + m0 + ::c10::metal::log1p(::metal::exp(-::metal::abs(a0 - b0)))); + } +} + +// The function is ported from mlx +template +inline ::metal::enable_if_t, T> logaddexp(T a, T b) { + if (::metal::isnan(a.x) || ::metal::isnan(a.y) || ::metal::isnan(b.x) || + ::metal::isnan(b.y)) { + return T(NAN, NAN); + } + + T maxval = a.x > b.x ? a : b; + T minval = a.x < b.x ? a : b; + constexpr auto inf = ::metal::numeric_limits::infinity().x; + + if (minval.x == -inf || maxval.x == inf) { + return maxval; + } + + float2 maxval_ = static_cast(maxval); + float2 minval_ = static_cast(minval); + float m = ::metal::exp(minval_.x - maxval_.x); + float2 dexp{ + m * ::metal::cos(minval_.y - maxval_.y), + m * ::metal::sin(minval_.y - maxval_.y), + }; + return static_cast(maxval_ + ::c10::metal::log1p(dexp)); +} + +template +inline T logaddexp2(T a, T b) { + constexpr auto log_2 = float(0.693147180559945309417232121458176); + constexpr auto inv_log_2 = float(1) / log_2; + float a0 = static_cast(a); + float b0 = static_cast(b); + if (::metal::isinf(a0) && a0 == b0) { + return static_cast(a0); + } else { + float m0 = ::metal::max(a0, b0); + return static_cast( + m0 + + ::c10::metal::log1p(::metal::pow(float(2), -::metal::abs(a0 - b0))) * + inv_log_2); + } +} + +template +inline float xlog1py(T x, T y) { + if (::metal::isnan(y)) { + return NAN; + } + + if (x == 0) { + return x; + } + + return x * ::c10::metal::log1p(y); +} + +template +inline T entr(T a) { + if (a != a) { + return a; + } + + if (a > 0) { + return static_cast(-a * ::metal::log(a)); + } + + if (a == 0) { + return 0; + } + + return static_cast(-INFINITY); +} + +// Copy-n-paste from aten/src/ATen/native/cuda/Math.cuh lines 1463-1915 +template +inline float bessel_j0_forward(T x) { + constexpr float PP[] = { + +7.96936729297347051624e-04, + +8.28352392107440799803e-02, + +1.23953371646414299388e+00, + +5.44725003058768775090e+00, + +8.74716500199817011941e+00, + +5.30324038235394892183e+00, + +9.99999999999999997821e-01, + }; + + constexpr float PQ[] = { + +9.24408810558863637013e-04, + +8.56288474354474431428e-02, + +1.25352743901058953537e+00, + +5.47097740330417105182e+00, + +8.76190883237069594232e+00, + +5.30605288235394617618e+00, + +1.00000000000000000218e+00, + }; + + constexpr float QP[] = { + -1.13663838898469149931e-02, + -1.28252718670509318512e+00, + -1.95539544257735972385e+01, + -9.32060152123768231369e+01, + -1.77681167980488050595e+02, + -1.47077505154951170175e+02, + -5.14105326766599330220e+01, + -6.05014350600728481186e+00, + }; + + constexpr float QQ[] = { + +6.43178256118178023184e+01, + +8.56430025976980587198e+02, + +3.88240183605401609683e+03, + +7.24046774195652478189e+03, + +5.93072701187316984827e+03, + +2.06209331660327847417e+03, + +2.42005740240291393179e+02, + }; + + constexpr float RP[] = { + -4.79443220978201773821e+09, + +1.95617491946556577543e+12, + -2.49248344360967716204e+14, + +9.70862251047306323952e+15, + }; + + constexpr float RQ[] = { + +4.99563147152651017219e+02, + +1.73785401676374683123e+05, + +4.84409658339962045305e+07, + +1.11855537045356834862e+10, + +2.11277520115489217587e+12, + +3.10518229857422583814e+14, + +3.18121955943204943306e+16, + +1.71086294081043136091e+18, + }; + + if (x < T(0)) { + x = -x; + } + + if (x <= T(5.0)) { + if (x < T(0.00001)) { + return 1.0 - x * x / 4.0; + } + + float rp = 0.0; + + for (auto index = 0; index <= 3; index++) { + rp = rp * (x * x) + RP[index]; + } + + float rq = 0.0; + + for (auto index = 0; index <= 7; index++) { + rq = rq * (x * x) + RQ[index]; + } + + return (x * x - 5.78318596294678452118e+00) * + (x * x - T(3.04712623436620863991e+01)) * rp / rq; + } + + float pp = 0.0; + + for (auto index = 0; index <= 6; index++) { + pp = pp * (25.0 / (x * x)) + PP[index]; + } + + float pq = 0.0; + + for (auto index = 0; index <= 6; index++) { + pq = pq * (25.0 / (x * x)) + PQ[index]; + } + + float qp = 0.0; + + for (auto index = 0; index <= 7; index++) { + qp = qp * (25.0 / (x * x)) + QP[index]; + } + + float qq = 0.0; + + for (auto index = 0; index <= 6; index++) { + qq = qq * (25.0 / (x * x)) + QQ[index]; + } + + return (pp / pq * + ::metal::precise::cos( + x - T(0.785398163397448309615660845819875721)) - + 5.0 / x * (qp / qq) * + ::metal::precise::sin( + x - 0.785398163397448309615660845819875721)) * + 0.797884560802865355879892119868763737 / ::metal::precise::sqrt(x); +} // bessel_j0_forward(T x) + +template +inline float bessel_y0_forward(T x) { + constexpr float PP[] = { + +7.96936729297347051624e-04, + +8.28352392107440799803e-02, + +1.23953371646414299388e+00, + +5.44725003058768775090e+00, + +8.74716500199817011941e+00, + +5.30324038235394892183e+00, + +9.99999999999999997821e-01, + }; + + constexpr float PQ[] = { + +9.24408810558863637013e-04, + +8.56288474354474431428e-02, + +1.25352743901058953537e+00, + +5.47097740330417105182e+00, + +8.76190883237069594232e+00, + +5.30605288235394617618e+00, + +1.00000000000000000218e+00, + }; + + constexpr float QP[] = { + -1.13663838898469149931e-02, + -1.28252718670509318512e+00, + -1.95539544257735972385e+01, + -9.32060152123768231369e+01, + -1.77681167980488050595e+02, + -1.47077505154951170175e+02, + -5.14105326766599330220e+01, + -6.05014350600728481186e+00, + }; + + constexpr float QQ[] = { + +6.43178256118178023184e+01, + +8.56430025976980587198e+02, + +3.88240183605401609683e+03, + +7.24046774195652478189e+03, + +5.93072701187316984827e+03, + +2.06209331660327847417e+03, + +2.42005740240291393179e+02, + }; + + constexpr float YP[] = { + +1.55924367855235737965e+04, + -1.46639295903971606143e+07, + +5.43526477051876500413e+09, + -9.82136065717911466409e+11, + +8.75906394395366999549e+13, + -3.46628303384729719441e+15, + +4.42733268572569800351e+16, + -1.84950800436986690637e+16, + }; + + constexpr float YQ[] = { + +1.04128353664259848412e+03, + +6.26107330137134956842e+05, + +2.68919633393814121987e+08, + +8.64002487103935000337e+10, + +2.02979612750105546709e+13, + +3.17157752842975028269e+15, + +2.50596256172653059228e+17, + }; + + if (x <= T(5.0)) { + if (x == T(0.0)) { + return -INFINITY; + } + + if (x < T(0.0)) { + return NAN; + } + + float yp = 0.0; + + for (auto index = 0; index <= 7; index++) { + yp = yp * (x * x) + YP[index]; + } + + float yq = 0.0; + + for (auto index = 0; index <= 6; index++) { + yq = yq * (x * x) + YQ[index]; + } + + return yp / yq + + (0.636619772367581343075535053490057448 * ::metal::precise::log(x) * + bessel_j0_forward(x)); + } + + float pp = 0.0; + + for (auto index = 0; index <= 6; index++) { + pp = pp * (25.0 / (x * x)) + PP[index]; + } + + float pq = 0.0; + + for (auto index = 0; index <= 6; index++) { + pq = pq * (25.0 / (x * x)) + PQ[index]; + } + + float qp = 0.0; + + for (auto index = 0; index <= 7; index++) { + qp = qp * (25.0 / (x * x)) + QP[index]; + } + + float qq = 0.0; + + for (auto index = 0; index <= 6; index++) { + qq = qq * (25.0 / (x * x)) + QQ[index]; + } + + return (pp / pq * + ::metal::precise::sin( + x - 0.785398163397448309615660845819875721) + + 5.0 / x * (qp / qq) * + ::metal::precise::cos( + x - 0.785398163397448309615660845819875721)) * + 0.797884560802865355879892119868763737 / ::metal::precise::sqrt(x); +} // bessel_y0_forward(T x) + +template +inline float bessel_j1_forward(T x) { + constexpr float PP[] = { + +7.62125616208173112003e-04, + +7.31397056940917570436e-02, + +1.12719608129684925192e+00, + +5.11207951146807644818e+00, + +8.42404590141772420927e+00, + +5.21451598682361504063e+00, + +1.00000000000000000254e+00, + }; + + constexpr float PQ[] = { + +5.71323128072548699714e-04, + +6.88455908754495404082e-02, + +1.10514232634061696926e+00, + +5.07386386128601488557e+00, + +8.39985554327604159757e+00, + +5.20982848682361821619e+00, + +9.99999999999999997461e-01, + }; + + constexpr float QP[] = { + +5.10862594750176621635e-02, + +4.98213872951233449420e+00, + +7.58238284132545283818e+01, + +3.66779609360150777800e+02, + +7.10856304998926107277e+02, + +5.97489612400613639965e+02, + +2.11688757100572135698e+02, + +2.52070205858023719784e+01, + }; + + constexpr float QQ[] = { + +7.42373277035675149943e+01, + +1.05644886038262816351e+03, + +4.98641058337653607651e+03, + +9.56231892404756170795e+03, + +7.99704160447350683650e+03, + +2.82619278517639096600e+03, + +3.36093607810698293419e+02, + }; + + constexpr float RP[] = { + -8.99971225705559398224e+08, + +4.52228297998194034323e+11, + -7.27494245221818276015e+13, + +3.68295732863852883286e+15, + }; + + constexpr float RQ[] = { + +6.20836478118054335476e+02, + +2.56987256757748830383e+05, + +8.35146791431949253037e+07, + +2.21511595479792499675e+10, + +4.74914122079991414898e+12, + +7.84369607876235854894e+14, + +8.95222336184627338078e+16, + +5.32278620332680085395e+18, + }; + + if (x < T(0.0)) { + return -bessel_j1_forward(-x); + } + + if (x <= T(5.0)) { + float rp = 0.0; + + for (auto index = 0; index <= 3; index++) { + rp = rp * (x * x) + RP[index]; + } + + float rq = 0.0; + + for (auto index = 0; index <= 7; index++) { + rq = rq * (x * x) + RQ[index]; + } + + return rp / rq * x * (x * x - 1.46819706421238932572e+01) * + (x * x - 4.92184563216946036703e+01); + } + + float pp = 0.0; + + for (auto index = 0; index <= 6; index++) { + pp = pp * (5.0 / x * (5.0 / x)) + PP[index]; + } + + float pq = 0.0; + + for (auto index = 0; index <= 6; index++) { + pq = pq * (5.0 / x * (5.0 / x)) + PQ[index]; + } + + float qp = 0.0; + + for (auto index = 0; index <= 7; index++) { + qp = qp * (5.0 / x * (5.0 / x)) + QP[index]; + } + + float qq = 0.0; + + for (auto index = 0; index <= 6; index++) { + qq = qq * (5.0 / x * (5.0 / x)) + QQ[index]; + } + + return (pp / pq * + ::metal::precise::cos( + x - 2.356194490192344928846982537459627163) - + 5.0 / x * (qp / qq) * + ::metal::precise::sin( + x - 2.356194490192344928846982537459627163)) * + 0.797884560802865355879892119868763737 / ::metal::precise::sqrt(x); +} // bessel_j1_forward(T x) + +template +inline float bessel_y1_forward(T x) { + constexpr float PP[] = { + +7.62125616208173112003e-04, + +7.31397056940917570436e-02, + +1.12719608129684925192e+00, + +5.11207951146807644818e+00, + +8.42404590141772420927e+00, + +5.21451598682361504063e+00, + +1.00000000000000000254e+00, + }; + + constexpr float PQ[] = { + +5.71323128072548699714e-04, + +6.88455908754495404082e-02, + +1.10514232634061696926e+00, + +5.07386386128601488557e+00, + +8.39985554327604159757e+00, + +5.20982848682361821619e+00, + +9.99999999999999997461e-01, + }; + + constexpr float QP[] = { + +5.10862594750176621635e-02, + +4.98213872951233449420e+00, + +7.58238284132545283818e+01, + +3.66779609360150777800e+02, + +7.10856304998926107277e+02, + +5.97489612400613639965e+02, + +2.11688757100572135698e+02, + +2.52070205858023719784e+01, + }; + + constexpr float QQ[] = { + +7.42373277035675149943e+01, + +1.05644886038262816351e+03, + +4.98641058337653607651e+03, + +9.56231892404756170795e+03, + +7.99704160447350683650e+03, + +2.82619278517639096600e+03, + +3.36093607810698293419e+02, + }; + + constexpr float YP[] = { + +1.26320474790178026440e+09, + -6.47355876379160291031e+11, + +1.14509511541823727583e+14, + -8.12770255501325109621e+15, + +2.02439475713594898196e+17, + -7.78877196265950026825e+17, + }; + + constexpr float YQ[] = { + +5.94301592346128195359e+02, + +2.35564092943068577943e+05, + +7.34811944459721705660e+07, + +1.87601316108706159478e+10, + +3.88231277496238566008e+12, + +6.20557727146953693363e+14, + +6.87141087355300489866e+16, + +3.97270608116560655612e+18, + }; + + if (x <= T(5.0)) { + if (x == T(0.0)) { + return -INFINITY; + } + + if (x <= T(0.0)) { + return NAN; + } + + float yp = 0.0; + + for (auto index = 0; index <= 5; index++) { + yp = yp * (x * x) + YP[index]; + } + + float yq = 0.0; + + for (auto index = 0; index <= 7; index++) { + yq = yq * (x * x) + YQ[index]; + } + + return x * (yp / yq) + + (0.636619772367581343075535053490057448 * + (bessel_j1_forward(x) * ::metal::precise::log(x) - 1.0 / x)); + } + + float pp = 0.0; + + for (auto index = 0; index <= 6; index++) { + pp = pp * (5.0 / x * (5.0 / x)) + PP[index]; + } + + float pq = 0.0; + + for (auto index = 0; index <= 6; index++) { + pq = pq * (5.0 / x * (5.0 / x)) + PQ[index]; + } + + float qp = 0.0; + + for (auto index = 0; index <= 7; index++) { + qp = qp * (5.0 / x * (5.0 / x)) + QP[index]; + } + + float qq = 0.0; + + for (auto index = 0; index <= 6; index++) { + qq = qq * (5.0 / x * (5.0 / x)) + QQ[index]; + } + + return (pp / pq * + ::metal::precise::sin( + x - 2.356194490192344928846982537459627163) + + 5.0 / x * (qp / qq) * + ::metal::precise::cos( + x - 2.356194490192344928846982537459627163)) * + 0.797884560802865355879892119868763737 / ::metal::precise::sqrt(x); +} // bessel_y1_forward(T x) + +template +inline float modified_bessel_i0_forward(T x) { + constexpr float A[] = { + -4.41534164647933937950e-18, +3.33079451882223809783e-17, + -2.43127984654795469359e-16, +1.71539128555513303061e-15, + -1.16853328779934516808e-14, +7.67618549860493561688e-14, + -4.85644678311192946090e-13, +2.95505266312963983461e-12, + -1.72682629144155570723e-11, +9.67580903537323691224e-11, + -5.18979560163526290666e-10, +2.65982372468238665035e-09, + -1.30002500998624804212e-08, +6.04699502254191894932e-08, + -2.67079385394061173391e-07, +1.11738753912010371815e-06, + -4.41673835845875056359e-06, +1.64484480707288970893e-05, + -5.75419501008210370398e-05, +1.88502885095841655729e-04, + -5.76375574538582365885e-04, +1.63947561694133579842e-03, + -4.32430999505057594430e-03, +1.05464603945949983183e-02, + -2.37374148058994688156e-02, +4.93052842396707084878e-02, + -9.49010970480476444210e-02, +1.71620901522208775349e-01, + -3.04682672343198398683e-01, +6.76795274409476084995e-01, + }; + + constexpr float B[] = { + -7.23318048787475395456e-18, -4.83050448594418207126e-18, + +4.46562142029675999901e-17, +3.46122286769746109310e-17, + -2.82762398051658348494e-16, -3.42548561967721913462e-16, + +1.77256013305652638360e-15, +3.81168066935262242075e-15, + -9.55484669882830764870e-15, -4.15056934728722208663e-14, + +1.54008621752140982691e-14, +3.85277838274214270114e-13, + +7.18012445138366623367e-13, -1.79417853150680611778e-12, + -1.32158118404477131188e-11, -3.14991652796324136454e-11, + +1.18891471078464383424e-11, +4.94060238822496958910e-10, + +3.39623202570838634515e-09, +2.26666899049817806459e-08, + +2.04891858946906374183e-07, +2.89137052083475648297e-06, + +6.88975834691682398426e-05, +3.36911647825569408990e-03, + +8.04490411014108831608e-01, + }; + + float p; + float q = 0.0; + + if (::metal::fabs(x) <= 8.0) { + float a = A[0]; + + for (uint8_t index = 1; index < 30; index++) { + p = q; + q = a; + a = (.5 * ::metal::fabs(x) - 2.0) * q - p + A[index]; + } + + return ::metal::exp(::metal::fabs(x)) * (T(0.5) * (a - p)); + } + + float b = B[0]; + + for (uint8_t index = 1; index < 25; index++) { + p = q; + q = b; + b = (32.0 / ::metal::fabs(x) - 2.0) * q - p + B[index]; + } + + return ::metal::exp(::metal::fabs(x)) * (.5 * (b - p)) / + ::metal::precise::sqrt(::metal::fabs(x)); +} // modified_bessel_i0_forward(T x) + +template +inline float modified_bessel_i1_forward(T x) { + constexpr float A[] = { + +2.77791411276104639959e-18, -2.11142121435816608115e-17, + +1.55363195773620046921e-16, -1.10559694773538630805e-15, + +7.60068429473540693410e-15, -5.04218550472791168711e-14, + +3.22379336594557470981e-13, -1.98397439776494371520e-12, + +1.17361862988909016308e-11, -6.66348972350202774223e-11, + +3.62559028155211703701e-10, -1.88724975172282928790e-09, + +9.38153738649577178388e-09, -4.44505912879632808065e-08, + +2.00329475355213526229e-07, -8.56872026469545474066e-07, + +3.47025130813767847674e-06, -1.32731636560394358279e-05, + +4.78156510755005422638e-05, -1.61760815825896745588e-04, + +5.12285956168575772895e-04, -1.51357245063125314899e-03, + +4.15642294431288815669e-03, -1.05640848946261981558e-02, + +2.47264490306265168283e-02, -5.29459812080949914269e-02, + +1.02643658689847095384e-01, -1.76416518357834055153e-01, + +2.52587186443633654823e-01, + }; + + constexpr float B[] = { + +7.51729631084210481353e-18, +4.41434832307170791151e-18, + -4.65030536848935832153e-17, -3.20952592199342395980e-17, + +2.96262899764595013876e-16, +3.30820231092092828324e-16, + -1.88035477551078244854e-15, -3.81440307243700780478e-15, + +1.04202769841288027642e-14, +4.27244001671195135429e-14, + -2.10154184277266431302e-14, -4.08355111109219731823e-13, + -7.19855177624590851209e-13, +2.03562854414708950722e-12, + +1.41258074366137813316e-11, +3.25260358301548823856e-11, + -1.89749581235054123450e-11, -5.58974346219658380687e-10, + -3.83538038596423702205e-09, -2.63146884688951950684e-08, + -2.51223623787020892529e-07, -3.88256480887769039346e-06, + -1.10588938762623716291e-04, -9.76109749136146840777e-03, + +7.78576235018280120474e-01, + }; + + float p; + float q = 0.0; + + if (::metal::fabs(x) <= T(8.0)) { + float a = A[0]; + + for (uint8_t index = 1; index < 29; index++) { + p = q; + q = a; + a = (.5 * ::metal::fabs(x) - 2.0) * q - p + A[index]; + } + + return .5 * (a - p) * x * ::metal::precise::exp(::metal::fabs(x)); + } + + float b = B[0]; + + for (uint8_t index = 1; index < 25; index++) { + p = q; + q = b; + b = (32.0 / ::metal::fabs(x) - 2.0) * q - p + B[index]; + } + + if (x < 0.0) { + return -( + ::metal::precise::exp(::metal::fabs(x)) * (0.5 * (b - p)) / + ::metal::precise::sqrt(::metal::fabs(x))); + } + + return ::metal::precise::exp(::metal::fabs(x)) * (0.5 * (b - p)) / + ::metal::precise::sqrt(::metal::fabs(x)); +} // modified_bessel_i1_forward(T x) + +template +inline float modified_bessel_k0_forward(T x) { + constexpr float A[] = { + +1.37446543561352307156e-16, + +4.25981614279661018399e-14, + +1.03496952576338420167e-11, + +1.90451637722020886025e-09, + +2.53479107902614945675e-07, + +2.28621210311945178607e-05, + +1.26461541144692592338e-03, + +3.59799365153615016266e-02, + +3.44289899924628486886e-01, + -5.35327393233902768720e-01, + }; + + constexpr float B[] = { + +5.30043377268626276149e-18, -1.64758043015242134646e-17, + +5.21039150503902756861e-17, -1.67823109680541210385e-16, + +5.51205597852431940784e-16, -1.84859337734377901440e-15, + +6.34007647740507060557e-15, -2.22751332699166985548e-14, + +8.03289077536357521100e-14, -2.98009692317273043925e-13, + +1.14034058820847496303e-12, -4.51459788337394416547e-12, + +1.85594911495471785253e-11, -7.95748924447710747776e-11, + +3.57739728140030116597e-10, -1.69753450938905987466e-09, + +8.57403401741422608519e-09, -4.66048989768794782956e-08, + +2.76681363944501510342e-07, -1.83175552271911948767e-06, + +1.39498137188764993662e-05, -1.28495495816278026384e-04, + +1.56988388573005337491e-03, -3.14481013119645005427e-02, + +2.44030308206595545468e+00, + }; + + if (x == 0.0) { + return INFINITY; + } + + if (x < 0.0) { + return NAN; + } + + float p; + float q = 0.0; + + if (x <= 2.0) { + float a = A[0]; + + for (uint8_t index = 1; index < 10; index++) { + p = q; + q = a; + a = (x * x - 2.0) * q - p + A[index]; + } + + return 0.5 * (a - p) - + ::metal::log(0.5 * x) * modified_bessel_i0_forward(x); + } + + float b = B[0]; + + for (uint8_t index = 1; index < 25; index++) { + p = q; + q = b; + b = (8.0 / x - 2.0) * q - p + B[index]; + } + + return ::metal::exp(-x) * (0.5 * (b - p)) / ::metal::sqrt(x); +} // modified_bessel_k0_forward(T x) + +template +inline float modified_bessel_k1_forward(T x) { + constexpr float A[] = { + -7.02386347938628759343e-18, + -2.42744985051936593393e-15, + -6.66690169419932900609e-13, + -1.41148839263352776110e-10, + -2.21338763073472585583e-08, + -2.43340614156596823496e-06, + -1.73028895751305206302e-04, + -6.97572385963986435018e-03, + -1.22611180822657148235e-01, + -3.53155960776544875667e-01, + +1.52530022733894777053e+00, + }; + + constexpr float B[] = { + -5.75674448366501715755e-18, +1.79405087314755922667e-17, + -5.68946255844285935196e-17, +1.83809354436663880070e-16, + -6.05704724837331885336e-16, +2.03870316562433424052e-15, + -7.01983709041831346144e-15, +2.47715442448130437068e-14, + -8.97670518232499435011e-14, +3.34841966607842919884e-13, + -1.28917396095102890680e-12, +5.13963967348173025100e-12, + -2.12996783842756842877e-11, +9.21831518760500529508e-11, + -4.19035475934189648750e-10, +2.01504975519703286596e-09, + -1.03457624656780970260e-08, +5.74108412545004946722e-08, + -3.50196060308781257119e-07, +2.40648494783721712015e-06, + -1.93619797416608296024e-05, +1.95215518471351631108e-04, + -2.85781685962277938680e-03, +1.03923736576817238437e-01, + +2.72062619048444266945e+00, + }; + + if (x == 0.0) { + return INFINITY; + } + + if (x < 0.0) { + return NAN; + } + + float p; + float q = 0.0; + + if (x <= 2.0) { + float a = A[0]; + + for (uint8_t index = 1; index < 11; index++) { + p = q; + q = a; + a = (x * x - T(2.0)) * q - p + A[index]; + } + + return ::metal::precise::log(T(0.5) * x) * modified_bessel_i1_forward(x) + + 0.5 * (a - p) / x; + } + + float b = B[0]; + + for (uint8_t index = 1; index < 25; index++) { + p = q; + q = b; + b = (8.0 / x - 2.0) * q - p + B[index]; + } + + return ::metal::precise::exp(-x) * (0.5 * (b - p)) / + ::metal::precise::sqrt(x); +} + +template +inline float scaled_modified_bessel_k0_forward(T x) { + constexpr float A[] = { + +1.37446543561352307156e-16, + +4.25981614279661018399e-14, + +1.03496952576338420167e-11, + +1.90451637722020886025e-09, + +2.53479107902614945675e-07, + +2.28621210311945178607e-05, + +1.26461541144692592338e-03, + +3.59799365153615016266e-02, + +3.44289899924628486886e-01, + -5.35327393233902768720e-01, + }; + + constexpr float B[] = { + +5.30043377268626276149e-18, -1.64758043015242134646e-17, + +5.21039150503902756861e-17, -1.67823109680541210385e-16, + +5.51205597852431940784e-16, -1.84859337734377901440e-15, + +6.34007647740507060557e-15, -2.22751332699166985548e-14, + +8.03289077536357521100e-14, -2.98009692317273043925e-13, + +1.14034058820847496303e-12, -4.51459788337394416547e-12, + +1.85594911495471785253e-11, -7.95748924447710747776e-11, + +3.57739728140030116597e-10, -1.69753450938905987466e-09, + +8.57403401741422608519e-09, -4.66048989768794782956e-08, + +2.76681363944501510342e-07, -1.83175552271911948767e-06, + +1.39498137188764993662e-05, -1.28495495816278026384e-04, + +1.56988388573005337491e-03, -3.14481013119645005427e-02, + +2.44030308206595545468e+00, + }; + + if (x == 0.0) { + return INFINITY; + } + + if (x < 0.0) { + return NAN; + } + + float p; + float q = 0.0; + + if (x <= 2.0) { + float a = A[0]; + + for (uint8_t index = 1; index < 10; index++) { + p = q; + q = a; + a = (x * x - T(2.0)) * q - p + A[index]; + } + + return (0.5 * (a - p) - + ::metal::precise::log(0.5 * x) * modified_bessel_i0_forward(x)) * + ::metal::precise::exp(x); + } + + float b = B[0]; + + for (uint8_t index = 1; index < 25; index++) { + p = q; + q = b; + b = (8.0 / x - 2.0) * q - p + B[index]; + } + + return 0.5 * (b - p) / ::metal::precise::sqrt(x); +} + +template +inline float scaled_modified_bessel_k1_forward(T x) { + constexpr float A[] = { + -7.02386347938628759343e-18, + -2.42744985051936593393e-15, + -6.66690169419932900609e-13, + -1.41148839263352776110e-10, + -2.21338763073472585583e-08, + -2.43340614156596823496e-06, + -1.73028895751305206302e-04, + -6.97572385963986435018e-03, + -1.22611180822657148235e-01, + -3.53155960776544875667e-01, + +1.52530022733894777053e+00, + }; + + constexpr float B[] = { + -5.75674448366501715755e-18, +1.79405087314755922667e-17, + -5.68946255844285935196e-17, +1.83809354436663880070e-16, + -6.05704724837331885336e-16, +2.03870316562433424052e-15, + -7.01983709041831346144e-15, +2.47715442448130437068e-14, + -8.97670518232499435011e-14, +3.34841966607842919884e-13, + -1.28917396095102890680e-12, +5.13963967348173025100e-12, + -2.12996783842756842877e-11, +9.21831518760500529508e-11, + -4.19035475934189648750e-10, +2.01504975519703286596e-09, + -1.03457624656780970260e-08, +5.74108412545004946722e-08, + -3.50196060308781257119e-07, +2.40648494783721712015e-06, + -1.93619797416608296024e-05, +1.95215518471351631108e-04, + -2.85781685962277938680e-03, +1.03923736576817238437e-01, + +2.72062619048444266945e+00, + }; + + if (x == 0.0) { + return INFINITY; + } + + if (x < 0.0) { + return NAN; + } + + float p; + float q = 0.0; + + if (x <= 2.0) { + float a = A[0]; + + for (uint8_t index = 1; index < 11; index++) { + p = q; + q = a; + a = (x * x - 2.0) * q - p + A[index]; + } + + return (::metal::precise::log(0.5 * x) * modified_bessel_i1_forward(x) + + 0.5 * (a - p) / x) * + ::metal::precise::exp(x); + } + + float b = B[0]; + + for (uint8_t index = 1; index < 25; index++) { + p = q; + q = b; + b = (8.0 / x - 2.0) * q - p + B[index]; + } + + return (0.5 * (b - p) / ::metal::precise::sqrt(x)); +} + +template +float chebyshev_polynomial_t_forward(T x, int64_t n) { + if (n < 0) { + return 0.0; + } + + if (::metal::fabs(x) == 1.0) { + if (x > 0.0 || n % 2 == 0) { + return 1.0; + } + + return -1.0; + } + + if ((n > 6) && (::metal::precise::fabs(x) < 1.0)) { + return ::metal::precise::cos(n * ::metal::precise::acos(x)); + } + + if (n == 0) { + return 1.0; + } + + if (n == 1) { + return x; + } + + float p = 1.0; + float q = x; + float r; + + for (int64_t k = 2; (k <= n) && !::metal::isnan(q); k++) { + r = (x + x) * q - p; + p = q; + q = r; + } + return r; +} + +template +float chebyshev_polynomial_u_forward(T x, int64_t n) { + if (n < 0) { + return 0.0; + } + + if (::metal::fabs(x) == 1.0) { + if (x > 0.0 || n % 2 == 0) { + return n + 1; + } + + return -(n + 1); + } + + if ((n > 8) && (::metal::fabs(x) < 1.0)) { + const auto acos_x = ::metal::precise::acos(x); + if (::metal::precise::sin(acos_x) != 0.0) { + return ::metal::precise::sin((n + 1) * acos_x) / + ::metal::precise::sin(acos_x); + } + + return (n + 1) * ::metal::precise::cos((n + 1) * acos_x) / x; + } + + if (n == 0) { + return 1.0; + } + + auto q = 2.0 * x; + if (n == 1) { + return q; + } + + auto p = 1.0; + float r; + + for (int64_t k = 2; (k <= n) && !::metal::isnan(q); k++) { + r = 2 * x * q - p; + p = q; + q = r; + } + + return r; +} + +template +float chebyshev_polynomial_v_forward(T x, int64_t n) { + if (n < 0) { + return 0.0; + } + + if (::metal::fabs(x) == 1.0) { + if (x > 0.0) { + return 1.0; + } + + if (n % 2 == 0) { + return n + n + 1; + } + + return -(n + n + 1); + } + + if ((n > 8) && (::metal::fabs(x) < 1.0)) { + const auto acos_x = ::metal::precise::acos(x); + if (::metal::precise::sin(.5 * acos_x) != 1.0) { + return ::metal::precise::cos((n + 0.5) * acos_x) / + ::metal::precise::cos(.5 * acos_x); + } + + if (n % 2 == 0) { + return n + n + 1; + } + + return -(n + n + 1); + } + + if (n == 0) { + return 1.0; + } + + auto q = 2.0 * x - 1.0; + if (n == 1) { + return q; + } + + auto p = 1.0; + float r; + + for (int64_t k = 2; (k <= n) && !::metal::isnan(q); k++) { + r = 2 * x * q - p; + p = q; + q = r; + } + + return r; +} // chebyshev_polynomial_v_forward(T x, int64_t n) + +template +float chebyshev_polynomial_w_forward(T x, int64_t n) { + if (n < 0) { + return 0.0; + } + + if (::metal::fabs(x) == 1.0) { + if (x > 0.0) { + return n + n + 1; + } + + if (n % 2 == 0) { + return 1.0; + } + + return -1.0; + } + + if ((n > 8) && (::metal::fabs(x) < 1.0)) { + const auto acos_x = ::metal::precise::acos(x); + if (::metal::precise::cos(.5 * acos_x) != 1.0) { + return ::metal::precise::sin((n + 0.5) * acos_x) / + ::metal::precise::sin(.5 * acos_x); + } + + if (x > 0.0) { + return n + n + 1; + } + + if (n % 2 == 0) { + return 1.0; + } + + return -1.0; + } + + if (n == 0) { + return 1.0; + } + + auto q = 2.0 * x + 1.0; + if (n == 1) { + return q; + } + + auto p = 1.0; + float r; + + for (int64_t k = 2; (k <= n) && !::metal::isnan(q); k++) { + r = 2.0 * x * q - p; + p = q; + q = r; + } + + return r; +} // chebyshev_polynomial_w_forward(T x, int64_t n) + +template +float shifted_chebyshev_polynomial_t_forward(T x, int64_t n) { + if (n < 0) { + return 0.0; + } + + if (x == T(1.0)) { + return 1.0; + } + + if (x == 0.0) { + if (n % 2 == 0) { + return 1.0; + } + + return -1.0; + } + + const float xpxm1 = x + x - 1.0; + if ((n > 6) && (::metal::abs(xpxm1) < 1.0)) { + return ::metal::precise::cos(n * ::metal::precise::acos(xpxm1)); + } + + if (n == 0) { + return 1.0; + } + + if (n == 1) { + return xpxm1; + } + + float p = 1.0; + float q = xpxm1; + float r; + + for (int64_t k = 2; (k <= n) && !::metal::isnan(q); k++) { + r = (xpxm1 + xpxm1) * q - p; + p = q; + q = r; + } + + return r; +} // shifted_chebyshev_polynomial_t_forward(T x, int64_t n) + +template +float shifted_chebyshev_polynomial_u_forward(T x, int64_t n) { + if (n < 0) { + return 0.0; + } + + if (x == 1.0) { + return n + 1; + } + + if (x == 0.0) { + if (n % 2 == 0) { + return n + 1; + } + + return -(n + 1); + } + const float xpxm1 = x + x - 1.0; + if ((n > 6) && (::metal::abs(xpxm1) < 1.0)) { + const float acos_2xm1 = ::metal::precise::acos(xpxm1); + const float divisor = ::metal::precise::sin(acos_2xm1); + if (divisor != 0.0) { + return ::metal::precise::sin((n + 1) * acos_2xm1) / divisor; + } + + return (n + 1) * ::metal::precise::cos((n + 1) * acos_2xm1) / xpxm1; + } + + if (n == 0) { + return 1.0; + } + + if (n == 1) { + return xpxm1 + xpxm1; + } + + float p = 1.0; + float q = xpxm1 + xpxm1; + float r; + + for (int64_t k = 2; (k <= n) && !::metal::isnan(q); k++) { + r = (xpxm1 + xpxm1) * q - p; + p = q; + q = r; + } + + return r; +} // shifted_chebyshev_polynomial_u_forward(T x, int64_t n) + +template +float shifted_chebyshev_polynomial_v_forward(T x, int64_t n) { + if (n < 0) { + return 0.0; + } + + if (x == 1.0) { + return 1.0; + } + + if (x == 0.0) { + if (n % 2 == 0) { + return (n + n + 1); + } + + return -(n + n + 1); + } + + const float xpxm1 = x + x - 1.0; + if ((n > 6) && (::metal::abs(xpxm1) < 1.0)) { + const float acos_2xm1 = ::metal::precise::acos(xpxm1); + if (::metal::precise::sin(acos_2xm1 / 2.0) != 1.0) { + return ::metal::precise::cos((n + 0.5) * acos_2xm1) / + ::metal::precise::cos(acos_2xm1 / 2.0); + } + + if (n % 2 == 0) { + return n + n + 1; + } + + return -(n + n + 1); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return xpxm1 + xpxm1 - 1.0; + } + + float p = 1.0; + float q = xpxm1 + xpxm1 - 1.0; + float r; + + for (int64_t k = 2; (k <= n) && !::metal::isnan(q); k++) { + r = (xpxm1 + xpxm1) * q - p; + p = q; + q = r; + } + + return r; +} // shifted_chebyshev_polynomial_v_forward(T x, int64_t n) + +template +float shifted_chebyshev_polynomial_w_forward(T x, int64_t n) { + if (n < 0) { + return 0.0; + } + + if (x == 1.0) { + return n + n + 1; + } + + if (x == 0.0) { + if (n % 2 == 0) { + return 1.0; + } + + return -1.0; + } + + const float xpxm1 = x + x - 1.0; + if ((n > 4) && (::metal::abs(xpxm1) < 1.0)) { + const float acos_2xm1 = ::metal::precise::acos(xpxm1); + if (::metal::precise::cos(acos_2xm1 / 2.0) != 1.0) { + return ::metal::precise::sin((n + 0.5) * acos_2xm1) / + ::metal::precise::sin(acos_2xm1 / 2.0); + } + + if (n % 2 == 0) { + return 1.0; + } + + return -1.0; + } + + if (n == 0) { + return 1.0; + } + + if (n == 1) { + return xpxm1 + xpxm1 + 1.0; + } + + float p = 1.0; + float q = xpxm1 + xpxm1 + 1.0; + float r; + + for (int64_t k = 2; (k <= n) && !::metal::isnan(q); k++) { + r = (xpxm1 + xpxm1) * q - p; + p = q; + q = r; + } + + return r; +} // shifted_chebyshev_polynomial_w_forward(T x, int64_t n) + +template +// TODO: Add 512 if/when double will be supported in Metal +inline constexpr int getHermitianLimit() { + return 128; +} + +template +inline float hermite_polynomial_h_forward(T x, int64_t n) { + if (n < 0) { + return 0.0; + } + + if (n == 0) { + return 1.0; + } + + if (n == 1) { + return x + x; + } + + if (n > getHermitianLimit()) { + return NAN; + } + + float p = 1.0; + float q = x + x; + float r = 0.0; + + for (int64_t k = 2; k < n + n; k += 2) { + r = (x + x) * q - k * p; + p = q; + q = r; + } + + return r; +} // hermite_polynomial_h_forward(T x, int64_t n) + +template +inline float hermite_polynomial_he_forward(T x, int64_t n) { + if (n < 0) { + return 0.0; + } + + if (n == 0) { + return 1.0; + } + + if (n == 1) { + return x; + } + + if (n > getHermitianLimit()) { + return NAN; + } + + float p = 1.0; + float q = x; + float r; + + for (int64_t k = 1; k < n; k++) { + r = x * q - k * p; + p = q; + q = r; + } + + return r; +} // hermite_polynomial_he_forward(T x, int64_t n) + +} // namespace metal +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/metal/utils.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/metal/utils.h new file mode 100644 index 0000000000000000000000000000000000000000..13c23ac7ed705a4e8fc76ba144f603be82a9c503 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/metal/utils.h @@ -0,0 +1,386 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Metal helper functions +#pragma once +#include +#include + +namespace c10 { +namespace metal { + +namespace detail { +template +struct vectypes {}; + +template <> +struct vectypes { + using type4 = float4; + using type3 = float3; + using type2 = float2; +}; + +template <> +struct vectypes { + using type4 = half4; + using type3 = half3; + using type2 = half2; +}; + +template <> +struct vectypes { + using type4 = bfloat4; + using type3 = bfloat3; + using type2 = bfloat2; +}; + +template <> +struct vectypes { + using type4 = short4; + using type3 = short3; + using type2 = short2; +}; + +template <> +struct vectypes { + using type4 = int4; + using type3 = int3; + using type2 = int2; +}; + +template <> +struct vectypes { + using type4 = short4; + using type3 = short3; + using type2 = short2; +}; + +template +struct OpMathType { + using type = T; +}; + +template <> +struct OpMathType { + using type = float; +}; + +template <> +struct OpMathType { + using type = int; +}; + +template <> +struct OpMathType { + using type = int; +}; + +template <> +struct OpMathType { + using type = int; +}; + +template <> +struct OpMathType { + using type = float; +}; + +// Type promotion structure for higher precision accumulation +template +struct AccumulationType { + using type = T; +}; + +// Specialization for half - promote to float for accumulation +template <> +struct AccumulationType { + using type = float; +}; + +// Specialization for bfloat - promote to float for accumulation +template <> +struct AccumulationType { + using type = float; +}; + +} // namespace detail + +template +::metal::enable_if_t<::metal::is_floating_point_v, T> max(T a, T b) { + return ::metal::isunordered(a, b) ? NAN : ::metal::max(a, b); +} + +template +::metal::enable_if_t<::metal::is_integral_v&& ::metal::is_integral_v, T> +max(T a, U b) { + return ::metal::max(a, static_cast(b)); +} + +template +::metal::enable_if_t<::metal::is_floating_point_v, T> min(T a, T b) { + return ::metal::isunordered(a, b) ? NAN : ::metal::min(a, b); +} + +template +::metal::enable_if_t<::metal::is_integral_v&& ::metal::is_integral_v, T> +min(T a, U b) { + return ::metal::min(a, static_cast(b)); +} + +template <> +inline bfloat min(bfloat a, bfloat b) { + return bfloat( + ::metal::isunordered(a, b) ? NAN : ::metal::min(float(a), float(b))); +} + +template <> +inline bfloat max(bfloat a, bfloat b) { + return bfloat( + ::metal::isunordered(a, b) ? NAN : ::metal::max(float(a), float(b))); +} + +template +using vec2type_t = typename detail::vectypes::type2; + +template +using vec4type_t = typename detail::vectypes::type4; + +template +using opmath_t = typename detail::OpMathType::type; + +template +using accum_t = typename detail::AccumulationType::type; + +// TODO: Move it to type_traits header may be +template +using result_of = decltype(::metal::declval()(::metal::declval()...)); + +template +constexpr constant bool is_complex_v = + ::metal::is_same_v || ::metal::is_same_v; + +template +constexpr constant bool is_scalar_floating_point_v = + ::metal::is_floating_point_v && ::metal::is_scalar_v; + +template +constexpr constant bool is_scalar_integral_v = + ::metal::is_integral_v && ::metal::is_scalar_v; + +template +using common_dtype = decltype(U(0) + V(0)); + +// floor_divide +template < + typename T, + typename U, + ::metal::enable_if_t< + is_scalar_integral_v && is_scalar_integral_v, + bool> = true> +inline common_dtype floor_divide(T x, U y) { + const auto quot = x / y; + return (x < 0) == (y < 0) ? quot : (x % y != 0) ? quot - 1 : quot; +} + +template < + typename T, + typename U, + ::metal::enable_if_t< + is_scalar_floating_point_v && is_scalar_floating_point_v, + bool> = true> +inline common_dtype floor_divide(T x, U y) { + return ::metal::floor(x / y); +} + +// fmod +template < + typename T, + typename U, + ::metal::enable_if_t< + is_scalar_integral_v && is_scalar_integral_v, + bool> = true> +inline common_dtype fmod(T x, U y) { + return x % y; +} + +template < + typename T, + typename U, + ::metal::enable_if_t< + is_scalar_floating_point_v && is_scalar_floating_point_v, + bool> = true> +inline common_dtype fmod(T x, U y) { + return ::metal::fmod(x, y); +} + +// cast_to primitives +// - No-op if types as the same +template < + typename T, + typename U, + ::metal::enable_if_t<::metal::is_same_v, bool> = true> +inline T cast_to(const U from) { + return from; +} +// - Simple cast between scalar and complex dtypes +template < + typename T, + typename U, + ::metal::enable_if_t< + !::metal::is_same_v && (is_complex_v == is_complex_v), + bool> = true> +inline T cast_to(const U from) { + return static_cast(from); +} + +// - Scalar to complex +template < + typename T, + typename U, + ::metal::enable_if_t && !is_complex_v, bool> = true> +inline T cast_to(const U from) { + return T(float(from), 0.0); +} +// - Complex to scalar (should not really be used, but exists for compliteness) +template < + typename T, + typename U, + ::metal::enable_if_t && is_complex_v, bool> = true> +inline T cast_to(const U from) { + return static_cast(from.x); +} + +// Generalizable math operators (used for both scalar and complex) + +template < + typename T, + typename U, + ::metal::enable_if_t, bool> = true> +inline common_dtype mul(const T x, const U y) { + return x * y; +} + +template < + typename T, + typename U, + ::metal::enable_if_t && is_complex_v, bool> = true> +inline common_dtype mul(const T x, const U y) { + return T(x.x * y.x - x.y * y.y, x.x * y.y + x.y * y.x); +} + +template < + typename T, + typename U, + ::metal::enable_if_t, bool> = true> +inline common_dtype div(const T x, const U y) { + return x / y; +} + +template < + typename T, + typename U, + ::metal::enable_if_t && is_complex_v, bool> = true> +inline common_dtype div(const T x, const U y) { + return T(::metal::dot(x, y), x.y * y.x - x.x * y.y) / ::metal::dot(y, y); +} + +// Remainder operator +template < + typename T, + typename U, + ::metal::enable_if_t< + is_scalar_floating_point_v || is_scalar_floating_point_v, + bool> = true> +inline float remainder(const T x, const U y) { + const auto x_f = static_cast(x); + const auto y_f = static_cast(y); + return x_f - y_f * floor_divide(x_f, y_f); +} + +template < + typename T, + typename U, + ::metal::enable_if_t< + is_scalar_integral_v && is_scalar_integral_v, + bool> = true> +inline common_dtype remainder(const T x, const U y) { + auto rc = x % y; + return rc == 0 || (x ^ y) > 0 ? rc : rc + y; +} + +// Based on algorithm described in +// https://docs.oracle.com/cd/E19957-01/806-3568/ncg_goldberg.html#1202 +inline float log1p(float x) { + const auto xp1 = 1.0f + x; + // First two elements of Taylor series for log(1+x) in Horner's form are: + // log(1+x) = x * (1 - x * (.5 ...)), but if 1 + x == x, then it's just x + if (xp1 == 1.0f) { + return x; + } + auto rc = ::metal::precise::log(xp1); + if (x > -.5 && x < .5) { + // Order of operations is important here for higher precision + rc *= x / (xp1 - 1.0f); + } + return rc; +} + +// The function is ported from mlx +inline float2 log1p(float2 in) { + float x = in.x; + float y = in.y; + float zabs = ::metal::precise::sqrt(x * x + y * y); + float theta = ::metal::atan2(y, x + 1); + if (zabs < 0.5f) { + float r = x * (2 + x) + y * y; + if (r == 0) { // handle underflow + return {x, theta}; + } + return {0.5f * log1p(r), theta}; + } else { + auto z0 = ::metal::sqrt((x + 1) * (x + 1) + y * y); + return {::metal::log(z0), theta}; + } +} + +template +struct pair { + T1 first; + T2 second; +}; + +template +inline T conj(T a) { + return a; +} + +template <> +inline half2 conj(half2 a) { + return half2(a.x, -a.y); +} + +template <> +inline float2 conj(float2 a) { + return float2(a.x, -a.y); +} + +#define INSTANTIATE_FOR_ALL_TYPES(MACRO) \ + MACRO(float); \ + MACRO(half); \ + MACRO(bfloat); \ + MACRO(float2); \ + MACRO(long); \ + MACRO(char); \ + MACRO(uchar); \ + MACRO(short); \ + MACRO(int); + +#define INSTANTIATE_FOR_FLOAT_TYPES(MACRO) \ + MACRO(float); \ + MACRO(half); \ + MACRO(bfloat); + +} // namespace metal +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/mobile/CPUCachingAllocator.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/mobile/CPUCachingAllocator.h new file mode 100644 index 0000000000000000000000000000000000000000..ad6854b8871d9e55324bea686b1313f64c1f5883 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/mobile/CPUCachingAllocator.h @@ -0,0 +1,111 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +#include +#include +#include + +/* + * CPUCachingAllocator: + * DISCLAIMER: + * This is subject to change (beta) and only supported on mobile builds. + * If code snippet such as in 'Usage pattern' is used outside of mobile + * build you will not observe the intended behavior. + * See below for more information. + * Why? + * It has been observed that some mobile platforms, such as pixel 3, return + * memory aggressively to the system. This results in page faults in some + * cases and ends up hurting performance. This caching allocator aims to address + * that. Furthermore it also allows users to specify their own allocator by + * implementing allocate/free virtual interfaces. What are the cons? There are + * some cons that were observed where use of caching allocator led to worse + * performance on some platforms. Reason being that the caching mechanism used + * by this allocator left us worse off compared to the corresponding platform's + * tuned memory allocator. In that case it seemed better to not use this + * allocator. Note there are some ideas to fix this in the works. + * + * Usage: + * Usage pattern: + * Instantiate and own the caching allocator. + * std::unique_ptr caching_allocator = + * std::make_unique(); + * Use caching allocator with a scoped guard at inference time. + * { + * WithCPUCachingAllocatorGuard(caching_allocator.get()); + * ... model.forward(...); + * } + */ + +namespace c10 { + +class C10_API CPUCachingAllocator { + /* + * What it does: + * Caches all the allocations carried out by this allocator. + * Cache key is the size of the allocation. + * If requested size is found in the cache returns the cached pointer. + * What it does not do: + * No speculative allocation for any future allocations. + */ + private: + inline void* allocate_and_cache(const size_t bytes); + void free_cached(); + + protected: + // Invariants. + // 1. If memory is ever allocated via this allocator then + // the pointer will exist in allocation_map_, unless the allocator + // returned the memory to OS via free_cached. + // 1.1. Therefore even when the said memory is "freed" via this + // allocator (and thus cached), it will continue to stay + // in allocation_map_. Furthermore it will also exist in + // available_map_. Thus an allocated memory pointer can be in both + // allocation_map_ and available_map_ simultaneously. + // 2. Memory pointer maybe removed from allocation_map_, when it + // is freed outside of the scope of this allocator, but was allocated + // by this allocator. + // 3. Available map only contains that memory which was allocated + // by this allocator and subsequently freed by this allocator. + // As a result of above invariants, allocated memory ptr cannot be in + // available_map_ unless it is in allocation_map_ as well. + ska::flat_hash_map> available_map_; + static ska::flat_hash_map allocation_map_; + // Since allocation_map, which is a global instance, is mutated/read via + // all public APIs we need a global mutex. + static std::mutex mutex_; + + public: + static void record_free(void* ptr); + virtual ~CPUCachingAllocator(); + // Checks the cache to see if allocation of size bytes can be found. + // If so return cached memory, else + // allocates memory, records it for caching and returns. + virtual void* allocate(const size_t bytes); + // Checks if the memory being freed is was marked for allocation by + // an earlier call to allocate. If so cache the allocation. + // Otherwise free. + virtual void free(void* ptr); +}; + +CPUCachingAllocator* GetDefaultCPUCachingAllocator(); + +bool ThreadLocalCachingAllocatorEnabled(); +CPUCachingAllocator* GetThreadLocalCachingAllocator(); + +class C10_API WithCPUCachingAllocatorGuard { + public: + WithCPUCachingAllocatorGuard(CPUCachingAllocator* allocator); + ~WithCPUCachingAllocatorGuard(); + + private: + CPUCachingAllocator* prev_caching_allocator_ptr_{nullptr}; +}; + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/mobile/CPUProfilingAllocator.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/mobile/CPUProfilingAllocator.h new file mode 100644 index 0000000000000000000000000000000000000000..07064210e115bb5799906828fac135ccb63a3146 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/mobile/CPUProfilingAllocator.h @@ -0,0 +1,157 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace c10 { + +/* + * Given a sequence of allocations in a thread, AllocationPlan records + * 1. size of each allocation + * 2. Lifetime of each allocation. + * 3. allocation offsets: Memory offset for each allocation in a single blob of + * memory + * 4. Total size of a blob of memory required to satisfy all the allocations. + */ +class C10_API AllocationPlan { + private: + // Records size of each allocation by their sequential allocation ids. + std::vector allocation_sizes; + // This maps one allocation id (X) to another allocation id (Y). + // Allocation X is alive until allocation Y. From allocation Y onwards + // allocation X is not referenced. + // Thus Y is the id of the first allocation after X is freed. + // NB: When an allocation is recorded, along with recording its size, + // we also set the lifetime to be numeric_limits::max() + // This is to track allocations that are made during the scope of + // profiling but were not freed until after the scope ended. + // Such allocations are not managed by profiling allocator. + std::vector allocation_lifetimes; + // Maps an allocation to some offset in a blob of memory. + std::vector allocation_offsets; + uint64_t total_size{0}; + void clear(); + friend class AllocationPlanner; + friend class CPUProfilingAllocator; +}; + +/* + * Map of memory ptr to allocation id. This is auxiliary information only + * used to establish lifetime of allocations. + */ +class C10_API AllocationPlanner { + private: + AllocationPlan* allocation_plan_{nullptr}; + // Maps allocated ptr to its allocation id. + // This is used when freeing the memory to look up the allocation id + // in order to establish the lifetime of a particular allocation. + ska::flat_hash_map allocation_ptr_to_id_; + uint64_t allocation_id_{0}; + bool validation_mode_{false}; + + bool validate_allocation(const uint64_t size, const void* ptr); + bool validate_free(const void* ptr); + + public: + bool validation_success{true}; + + AllocationPlanner() = delete; + AllocationPlanner(AllocationPlan* plan, bool validate = false) + : allocation_plan_(plan), validation_mode_(validate) {} + void record_allocation(const uint64_t size, const void* ptr); + void record_free(const void* ptr); + void formulate_plan(); + void clear(); +}; + +// NOT THREAD SAFE profiling allocator. +class C10_API CPUProfilingAllocator { + private: + const AllocationPlan* plan_{nullptr}; + uint64_t allocation_id_{0}; + uint64_t current_size_{0}; + void* blob_{nullptr}; + ska::flat_hash_map allocation_ptr_to_id_; + + public: + ~CPUProfilingAllocator(); + void set_plan(const AllocationPlan* plan); + void unset_plan(); + void* allocate(const size_t bytes); + void free(void* const ptr); +}; + +/* + * Usage: Profile allocations made by one run of the model. + * AllocationPlan plan; + * { + * WithProfileAllocationGuard profile_guard(&plan); + * module.forward(...); + * } + * plan now contains allocation plan. + */ +class C10_API WithProfileAllocationsGuard { + public: + WithProfileAllocationsGuard(AllocationPlan* plan); + ~WithProfileAllocationsGuard(); + + private: + std::unique_ptr planner_; +}; + +/* + * Usage: Validate allocation plan made with WithProfileAllocationGuard + * bool plan_validation_success, success = true; + * for (some number of representative inputs) + * { + * WithValidateAllocationPlanGuard(&plan, &plan_validation_success); + * module.forward(...); + * success = success && plan_validation_success; + * } + * success == true means allocations are according to plan + * else for some inputs allocation pattern changed. + */ +class C10_API WithValidateAllocationPlanGuard { + public: + WithValidateAllocationPlanGuard(AllocationPlan* plan, bool* success); + ~WithValidateAllocationPlanGuard(); + + private: + std::unique_ptr planner_; + bool* success_; +}; + +AllocationPlanner* GetThreadLocalAllocationPlanner(); + +/* + * Usage: Allocate tensors accordingly to allocation plan + * First make allocation plan. + * See WithProfileAllocationsGuard usage. + * Second validate allocation plan. + * See WithValidateAllocationPlanGuard usage. + * CPUProfilingAllocator profiling_allocator; + * { + * WithProfilingAllocatorGuard allocator_guard(&profiling_allocator, &plan); + * module.forward(...); + * } + */ +class C10_API WithProfilingAllocatorGuard { + public: + WithProfilingAllocatorGuard( + CPUProfilingAllocator* allocator, + const AllocationPlan* plan); + ~WithProfilingAllocatorGuard(); +}; + +CPUProfilingAllocator* GetThreadLocalProfilingAllocator(); + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/test/util/Macros.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/test/util/Macros.h new file mode 100644 index 0000000000000000000000000000000000000000..026570edcd7f2be024266f65b5745a65036bbeed --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/test/util/Macros.h @@ -0,0 +1,14 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#ifndef C10_TEST_CORE_MACROS_MACROS_H_ + +#ifdef _WIN32 +#define DISABLED_ON_WINDOWS(x) DISABLED_##x +#else +#define DISABLED_ON_WINDOWS(x) x +#endif + +#endif // C10_MACROS_MACROS_H_ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/test/util/complex_math_test_common.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/test/util/complex_math_test_common.h new file mode 100644 index 0000000000000000000000000000000000000000..a68a35cd968a95ef35b61b92594837fcbdbf79a6 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/test/util/complex_math_test_common.h @@ -0,0 +1,672 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Warning: this file is included twice in +// aten/src/ATen/test/cuda_complex_math_test.cu + +#include +#include + +#ifndef PI +#define PI 3.141592653589793238463 +#endif + +#ifndef tol +#define tol 1e-6 +#endif + +// Exponential functions + +C10_DEFINE_TEST(TestExponential, IPi) { + // exp(i*pi) = -1 + { + c10::complex e_i_pi = std::exp(c10::complex(0, float(PI))); + C10_ASSERT_NEAR(e_i_pi.real(), -1, tol); + C10_ASSERT_NEAR(e_i_pi.imag(), 0, tol); + } + { + c10::complex e_i_pi = ::exp(c10::complex(0, float(PI))); + C10_ASSERT_NEAR(e_i_pi.real(), -1, tol); + C10_ASSERT_NEAR(e_i_pi.imag(), 0, tol); + } + { + c10::complex e_i_pi = std::exp(c10::complex(0, PI)); + C10_ASSERT_NEAR(e_i_pi.real(), -1, tol); + C10_ASSERT_NEAR(e_i_pi.imag(), 0, tol); + } + { + c10::complex e_i_pi = ::exp(c10::complex(0, PI)); + C10_ASSERT_NEAR(e_i_pi.real(), -1, tol); + C10_ASSERT_NEAR(e_i_pi.imag(), 0, tol); + } +} + +C10_DEFINE_TEST(TestExponential, EulerFormula) { + // exp(ix) = cos(x) + i * sin(x) + { + c10::complex x(0.1, 1.2); + c10::complex e = std::exp(x); + float expected_real = std::exp(x.real()) * std::cos(x.imag()); + float expected_imag = std::exp(x.real()) * std::sin(x.imag()); + C10_ASSERT_NEAR(e.real(), expected_real, tol); + C10_ASSERT_NEAR(e.imag(), expected_imag, tol); + } + { + c10::complex x(0.1, 1.2); + c10::complex e = ::exp(x); + float expected_real = ::exp(x.real()) * ::cos(x.imag()); + float expected_imag = ::exp(x.real()) * ::sin(x.imag()); + C10_ASSERT_NEAR(e.real(), expected_real, tol); + C10_ASSERT_NEAR(e.imag(), expected_imag, tol); + } + { + c10::complex x(0.1, 1.2); + c10::complex e = std::exp(x); + float expected_real = std::exp(x.real()) * std::cos(x.imag()); + float expected_imag = std::exp(x.real()) * std::sin(x.imag()); + C10_ASSERT_NEAR(e.real(), expected_real, tol); + C10_ASSERT_NEAR(e.imag(), expected_imag, tol); + } + { + c10::complex x(0.1, 1.2); + c10::complex e = ::exp(x); + float expected_real = ::exp(x.real()) * ::cos(x.imag()); + float expected_imag = ::exp(x.real()) * ::sin(x.imag()); + C10_ASSERT_NEAR(e.real(), expected_real, tol); + C10_ASSERT_NEAR(e.imag(), expected_imag, tol); + } +} + +C10_DEFINE_TEST(TestExpm1, Normal) { + // expm1(x) = exp(x) - 1 + { + c10::complex x(0.1, 1.2); + c10::complex l1 = std::expm1(x); + c10::complex l2 = std::exp(x) - 1.0f; + C10_ASSERT_NEAR(l1.real(), l2.real(), tol); + C10_ASSERT_NEAR(l1.imag(), l2.imag(), tol); + } + { + c10::complex x(0.1, 1.2); + c10::complex l1 = std::expm1(x); + c10::complex l2 = std::exp(x) - 1.0; + C10_ASSERT_NEAR(l1.real(), l2.real(), tol); + C10_ASSERT_NEAR(l1.imag(), l2.imag(), tol); + } +} + +C10_DEFINE_TEST(TestExpm1, Small) { + // expm1(x) = exp(x) - 1 + // expm1(x) provides greater precision than exp(x) - 1 for small values of x + { + c10::complex x(1e-30, 1e-30); + c10::complex l1 = std::expm1(x); + C10_ASSERT_NEAR(l1.real(), 1e-30, tol); + C10_ASSERT_NEAR(l1.imag(), 1e-30, tol); + } + { + c10::complex x(1e-100, 1e-100); + c10::complex l1 = std::expm1(x); + C10_ASSERT_NEAR(l1.real(), 1e-30, tol); + C10_ASSERT_NEAR(l1.imag(), 1e-30, tol); + } +} + +C10_DEFINE_TEST(TestLog, Definition) { + // log(x) = log(r) + i*theta + { + c10::complex x(1.2, 3.4); + c10::complex l = std::log(x); + float expected_real = std::log(std::abs(x)); + float expected_imag = std::arg(x); + C10_ASSERT_NEAR(l.real(), expected_real, tol); + C10_ASSERT_NEAR(l.imag(), expected_imag, tol); + } + { + c10::complex x(1.2, 3.4); + c10::complex l = ::log(x); + float expected_real = ::log(std::abs(x)); + float expected_imag = std::arg(x); + C10_ASSERT_NEAR(l.real(), expected_real, tol); + C10_ASSERT_NEAR(l.imag(), expected_imag, tol); + } + { + c10::complex x(1.2, 3.4); + c10::complex l = std::log(x); + float expected_real = std::log(std::abs(x)); + float expected_imag = std::arg(x); + C10_ASSERT_NEAR(l.real(), expected_real, tol); + C10_ASSERT_NEAR(l.imag(), expected_imag, tol); + } + { + c10::complex x(1.2, 3.4); + c10::complex l = ::log(x); + float expected_real = ::log(std::abs(x)); + float expected_imag = std::arg(x); + C10_ASSERT_NEAR(l.real(), expected_real, tol); + C10_ASSERT_NEAR(l.imag(), expected_imag, tol); + } +} + +C10_DEFINE_TEST(TestLog10, Rev) { + // log10(10^x) = x + { + c10::complex x(0.1, 1.2); + c10::complex l = std::log10(std::pow(float(10), x)); + C10_ASSERT_NEAR(l.real(), float(0.1), tol); + C10_ASSERT_NEAR(l.imag(), float(1.2), tol); + } + { + c10::complex x(0.1, 1.2); + c10::complex l = ::log10(::pow(float(10), x)); + C10_ASSERT_NEAR(l.real(), float(0.1), tol); + C10_ASSERT_NEAR(l.imag(), float(1.2), tol); + } + { + c10::complex x(0.1, 1.2); + c10::complex l = std::log10(std::pow(double(10), x)); + C10_ASSERT_NEAR(l.real(), double(0.1), tol); + C10_ASSERT_NEAR(l.imag(), double(1.2), tol); + } + { + c10::complex x(0.1, 1.2); + c10::complex l = ::log10(::pow(double(10), x)); + C10_ASSERT_NEAR(l.real(), double(0.1), tol); + C10_ASSERT_NEAR(l.imag(), double(1.2), tol); + } +} + +C10_DEFINE_TEST(TestLog2, Rev) { + // log2(2^x) = x + { + c10::complex x(0.1, 1.2); + c10::complex l = std::log2(std::pow(float(2), x)); + C10_ASSERT_NEAR(l.real(), float(0.1), tol); + C10_ASSERT_NEAR(l.imag(), float(1.2), tol); + } + { + c10::complex x(0.1, 1.2); + c10::complex l = ::log2(std::pow(float(2), x)); + C10_ASSERT_NEAR(l.real(), float(0.1), tol); + C10_ASSERT_NEAR(l.imag(), float(1.2), tol); + } + { + c10::complex x(0.1, 1.2); + c10::complex l = std::log2(std::pow(double(2), x)); + C10_ASSERT_NEAR(l.real(), double(0.1), tol); + C10_ASSERT_NEAR(l.imag(), double(1.2), tol); + } + { + c10::complex x(0.1, 1.2); + c10::complex l = ::log2(std::pow(double(2), x)); + C10_ASSERT_NEAR(l.real(), double(0.1), tol); + C10_ASSERT_NEAR(l.imag(), double(1.2), tol); + } +} + +C10_DEFINE_TEST(TestLog1p, Normal) { + // log1p(x) = log(1 + x) + { + c10::complex x(0.1, 1.2); + c10::complex l1 = std::log1p(x); + c10::complex l2 = std::log(1.0f + x); + C10_ASSERT_NEAR(l1.real(), l2.real(), tol); + C10_ASSERT_NEAR(l1.imag(), l2.imag(), tol); + } + { + c10::complex x(0.1, 1.2); + c10::complex l1 = std::log1p(x); + c10::complex l2 = std::log(1.0 + x); + C10_ASSERT_NEAR(l1.real(), l2.real(), tol); + C10_ASSERT_NEAR(l1.imag(), l2.imag(), tol); + } +} + +C10_DEFINE_TEST(TestLog1p, Small) { + // log(1 + x) ~ x for |x| << 1 + { + c10::complex x(1e-9, 2e-9); + c10::complex l = std::log1p(x); + C10_ASSERT_NEAR(l.real() / x.real(), 1, tol); + C10_ASSERT_NEAR(l.imag() / x.imag(), 1, tol); + } + { + c10::complex x(1e-100, 2e-100); + c10::complex l = std::log1p(x); + C10_ASSERT_NEAR(l.real() / x.real(), 1, tol); + C10_ASSERT_NEAR(l.imag() / x.imag(), 1, tol); + } +} + +C10_DEFINE_TEST(TestLog1p, Extreme) { + // log(1 + x) ~ x for |x| << 1 and in the brink of overflow / underflow + { + c10::complex x(-1, 1e-30); + c10::complex l = std::log1p(x); + C10_ASSERT_NEAR(l.real(), -69.07755278982137, tol); + C10_ASSERT_NEAR(l.imag(), 1.5707963267948966, tol); + } + { + c10::complex x(-1, 1e30); + c10::complex l = std::log1p(x); + C10_ASSERT_NEAR(l.real(), 69.07755278982137, tol); + C10_ASSERT_NEAR(l.imag(), 1.5707963267948966, tol); + } + { + c10::complex x(1e30, 1); + c10::complex l = std::log1p(x); + C10_ASSERT_NEAR(l.real(), 69.07755278982137, tol); + C10_ASSERT_NEAR(l.imag(), 1e-30, tol); + } + { + c10::complex x(1e-30, 1); + c10::complex l = std::log1p(x); + C10_ASSERT_NEAR(l.real(), 0.34657359027997264, tol); + C10_ASSERT_NEAR(l.imag(), 0.7853981633974483, tol); + } + { + c10::complex x(1e30, 1e30); + c10::complex l = std::log1p(x); + C10_ASSERT_NEAR(l.real(), 69.42412638010134, tol); + C10_ASSERT_NEAR(l.imag(), 0.7853981633974483, tol); + } + { + c10::complex x(1e-38, 1e-38); + c10::complex l = std::log1p(x); + C10_ASSERT_NEAR(l.real(), 1e-38, tol); + C10_ASSERT_NEAR(l.imag(), 1e-38, tol); + } + { + c10::complex x(1e-38, 2e-30); + c10::complex l = std::log1p(x); + C10_ASSERT_NEAR(l.real(), 1e-30, tol); + C10_ASSERT_NEAR(l.imag(), 2e-30, tol); + } + { + c10::complex x(-1, 1e-250); + c10::complex l = std::log1p(x); + C10_ASSERT_NEAR(l.real(), -575.6462732485114, tol); + C10_ASSERT_NEAR(l.imag(), 1.5707963267948966, tol); + } + { + c10::complex x(-1, 1e250); + c10::complex l = std::log1p(x); + C10_ASSERT_NEAR(l.real(), 575.6462732485114, tol); + C10_ASSERT_NEAR(l.imag(), 1.5707963267948966, tol); + } + { + c10::complex x(1e250, 1); + c10::complex l = std::log1p(x); + C10_ASSERT_NEAR(l.real(), 575.6462732485114, tol); + C10_ASSERT_NEAR(l.imag(), 1e-250, tol); + } + { + c10::complex x(1e-250, 1); + c10::complex l = std::log1p(x); + C10_ASSERT_NEAR(l.real(), 0.34657359027997264, tol); + C10_ASSERT_NEAR(l.imag(), 0.7853981633974483, tol); + } + { + c10::complex x(1e250, 1e250); + c10::complex l = std::log1p(x); + C10_ASSERT_NEAR(l.real(), 575.9928468387914, tol); + C10_ASSERT_NEAR(l.imag(), 0.7853981633974483, tol); + } + { + c10::complex x(1e-250, 1e-250); + c10::complex l = std::log1p(x); + C10_ASSERT_NEAR(l.real(), 1e-250, tol); + C10_ASSERT_NEAR(l.imag(), 1e-250, tol); + } + { + c10::complex x(1e-250, 2e-250); + c10::complex l = std::log1p(x); + C10_ASSERT_NEAR(l.real(), 1e-250, tol); + C10_ASSERT_NEAR(l.imag(), 2e-250, tol); + } + { + c10::complex x(2e-308, 1.5e-250); + c10::complex l = std::log1p(x); + C10_ASSERT_NEAR(l.real(), 2e-308, tol); + C10_ASSERT_NEAR(l.imag(), 1.5e-308, tol); + } +} + +// Power functions + +C10_DEFINE_TEST(TestPowSqrt, Equal) { + // x^0.5 = sqrt(x) + { + c10::complex x(0.1, 1.2); + c10::complex y = std::pow(x, float(0.5)); + c10::complex z = std::sqrt(x); + C10_ASSERT_NEAR(y.real(), z.real(), tol); + C10_ASSERT_NEAR(y.imag(), z.imag(), tol); + } + { + c10::complex x(0.1, 1.2); + c10::complex y = ::pow(x, float(0.5)); + c10::complex z = ::sqrt(x); + C10_ASSERT_NEAR(y.real(), z.real(), tol); + C10_ASSERT_NEAR(y.imag(), z.imag(), tol); + } + { + c10::complex x(0.1, 1.2); + c10::complex y = std::pow(x, double(0.5)); + c10::complex z = std::sqrt(x); + C10_ASSERT_NEAR(y.real(), z.real(), tol); + C10_ASSERT_NEAR(y.imag(), z.imag(), tol); + } + { + c10::complex x(0.1, 1.2); + c10::complex y = ::pow(x, double(0.5)); + c10::complex z = ::sqrt(x); + C10_ASSERT_NEAR(y.real(), z.real(), tol); + C10_ASSERT_NEAR(y.imag(), z.imag(), tol); + } +} + +C10_DEFINE_TEST(TestPow, Square) { + // x^2 = x * x + { + c10::complex x(0.1, 1.2); + c10::complex y = std::pow(x, float(2)); + c10::complex z = x * x; + C10_ASSERT_NEAR(y.real(), z.real(), tol); + C10_ASSERT_NEAR(y.imag(), z.imag(), tol); + } + { + c10::complex x(0.1, 1.2); + c10::complex y = ::pow(x, float(2)); + c10::complex z = x * x; + C10_ASSERT_NEAR(y.real(), z.real(), tol); + C10_ASSERT_NEAR(y.imag(), z.imag(), tol); + } + { + c10::complex x(0.1, 1.2); + c10::complex y = std::pow(x, double(2)); + c10::complex z = x * x; + C10_ASSERT_NEAR(y.real(), z.real(), tol); + C10_ASSERT_NEAR(y.imag(), z.imag(), tol); + } + { + c10::complex x(0.1, 1.2); + c10::complex y = ::pow(x, double(2)); + c10::complex z = x * x; + C10_ASSERT_NEAR(y.real(), z.real(), tol); + C10_ASSERT_NEAR(y.imag(), z.imag(), tol); + } +} + +// Trigonometric functions and hyperbolic functions + +C10_DEFINE_TEST(TestSinCosSinhCosh, Identity) { + // sin(x + i * y) = sin(x) * cosh(y) + i * cos(x) * sinh(y) + // cos(x + i * y) = cos(x) * cosh(y) - i * sin(x) * sinh(y) + { + c10::complex x(0.1, 1.2); + c10::complex y = std::sin(x); + float expected_real = std::sin(x.real()) * std::cosh(x.imag()); + float expected_imag = std::cos(x.real()) * std::sinh(x.imag()); + C10_ASSERT_NEAR(y.real(), expected_real, tol); + C10_ASSERT_NEAR(y.imag(), expected_imag, tol); + } + { + c10::complex x(0.1, 1.2); + c10::complex y = ::sin(x); + float expected_real = ::sin(x.real()) * ::cosh(x.imag()); + float expected_imag = ::cos(x.real()) * ::sinh(x.imag()); + C10_ASSERT_NEAR(y.real(), expected_real, tol); + C10_ASSERT_NEAR(y.imag(), expected_imag, tol); + } + { + c10::complex x(0.1, 1.2); + c10::complex y = std::cos(x); + float expected_real = std::cos(x.real()) * std::cosh(x.imag()); + float expected_imag = -std::sin(x.real()) * std::sinh(x.imag()); + C10_ASSERT_NEAR(y.real(), expected_real, tol); + C10_ASSERT_NEAR(y.imag(), expected_imag, tol); + } + { + c10::complex x(0.1, 1.2); + c10::complex y = ::cos(x); + float expected_real = ::cos(x.real()) * ::cosh(x.imag()); + float expected_imag = -::sin(x.real()) * ::sinh(x.imag()); + C10_ASSERT_NEAR(y.real(), expected_real, tol); + C10_ASSERT_NEAR(y.imag(), expected_imag, tol); + } + { + c10::complex x(0.1, 1.2); + c10::complex y = std::sin(x); + float expected_real = std::sin(x.real()) * std::cosh(x.imag()); + float expected_imag = std::cos(x.real()) * std::sinh(x.imag()); + C10_ASSERT_NEAR(y.real(), expected_real, tol); + C10_ASSERT_NEAR(y.imag(), expected_imag, tol); + } + { + c10::complex x(0.1, 1.2); + c10::complex y = ::sin(x); + float expected_real = ::sin(x.real()) * ::cosh(x.imag()); + float expected_imag = ::cos(x.real()) * ::sinh(x.imag()); + C10_ASSERT_NEAR(y.real(), expected_real, tol); + C10_ASSERT_NEAR(y.imag(), expected_imag, tol); + } + { + c10::complex x(0.1, 1.2); + c10::complex y = std::cos(x); + float expected_real = std::cos(x.real()) * std::cosh(x.imag()); + float expected_imag = -std::sin(x.real()) * std::sinh(x.imag()); + C10_ASSERT_NEAR(y.real(), expected_real, tol); + C10_ASSERT_NEAR(y.imag(), expected_imag, tol); + } + { + c10::complex x(0.1, 1.2); + c10::complex y = ::cos(x); + float expected_real = ::cos(x.real()) * ::cosh(x.imag()); + float expected_imag = -::sin(x.real()) * ::sinh(x.imag()); + C10_ASSERT_NEAR(y.real(), expected_real, tol); + C10_ASSERT_NEAR(y.imag(), expected_imag, tol); + } +} + +C10_DEFINE_TEST(TestTan, Identity) { + // tan(x) = sin(x) / cos(x) + { + c10::complex x(0.1, 1.2); + c10::complex y = std::tan(x); + c10::complex z = std::sin(x) / std::cos(x); + C10_ASSERT_NEAR(y.real(), z.real(), tol); + C10_ASSERT_NEAR(y.imag(), z.imag(), tol); + } + { + c10::complex x(0.1, 1.2); + c10::complex y = ::tan(x); + c10::complex z = ::sin(x) / ::cos(x); + C10_ASSERT_NEAR(y.real(), z.real(), tol); + C10_ASSERT_NEAR(y.imag(), z.imag(), tol); + } + { + c10::complex x(0.1, 1.2); + c10::complex y = std::tan(x); + c10::complex z = std::sin(x) / std::cos(x); + C10_ASSERT_NEAR(y.real(), z.real(), tol); + C10_ASSERT_NEAR(y.imag(), z.imag(), tol); + } + { + c10::complex x(0.1, 1.2); + c10::complex y = ::tan(x); + c10::complex z = ::sin(x) / ::cos(x); + C10_ASSERT_NEAR(y.real(), z.real(), tol); + C10_ASSERT_NEAR(y.imag(), z.imag(), tol); + } +} + +C10_DEFINE_TEST(TestTanh, Identity) { + // tanh(x) = sinh(x) / cosh(x) + { + c10::complex x(0.1, 1.2); + c10::complex y = std::tanh(x); + c10::complex z = std::sinh(x) / std::cosh(x); + C10_ASSERT_NEAR(y.real(), z.real(), tol); + C10_ASSERT_NEAR(y.imag(), z.imag(), tol); + } + { + c10::complex x(0.1, 1.2); + c10::complex y = ::tanh(x); + c10::complex z = ::sinh(x) / ::cosh(x); + C10_ASSERT_NEAR(y.real(), z.real(), tol); + C10_ASSERT_NEAR(y.imag(), z.imag(), tol); + } + { + c10::complex x(0.1, 1.2); + c10::complex y = std::tanh(x); + c10::complex z = std::sinh(x) / std::cosh(x); + C10_ASSERT_NEAR(y.real(), z.real(), tol); + C10_ASSERT_NEAR(y.imag(), z.imag(), tol); + } + { + c10::complex x(0.1, 1.2); + c10::complex y = ::tanh(x); + c10::complex z = ::sinh(x) / ::cosh(x); + C10_ASSERT_NEAR(y.real(), z.real(), tol); + C10_ASSERT_NEAR(y.imag(), z.imag(), tol); + } +} + +// Rev trigonometric functions + +C10_DEFINE_TEST(TestRevTrigonometric, Rev) { + // asin(sin(x)) = x + // acos(cos(x)) = x + // atan(tan(x)) = x + { + c10::complex x(0.5, 0.6); + c10::complex s = std::sin(x); + c10::complex ss = std::asin(s); + c10::complex c = std::cos(x); + c10::complex cc = std::acos(c); + c10::complex t = std::tan(x); + c10::complex tt = std::atan(t); + C10_ASSERT_NEAR(x.real(), ss.real(), tol); + C10_ASSERT_NEAR(x.imag(), ss.imag(), tol); + C10_ASSERT_NEAR(x.real(), cc.real(), tol); + C10_ASSERT_NEAR(x.imag(), cc.imag(), tol); + C10_ASSERT_NEAR(x.real(), tt.real(), tol); + C10_ASSERT_NEAR(x.imag(), tt.imag(), tol); + } + { + c10::complex x(0.5, 0.6); + c10::complex s = ::sin(x); + c10::complex ss = ::asin(s); + c10::complex c = ::cos(x); + c10::complex cc = ::acos(c); + c10::complex t = ::tan(x); + c10::complex tt = ::atan(t); + C10_ASSERT_NEAR(x.real(), ss.real(), tol); + C10_ASSERT_NEAR(x.imag(), ss.imag(), tol); + C10_ASSERT_NEAR(x.real(), cc.real(), tol); + C10_ASSERT_NEAR(x.imag(), cc.imag(), tol); + C10_ASSERT_NEAR(x.real(), tt.real(), tol); + C10_ASSERT_NEAR(x.imag(), tt.imag(), tol); + } + { + c10::complex x(0.5, 0.6); + c10::complex s = std::sin(x); + c10::complex ss = std::asin(s); + c10::complex c = std::cos(x); + c10::complex cc = std::acos(c); + c10::complex t = std::tan(x); + c10::complex tt = std::atan(t); + C10_ASSERT_NEAR(x.real(), ss.real(), tol); + C10_ASSERT_NEAR(x.imag(), ss.imag(), tol); + C10_ASSERT_NEAR(x.real(), cc.real(), tol); + C10_ASSERT_NEAR(x.imag(), cc.imag(), tol); + C10_ASSERT_NEAR(x.real(), tt.real(), tol); + C10_ASSERT_NEAR(x.imag(), tt.imag(), tol); + } + { + c10::complex x(0.5, 0.6); + c10::complex s = ::sin(x); + c10::complex ss = ::asin(s); + c10::complex c = ::cos(x); + c10::complex cc = ::acos(c); + c10::complex t = ::tan(x); + c10::complex tt = ::atan(t); + C10_ASSERT_NEAR(x.real(), ss.real(), tol); + C10_ASSERT_NEAR(x.imag(), ss.imag(), tol); + C10_ASSERT_NEAR(x.real(), cc.real(), tol); + C10_ASSERT_NEAR(x.imag(), cc.imag(), tol); + C10_ASSERT_NEAR(x.real(), tt.real(), tol); + C10_ASSERT_NEAR(x.imag(), tt.imag(), tol); + } +} + +// Rev hyperbolic functions + +C10_DEFINE_TEST(TestRevHyperbolic, Rev) { + // asinh(sinh(x)) = x + // acosh(cosh(x)) = x + // atanh(tanh(x)) = x + { + c10::complex x(0.5, 0.6); + c10::complex s = std::sinh(x); + c10::complex ss = std::asinh(s); + c10::complex c = std::cosh(x); + c10::complex cc = std::acosh(c); + c10::complex t = std::tanh(x); + c10::complex tt = std::atanh(t); + C10_ASSERT_NEAR(x.real(), ss.real(), tol); + C10_ASSERT_NEAR(x.imag(), ss.imag(), tol); + C10_ASSERT_NEAR(x.real(), cc.real(), tol); + C10_ASSERT_NEAR(x.imag(), cc.imag(), tol); + C10_ASSERT_NEAR(x.real(), tt.real(), tol); + C10_ASSERT_NEAR(x.imag(), tt.imag(), tol); + } + { + c10::complex x(0.5, 0.6); + c10::complex s = ::sinh(x); + c10::complex ss = ::asinh(s); + c10::complex c = ::cosh(x); + c10::complex cc = ::acosh(c); + c10::complex t = ::tanh(x); + c10::complex tt = ::atanh(t); + C10_ASSERT_NEAR(x.real(), ss.real(), tol); + C10_ASSERT_NEAR(x.imag(), ss.imag(), tol); + C10_ASSERT_NEAR(x.real(), cc.real(), tol); + C10_ASSERT_NEAR(x.imag(), cc.imag(), tol); + C10_ASSERT_NEAR(x.real(), tt.real(), tol); + C10_ASSERT_NEAR(x.imag(), tt.imag(), tol); + } + { + c10::complex x(0.5, 0.6); + c10::complex s = std::sinh(x); + c10::complex ss = std::asinh(s); + c10::complex c = std::cosh(x); + c10::complex cc = std::acosh(c); + c10::complex t = std::tanh(x); + c10::complex tt = std::atanh(t); + C10_ASSERT_NEAR(x.real(), ss.real(), tol); + C10_ASSERT_NEAR(x.imag(), ss.imag(), tol); + C10_ASSERT_NEAR(x.real(), cc.real(), tol); + C10_ASSERT_NEAR(x.imag(), cc.imag(), tol); + C10_ASSERT_NEAR(x.real(), tt.real(), tol); + C10_ASSERT_NEAR(x.imag(), tt.imag(), tol); + } + { + c10::complex x(0.5, 0.6); + c10::complex s = ::sinh(x); + c10::complex ss = ::asinh(s); + c10::complex c = ::cosh(x); + c10::complex cc = ::acosh(c); + c10::complex t = ::tanh(x); + c10::complex tt = ::atanh(t); + C10_ASSERT_NEAR(x.real(), ss.real(), tol); + C10_ASSERT_NEAR(x.imag(), ss.imag(), tol); + C10_ASSERT_NEAR(x.real(), cc.real(), tol); + C10_ASSERT_NEAR(x.imag(), cc.imag(), tol); + C10_ASSERT_NEAR(x.real(), tt.real(), tol); + C10_ASSERT_NEAR(x.imag(), tt.imag(), tol); + } +} + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/test/util/complex_test_common.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/test/util/complex_test_common.h new file mode 100644 index 0000000000000000000000000000000000000000..94586ba1293ac4c922d6638817ce7a92b14d83b5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/test/util/complex_test_common.h @@ -0,0 +1,663 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include +#include +#include +#include +#include +#include +#include +#include + +#if (defined(__CUDACC__) || defined(__HIPCC__)) +#define MAYBE_GLOBAL __global__ +#else +#define MAYBE_GLOBAL +#endif + +#define PI 3.141592653589793238463 + +namespace memory { + +MAYBE_GLOBAL void test_size() { + static_assert(sizeof(c10::complex) == 2 * sizeof(float), ""); + static_assert(sizeof(c10::complex) == 2 * sizeof(double), ""); +} + +MAYBE_GLOBAL void test_align() { + static_assert(alignof(c10::complex) == 2 * sizeof(float), ""); + static_assert(alignof(c10::complex) == 2 * sizeof(double), ""); +} + +MAYBE_GLOBAL void test_pod() { + static_assert(std::is_standard_layout>::value, ""); + static_assert(std::is_standard_layout>::value, ""); +} + +TEST(TestMemory, ReinterpretCast) { + { + std::complex z(1, 2); + c10::complex zz = *reinterpret_cast*>(&z); + ASSERT_EQ(zz.real(), float(1)); + ASSERT_EQ(zz.imag(), float(2)); + } + + { + c10::complex z(3, 4); + std::complex zz = *reinterpret_cast*>(&z); + ASSERT_EQ(zz.real(), float(3)); + ASSERT_EQ(zz.imag(), float(4)); + } + + { + std::complex z(1, 2); + c10::complex zz = *reinterpret_cast*>(&z); + ASSERT_EQ(zz.real(), double(1)); + ASSERT_EQ(zz.imag(), double(2)); + } + + { + c10::complex z(3, 4); + std::complex zz = *reinterpret_cast*>(&z); + ASSERT_EQ(zz.real(), double(3)); + ASSERT_EQ(zz.imag(), double(4)); + } +} + +#if defined(__CUDACC__) || defined(__HIPCC__) +TEST(TestMemory, ThrustReinterpretCast) { + { + thrust::complex z(1, 2); + c10::complex zz = *reinterpret_cast*>(&z); + ASSERT_EQ(zz.real(), float(1)); + ASSERT_EQ(zz.imag(), float(2)); + } + + { + c10::complex z(3, 4); + thrust::complex zz = *reinterpret_cast*>(&z); + ASSERT_EQ(zz.real(), float(3)); + ASSERT_EQ(zz.imag(), float(4)); + } + + { + thrust::complex z(1, 2); + c10::complex zz = *reinterpret_cast*>(&z); + ASSERT_EQ(zz.real(), double(1)); + ASSERT_EQ(zz.imag(), double(2)); + } + + { + c10::complex z(3, 4); + thrust::complex zz = + *reinterpret_cast*>(&z); + ASSERT_EQ(zz.real(), double(3)); + ASSERT_EQ(zz.imag(), double(4)); + } +} +#endif + +} // namespace memory + +namespace constructors { + +template +C10_HOST_DEVICE void test_construct_from_scalar() { + constexpr scalar_t num1 = scalar_t(1.23); + constexpr scalar_t num2 = scalar_t(4.56); + constexpr scalar_t zero = scalar_t(); + static_assert(c10::complex(num1, num2).real() == num1, ""); + static_assert(c10::complex(num1, num2).imag() == num2, ""); + static_assert(c10::complex(num1).real() == num1, ""); + static_assert(c10::complex(num1).imag() == zero, ""); + static_assert(c10::complex().real() == zero, ""); + static_assert(c10::complex().imag() == zero, ""); +} + +template +C10_HOST_DEVICE void test_construct_from_other() { + constexpr other_t num1 = other_t(1.23); + constexpr other_t num2 = other_t(4.56); + constexpr scalar_t num3 = scalar_t(num1); + constexpr scalar_t num4 = scalar_t(num2); + static_assert( + c10::complex(c10::complex(num1, num2)).real() == num3, + ""); + static_assert( + c10::complex(c10::complex(num1, num2)).imag() == num4, + ""); +} + +MAYBE_GLOBAL void test_convert_constructors() { + test_construct_from_scalar(); + test_construct_from_scalar(); + + static_assert( + std::is_convertible, c10::complex>::value, ""); + static_assert( + !std::is_convertible, c10::complex>::value, + ""); + static_assert( + std::is_convertible, c10::complex>::value, + ""); + static_assert( + std::is_convertible, c10::complex>::value, + ""); + + static_assert( + std::is_constructible, c10::complex>::value, + ""); + static_assert( + std::is_constructible, c10::complex>::value, + ""); + static_assert( + std::is_constructible, c10::complex>::value, + ""); + static_assert( + std::is_constructible, c10::complex>::value, + ""); + + test_construct_from_other(); + test_construct_from_other(); + test_construct_from_other(); + test_construct_from_other(); +} + +template +C10_HOST_DEVICE void test_construct_from_std() { + constexpr scalar_t num1 = scalar_t(1.23); + constexpr scalar_t num2 = scalar_t(4.56); + static_assert( + c10::complex(std::complex(num1, num2)).real() == num1, + ""); + static_assert( + c10::complex(std::complex(num1, num2)).imag() == num2, + ""); +} + +MAYBE_GLOBAL void test_std_conversion() { + test_construct_from_std(); + test_construct_from_std(); +} + +#if defined(__CUDACC__) || defined(__HIPCC__) +template +void test_construct_from_thrust() { + constexpr scalar_t num1 = scalar_t(1.23); + constexpr scalar_t num2 = scalar_t(4.56); + ASSERT_EQ( + c10::complex(thrust::complex(num1, num2)).real(), + num1); + ASSERT_EQ( + c10::complex(thrust::complex(num1, num2)).imag(), + num2); +} + +TEST(TestConstructors, FromThrust) { + test_construct_from_thrust(); + test_construct_from_thrust(); +} +#endif + +TEST(TestConstructors, UnorderedMap) { + std::unordered_map< + c10::complex, + c10::complex, + c10::hash>> + m; + auto key1 = c10::complex(2.5, 3); + auto key2 = c10::complex(2, 0); + auto val1 = c10::complex(2, -3.2); + auto val2 = c10::complex(0, -3); + m[key1] = val1; + m[key2] = val2; + ASSERT_EQ(m[key1], val1); + ASSERT_EQ(m[key2], val2); +} + +} // namespace constructors + +namespace assignment { + +template +constexpr c10::complex one() { + c10::complex result(3, 4); + result = scalar_t(1); + return result; +} + +MAYBE_GLOBAL void test_assign_real() { + static_assert(one().real() == float(1), ""); + static_assert(one().imag() == float(), ""); + static_assert(one().real() == double(1), ""); + static_assert(one().imag() == double(), ""); +} + +constexpr std::tuple, c10::complex> one_two() { + constexpr c10::complex src(1, 2); + c10::complex ret0; + c10::complex ret1; + ret0 = ret1 = src; + return std::make_tuple(ret0, ret1); +} + +MAYBE_GLOBAL void test_assign_other() { + constexpr auto tup = one_two(); + static_assert(std::get>(tup).real() == double(1), ""); + static_assert(std::get>(tup).imag() == double(2), ""); + static_assert(std::get>(tup).real() == float(1), ""); + static_assert(std::get>(tup).imag() == float(2), ""); +} + +constexpr std::tuple, c10::complex> one_two_std() { + constexpr std::complex src(1, 1); + c10::complex ret0; + c10::complex ret1; + ret0 = ret1 = src; + return std::make_tuple(ret0, ret1); +} + +MAYBE_GLOBAL void test_assign_std() { + constexpr auto tup = one_two(); + static_assert(std::get>(tup).real() == double(1), ""); + static_assert(std::get>(tup).imag() == double(2), ""); + static_assert(std::get>(tup).real() == float(1), ""); + static_assert(std::get>(tup).imag() == float(2), ""); +} + +#if defined(__CUDACC__) || defined(__HIPCC__) +C10_HOST_DEVICE std::tuple, c10::complex> +one_two_thrust() { + thrust::complex src(1, 2); + c10::complex ret0; + c10::complex ret1; + ret0 = ret1 = src; + return std::make_tuple(ret0, ret1); +} + +TEST(TestAssignment, FromThrust) { + auto tup = one_two_thrust(); + ASSERT_EQ(std::get>(tup).real(), double(1)); + ASSERT_EQ(std::get>(tup).imag(), double(2)); + ASSERT_EQ(std::get>(tup).real(), float(1)); + ASSERT_EQ(std::get>(tup).imag(), float(2)); +} +#endif + +} // namespace assignment + +namespace literals { + +MAYBE_GLOBAL void test_complex_literals() { + using namespace c10::complex_literals; + static_assert(std::is_same>::value, ""); + static_assert((0.5_if).real() == float(), ""); + static_assert((0.5_if).imag() == float(0.5), ""); + static_assert( + std::is_same>::value, ""); + static_assert((0.5_id).real() == float(), ""); + static_assert((0.5_id).imag() == float(0.5), ""); + + static_assert(std::is_same>::value, ""); + static_assert((1_if).real() == float(), ""); + static_assert((1_if).imag() == float(1), ""); + static_assert(std::is_same>::value, ""); + static_assert((1_id).real() == double(), ""); + static_assert((1_id).imag() == double(1), ""); +} + +} // namespace literals + +namespace real_imag { + +template +constexpr c10::complex zero_one() { + c10::complex result; + result.imag(scalar_t(1)); + return result; +} + +template +constexpr c10::complex one_zero() { + c10::complex result; + result.real(scalar_t(1)); + return result; +} + +MAYBE_GLOBAL void test_real_imag_modify() { + static_assert(zero_one().real() == float(0), ""); + static_assert(zero_one().imag() == float(1), ""); + static_assert(zero_one().real() == double(0), ""); + static_assert(zero_one().imag() == double(1), ""); + + static_assert(one_zero().real() == float(1), ""); + static_assert(one_zero().imag() == float(0), ""); + static_assert(one_zero().real() == double(1), ""); + static_assert(one_zero().imag() == double(0), ""); +} + +} // namespace real_imag + +namespace arithmetic_assign { + +template +constexpr c10::complex p(scalar_t value) { + c10::complex result(scalar_t(2), scalar_t(2)); + result += value; + return result; +} + +template +constexpr c10::complex m(scalar_t value) { + c10::complex result(scalar_t(2), scalar_t(2)); + result -= value; + return result; +} + +template +constexpr c10::complex t(scalar_t value) { + c10::complex result(scalar_t(2), scalar_t(2)); + result *= value; + return result; +} + +template +constexpr c10::complex d(scalar_t value) { + c10::complex result(scalar_t(2), scalar_t(2)); + result /= value; + return result; +} + +template +C10_HOST_DEVICE void test_arithmetic_assign_scalar() { + constexpr c10::complex x = p(scalar_t(1)); + static_assert(x.real() == scalar_t(3), ""); + static_assert(x.imag() == scalar_t(2), ""); + constexpr c10::complex y = m(scalar_t(1)); + static_assert(y.real() == scalar_t(1), ""); + static_assert(y.imag() == scalar_t(2), ""); + constexpr c10::complex z = t(scalar_t(2)); + static_assert(z.real() == scalar_t(4), ""); + static_assert(z.imag() == scalar_t(4), ""); + constexpr c10::complex t = d(scalar_t(2)); + static_assert(t.real() == scalar_t(1), ""); + static_assert(t.imag() == scalar_t(1), ""); +} + +template +constexpr c10::complex p( + scalar_t real, + scalar_t imag, + c10::complex rhs) { + c10::complex result(real, imag); + result += rhs; + return result; +} + +template +constexpr c10::complex m( + scalar_t real, + scalar_t imag, + c10::complex rhs) { + c10::complex result(real, imag); + result -= rhs; + return result; +} + +template +constexpr c10::complex t( + scalar_t real, + scalar_t imag, + c10::complex rhs) { + c10::complex result(real, imag); + result *= rhs; + return result; +} + +template +constexpr c10::complex d( + scalar_t real, + scalar_t imag, + c10::complex rhs) { + c10::complex result(real, imag); + result /= rhs; + return result; +} + +template +C10_HOST_DEVICE void test_arithmetic_assign_complex() { + using namespace c10::complex_literals; + constexpr c10::complex x2 = p(scalar_t(2), scalar_t(2), 1.0_if); + static_assert(x2.real() == scalar_t(2), ""); + static_assert(x2.imag() == scalar_t(3), ""); + constexpr c10::complex x3 = p(scalar_t(2), scalar_t(2), 1.0_id); + static_assert(x3.real() == scalar_t(2), ""); + + // this test is skipped due to a bug in constexpr evaluation + // in nvcc. This bug has already been fixed since CUDA 11.2 +#if !defined(__CUDACC__) || (defined(CUDA_VERSION) && CUDA_VERSION >= 11020) + static_assert(x3.imag() == scalar_t(3), ""); +#endif + + constexpr c10::complex y2 = m(scalar_t(2), scalar_t(2), 1.0_if); + static_assert(y2.real() == scalar_t(2), ""); + static_assert(y2.imag() == scalar_t(1), ""); + constexpr c10::complex y3 = m(scalar_t(2), scalar_t(2), 1.0_id); + static_assert(y3.real() == scalar_t(2), ""); + + // this test is skipped due to a bug in constexpr evaluation + // in nvcc. This bug has already been fixed since CUDA 11.2 +#if !defined(__CUDACC__) || (defined(CUDA_VERSION) && CUDA_VERSION >= 11020) + static_assert(y3.imag() == scalar_t(1), ""); +#endif + + constexpr c10::complex z2 = t(scalar_t(1), scalar_t(-2), 1.0_if); + static_assert(z2.real() == scalar_t(2), ""); + static_assert(z2.imag() == scalar_t(1), ""); + constexpr c10::complex z3 = t(scalar_t(1), scalar_t(-2), 1.0_id); + static_assert(z3.real() == scalar_t(2), ""); + static_assert(z3.imag() == scalar_t(1), ""); + + constexpr c10::complex t2 = d(scalar_t(-1), scalar_t(2), 1.0_if); + static_assert(t2.real() == scalar_t(2), ""); + static_assert(t2.imag() == scalar_t(1), ""); + constexpr c10::complex t3 = d(scalar_t(-1), scalar_t(2), 1.0_id); + static_assert(t3.real() == scalar_t(2), ""); + static_assert(t3.imag() == scalar_t(1), ""); +} + +MAYBE_GLOBAL void test_arithmetic_assign() { + test_arithmetic_assign_scalar(); + test_arithmetic_assign_scalar(); + test_arithmetic_assign_complex(); + test_arithmetic_assign_complex(); +} + +} // namespace arithmetic_assign + +namespace arithmetic { + +template +C10_HOST_DEVICE void test_arithmetic_() { + static_assert( + c10::complex(1, 2) == +c10::complex(1, 2), ""); + static_assert( + c10::complex(-1, -2) == -c10::complex(1, 2), ""); + + static_assert( + c10::complex(1, 2) + c10::complex(3, 4) == + c10::complex(4, 6), + ""); + static_assert( + c10::complex(1, 2) + scalar_t(3) == + c10::complex(4, 2), + ""); + static_assert( + scalar_t(3) + c10::complex(1, 2) == + c10::complex(4, 2), + ""); + + static_assert( + c10::complex(1, 2) - c10::complex(3, 4) == + c10::complex(-2, -2), + ""); + static_assert( + c10::complex(1, 2) - scalar_t(3) == + c10::complex(-2, 2), + ""); + static_assert( + scalar_t(3) - c10::complex(1, 2) == + c10::complex(2, -2), + ""); + + static_assert( + c10::complex(1, 2) * c10::complex(3, 4) == + c10::complex(-5, 10), + ""); + static_assert( + c10::complex(1, 2) * scalar_t(3) == + c10::complex(3, 6), + ""); + static_assert( + scalar_t(3) * c10::complex(1, 2) == + c10::complex(3, 6), + ""); + + static_assert( + c10::complex(-5, 10) / c10::complex(3, 4) == + c10::complex(1, 2), + ""); + static_assert( + c10::complex(5, 10) / scalar_t(5) == + c10::complex(1, 2), + ""); + static_assert( + scalar_t(25) / c10::complex(3, 4) == + c10::complex(3, -4), + ""); +} + +MAYBE_GLOBAL void test_arithmetic() { + test_arithmetic_(); + test_arithmetic_(); +} + +template +void test_binary_ops_for_int_type_(T real, T img, int_t num) { + c10::complex c(real, img); + ASSERT_EQ(c + num, c10::complex(real + num, img)); + ASSERT_EQ(num + c, c10::complex(num + real, img)); + ASSERT_EQ(c - num, c10::complex(real - num, img)); + ASSERT_EQ(num - c, c10::complex(num - real, -img)); + ASSERT_EQ(c * num, c10::complex(real * num, img * num)); + ASSERT_EQ(num * c, c10::complex(num * real, num * img)); + ASSERT_EQ(c / num, c10::complex(real / num, img / num)); + ASSERT_EQ( + num / c, + c10::complex(num * real / std::norm(c), -num * img / std::norm(c))); +} + +template +void test_binary_ops_for_all_int_types_(T real, T img, int8_t i) { + test_binary_ops_for_int_type_(real, img, i); + test_binary_ops_for_int_type_(real, img, i); + test_binary_ops_for_int_type_(real, img, i); + test_binary_ops_for_int_type_(real, img, i); +} + +TEST(TestArithmeticIntScalar, All) { + test_binary_ops_for_all_int_types_(1.0, 0.1, 1); + test_binary_ops_for_all_int_types_(-1.3, -0.2, -2); +} + +} // namespace arithmetic + +namespace equality { + +template +C10_HOST_DEVICE void test_equality_() { + static_assert( + c10::complex(1, 2) == c10::complex(1, 2), ""); + static_assert(c10::complex(1, 0) == scalar_t(1), ""); + static_assert(scalar_t(1) == c10::complex(1, 0), ""); + static_assert( + c10::complex(1, 2) != c10::complex(3, 4), ""); + static_assert(c10::complex(1, 2) != scalar_t(1), ""); + static_assert(scalar_t(1) != c10::complex(1, 2), ""); +} + +MAYBE_GLOBAL void test_equality() { + test_equality_(); + test_equality_(); +} + +} // namespace equality + +namespace io { + +template +void test_io_() { + std::stringstream ss; + c10::complex a(1, 2); + ss << a; + ASSERT_EQ(ss.str(), "(1,2)"); + ss.str("(3,4)"); + ss >> a; + ASSERT_TRUE(a == c10::complex(3, 4)); +} + +TEST(TestIO, All) { + test_io_(); + test_io_(); +} + +} // namespace io + +namespace test_std { + +template +C10_HOST_DEVICE void test_callable_() { + static_assert(std::real(c10::complex(1, 2)) == scalar_t(1), ""); + static_assert(std::imag(c10::complex(1, 2)) == scalar_t(2), ""); + std::abs(c10::complex(1, 2)); + std::arg(c10::complex(1, 2)); + static_assert(std::norm(c10::complex(3, 4)) == scalar_t(25), ""); + static_assert( + std::conj(c10::complex(3, 4)) == c10::complex(3, -4), + ""); + c10::polar(float(1), float(PI / 2)); + c10::polar(double(1), double(PI / 2)); +} + +MAYBE_GLOBAL void test_callable() { + test_callable_(); + test_callable_(); +} + +template +void test_values_() { + ASSERT_EQ(std::abs(c10::complex(3, 4)), scalar_t(5)); + ASSERT_LT(std::abs(std::arg(c10::complex(0, 1)) - PI / 2), 1e-6); + ASSERT_LT( + std::abs( + c10::polar(scalar_t(1), scalar_t(PI / 2)) - + c10::complex(0, 1)), + 1e-6); +} + +TEST(TestStd, BasicFunctions) { + test_values_(); + test_values_(); + // CSQRT edge cases: checks for overflows which are likely to occur + // if square root is computed using polar form + ASSERT_LT( + std::abs(std::sqrt(c10::complex(-1e20, -4988429.2)).real()), 3e-4); + ASSERT_LT( + std::abs(std::sqrt(c10::complex(-1e60, -4988429.2)).real()), + 3e-4); +} + +} // namespace test_std + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/AbortHandler.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/AbortHandler.h new file mode 100644 index 0000000000000000000000000000000000000000..f7bcaaa28af3871f95280a9bd764aea260405ca1 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/AbortHandler.h @@ -0,0 +1,88 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include +#include +#include +#include +#include +#include +#include +#include + +namespace c10 { +class AbortHandlerHelper { + public: + static AbortHandlerHelper& getInstance() { +#ifdef _WIN32 + thread_local +#endif // _WIN32 + static AbortHandlerHelper instance; + return instance; + } + + void set(std::terminate_handler handler) { + std::lock_guard lk(mutex); + if (!inited) { + prev = std::set_terminate(handler); + curr = std::get_terminate(); + inited = true; + } + } + + std::terminate_handler getPrev() const { + return prev; + } + + private: + std::terminate_handler prev = nullptr; + std::terminate_handler curr = nullptr; + bool inited = false; + std::mutex mutex; + AbortHandlerHelper() = default; + ~AbortHandlerHelper() { + // Only restore the handler if we are the current one + if (inited && curr == std::get_terminate()) { + std::set_terminate(prev); + } + } + + public: + AbortHandlerHelper(AbortHandlerHelper const&) = delete; + void operator=(AbortHandlerHelper const&) = delete; + AbortHandlerHelper(AbortHandlerHelper&&) = delete; + void operator=(AbortHandlerHelper&&) = delete; +}; + +namespace detail { +C10_ALWAYS_INLINE void terminate_handler() { + std::cout << "Unhandled exception caught in c10/util/AbortHandler.h" << '\n'; + auto backtrace = get_backtrace(); + std::cout << backtrace << '\n' << std::flush; + auto prev_handler = AbortHandlerHelper::getInstance().getPrev(); + if (prev_handler) { + prev_handler(); + } else { + std::abort(); + } +} +} // namespace detail + +C10_ALWAYS_INLINE void set_terminate_handler() { + bool use_custom_terminate = false; + // On Windows it is enabled by default based on + // https://github.com/pytorch/pytorch/pull/50320#issuecomment-763147062 +#ifdef _WIN32 + use_custom_terminate = true; +#endif // _WIN32 + auto result = c10::utils::check_env("TORCH_CUSTOM_TERMINATE"); + if (result != std::nullopt) { + use_custom_terminate = result.value(); + } + if (use_custom_terminate) { + AbortHandlerHelper::getInstance().set(detail::terminate_handler); + } +} +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/AlignOf.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/AlignOf.h new file mode 100644 index 0000000000000000000000000000000000000000..ce9fe90961700f2a1dd3f9c25e120eaa9609fc03 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/AlignOf.h @@ -0,0 +1,181 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +//===--- AlignOf.h - Portable calculation of type alignment -----*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file defines the AlignedCharArray and AlignedCharArrayUnion classes. +// +//===----------------------------------------------------------------------===// + +// ATen: modified from llvm::AlignOf +// replaced LLVM_ALIGNAS with alignas + +#pragma once + +#include + +namespace c10 { + +/// \struct AlignedCharArray +/// \brief Helper for building an aligned character array type. +/// +/// This template is used to explicitly build up a collection of aligned +/// character array types. We have to build these up using a macro and explicit +/// specialization to cope with MSVC (at least till 2015) where only an +/// integer literal can be used to specify an alignment constraint. Once built +/// up here, we can then begin to indirect between these using normal C++ +/// template parameters. + +// MSVC requires special handling here. +#ifndef _MSC_VER + +template +struct AlignedCharArray { + // NOLINTNEXTLINE(*c-arrays) + alignas(Alignment) char buffer[Size]; +}; + +#else // _MSC_VER + +/// \brief Create a type with an aligned char buffer. +template +struct AlignedCharArray; + +// We provide special variations of this template for the most common +// alignments because __declspec(align(...)) doesn't actually work when it is +// a member of a by-value function argument in MSVC, even if the alignment +// request is something reasonably like 8-byte or 16-byte. Note that we can't +// even include the declspec with the union that forces the alignment because +// MSVC warns on the existence of the declspec despite the union member forcing +// proper alignment. + +template +struct AlignedCharArray<1, Size> { + union { + char aligned; + char buffer[Size]; + }; +}; + +template +struct AlignedCharArray<2, Size> { + union { + short aligned; + char buffer[Size]; + }; +}; + +template +struct AlignedCharArray<4, Size> { + union { + int aligned; + char buffer[Size]; + }; +}; + +template +struct AlignedCharArray<8, Size> { + union { + double aligned; + char buffer[Size]; + }; +}; + +// The rest of these are provided with a __declspec(align(...)) and we simply +// can't pass them by-value as function arguments on MSVC. + +#define AT_ALIGNEDCHARARRAY_TEMPLATE_ALIGNMENT(x) \ + template \ + struct AlignedCharArray { \ + __declspec(align(x)) char buffer[Size]; \ + }; + +AT_ALIGNEDCHARARRAY_TEMPLATE_ALIGNMENT(16) +AT_ALIGNEDCHARARRAY_TEMPLATE_ALIGNMENT(32) +AT_ALIGNEDCHARARRAY_TEMPLATE_ALIGNMENT(64) +AT_ALIGNEDCHARARRAY_TEMPLATE_ALIGNMENT(128) + +#undef AT_ALIGNEDCHARARRAY_TEMPLATE_ALIGNMENT + +#endif // _MSC_VER + +namespace detail { +template < + typename T1, + typename T2 = char, + typename T3 = char, + typename T4 = char, + typename T5 = char, + typename T6 = char, + typename T7 = char, + typename T8 = char, + typename T9 = char, + typename T10 = char> +class AlignerImpl { + T1 t1; + T2 t2; + T3 t3; + T4 t4; + T5 t5; + T6 t6; + T7 t7; + T8 t8; + T9 t9; + T10 t10; + + public: + AlignerImpl() = delete; +}; + +template < + typename T1, + typename T2 = char, + typename T3 = char, + typename T4 = char, + typename T5 = char, + typename T6 = char, + typename T7 = char, + typename T8 = char, + typename T9 = char, + typename T10 = char> +union SizerImpl { + // NOLINTNEXTLINE(*c-arrays) + char arr1[sizeof(T1)], arr2[sizeof(T2)], arr3[sizeof(T3)], arr4[sizeof(T4)], + arr5[sizeof(T5)], arr6[sizeof(T6)], arr7[sizeof(T7)], arr8[sizeof(T8)], + arr9[sizeof(T9)], arr10[sizeof(T10)]; +}; +} // end namespace detail + +/// \brief This union template exposes a suitably aligned and sized character +/// array member which can hold elements of any of up to ten types. +/// +/// These types may be arrays, structs, or any other types. The goal is to +/// expose a char array buffer member which can be used as suitable storage for +/// a placement new of any of these types. Support for more than ten types can +/// be added at the cost of more boilerplate. +template < + typename T1, + typename T2 = char, + typename T3 = char, + typename T4 = char, + typename T5 = char, + typename T6 = char, + typename T7 = char, + typename T8 = char, + typename T9 = char, + typename T10 = char> +struct AlignedCharArrayUnion + : AlignedCharArray< + alignof(detail::AlignerImpl), + sizeof(::c10::detail:: + SizerImpl)> {}; +} // end namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/ApproximateClock.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/ApproximateClock.h new file mode 100644 index 0000000000000000000000000000000000000000..7410fc4e829fa44aadb22f61e85e1f05f9a81134 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/ApproximateClock.h @@ -0,0 +1,120 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Copyright 2023-present Facebook. All Rights Reserved. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(C10_IOS) && defined(C10_MOBILE) +#include // for gettimeofday() +#endif + +#if defined(__i386__) || defined(__x86_64__) || defined(__amd64__) +#define C10_RDTSC +#if defined(_MSC_VER) +#include +#elif defined(__CUDACC__) || defined(__HIPCC__) +#undef C10_RDTSC +#elif defined(__clang__) +// `__rdtsc` is available by default. +// NB: This has to be first, because Clang will also define `__GNUC__` +#elif defined(__GNUC__) +#include +#else +#undef C10_RDTSC +#endif +#endif + +namespace c10 { + +using time_t = int64_t; +using steady_clock_t = std::conditional_t< + std::chrono::high_resolution_clock::is_steady, + std::chrono::high_resolution_clock, + std::chrono::steady_clock>; + +inline time_t getTimeSinceEpoch() { + auto now = std::chrono::system_clock::now().time_since_epoch(); + return std::chrono::duration_cast(now).count(); +} + +inline time_t getTime(bool allow_monotonic = false) { +#if defined(C10_IOS) && defined(C10_MOBILE) + // clock_gettime is only available on iOS 10.0 or newer. Unlike OS X, iOS + // can't rely on CLOCK_REALTIME, as it is defined no matter if clock_gettime + // is implemented or not + struct timeval now; + gettimeofday(&now, NULL); + return static_cast(now.tv_sec) * 1000000000 + + static_cast(now.tv_usec) * 1000; +#elif defined(_WIN32) || defined(__MACH__) + return std::chrono::duration_cast( + steady_clock_t::now().time_since_epoch()) + .count(); +#else + // clock_gettime is *much* faster than std::chrono implementation on Linux + struct timespec t{}; + auto mode = CLOCK_REALTIME; + if (allow_monotonic) { + mode = CLOCK_MONOTONIC; + } + clock_gettime(mode, &t); + return static_cast(t.tv_sec) * 1000000000 + + static_cast(t.tv_nsec); +#endif +} + +// We often do not need to capture true wall times. If a fast mechanism such +// as TSC is available we can use that instead and convert back to epoch time +// during post processing. This greatly reduce the clock's contribution to +// profiling. +// http://btorpey.github.io/blog/2014/02/18/clock-sources-in-linux/ +// https://quick-bench.com/q/r8opkkGZSJMu9wM_XTbDouq-0Io +// TODO: We should use +// `https://github.com/google/benchmark/blob/main/src/cycleclock.h` +inline auto getApproximateTime() { +#if defined(C10_RDTSC) + return static_cast(__rdtsc()); +#else + return getTime(); +#endif +} + +using approx_time_t = decltype(getApproximateTime()); +static_assert( + std::is_same_v || + std::is_same_v, + "Expected either int64_t (`getTime`) or uint64_t (some TSC reads)."); + +// Convert `getCount` results to Nanoseconds since unix epoch. +class C10_API ApproximateClockToUnixTimeConverter final { + public: + ApproximateClockToUnixTimeConverter(); + std::function makeConverter(); + + struct UnixAndApproximateTimePair { + time_t t_; + approx_time_t approx_t_; + }; + static UnixAndApproximateTimePair measurePair(); + + private: + static constexpr size_t replicates = 1001; + using time_pairs = std::array; + time_pairs measurePairs(); + + time_pairs start_times_; +}; + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Array.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Array.h new file mode 100644 index 0000000000000000000000000000000000000000..5cb2d8dff74253bf9c54d53b3aa532d91bee89a8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Array.h @@ -0,0 +1,23 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +namespace c10 { + +// This helper function creates a constexpr std::array +// From a compile time list of values, without requiring you to explicitly +// write out the length. +// +// See also https://stackoverflow.com/a/26351760/23845 +template +inline constexpr auto array_of(T&&... t) -> std::array { + return {{std::forward(t)...}}; +} + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/ArrayRef.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/ArrayRef.h new file mode 100644 index 0000000000000000000000000000000000000000..9da524e96ce718b7782e1584a795c919af0ecd78 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/ArrayRef.h @@ -0,0 +1,326 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +//===--- ArrayRef.h - Array Reference Wrapper -------------------*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +// ATen: modified from llvm::ArrayRef. +// removed llvm-specific functionality +// removed some implicit const -> non-const conversions that rely on +// complicated std::enable_if meta-programming +// removed a bunch of slice variants for simplicity... + +#pragma once + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace c10 { +/// ArrayRef - Represent a constant reference to an array (0 or more elements +/// consecutively in memory), i.e. a start pointer and a length. It allows +/// various APIs to take consecutive elements easily and conveniently. +/// +/// This class does not own the underlying data, it is expected to be used in +/// situations where the data resides in some other buffer, whose lifetime +/// extends past that of the ArrayRef. For this reason, it is not in general +/// safe to store an ArrayRef. +/// +/// This is intended to be trivially copyable, so it should be passed by +/// value. +/// +/// NOTE: We have refactored out the headeronly parts of the ArrayRef struct +/// into HeaderOnlyArrayRef. As adding `virtual` would change the performance of +/// the underlying constexpr calls, we rely on apparent-type dispatch for +/// inheritance. This should be fine because their memory format is the same, +/// and it is never incorrect for ArrayRef to call HeaderOnlyArrayRef methods. +/// However, you should prefer to use ArrayRef when possible, because its use +/// of TORCH_CHECK will lead to better user-facing error messages. +template +// ArrayRef cannot be derived from. Normally, we would use `final` +// specifier to force this constraint at compile time. However, Intel +// compiler does not recognize ArrayRef as a class template (which is +// required in the definition of at::TensorAccessor, for instance) +// when `final` specifier is used. So, we cannot define ArrayRef as +// final because of the Intel compiler issue. +class ArrayRef : public HeaderOnlyArrayRef { + public: + /// @name Constructors, all inherited from HeaderOnlyArrayRef except for + /// SmallVector. As inherited constructors won't work with class template + /// argument deduction (CTAD) until C++23, we add deduction guides after + /// the class definition to enable CTAD. + /// @{ + + using HeaderOnlyArrayRef::HeaderOnlyArrayRef; + + /// Construct an ArrayRef from a SmallVector. This is templated in order to + /// avoid instantiating SmallVectorTemplateCommon whenever we + /// copy-construct an ArrayRef. + /// NOTE: this is the only constructor that is not inherited from + /// HeaderOnlyArrayRef. + template + /* implicit */ ArrayRef(const SmallVectorTemplateCommon& Vec) + : HeaderOnlyArrayRef(Vec.data(), Vec.size()) {} + + /// @} + /// @name Simple Operations, mostly inherited from HeaderOnlyArrayRef + /// @{ + + /// front - Get the first element. + /// We deviate from HeaderOnlyArrayRef by using TORCH_CHECK instead of + /// STD_TORCH_CHECK + constexpr const T& front() const { + TORCH_CHECK( + !this->empty(), "ArrayRef: attempted to access front() of empty list"); + return this->Data[0]; + } + + /// back - Get the last element. + /// We deviate from HeaderOnlyArrayRef by using TORCH_CHECK instead of + /// STD_TORCH_CHECK + constexpr const T& back() const { + TORCH_CHECK( + !this->empty(), "ArrayRef: attempted to access back() of empty list"); + return this->Data[this->Length - 1]; + } + + /// slice(n, m) - Take M elements of the array starting at element N + /// We deviate from HeaderOnlyArrayRef by using TORCH_CHECK instead of + /// STD_TORCH_CHECK + constexpr ArrayRef slice(size_t N, size_t M) const { + TORCH_CHECK( + N + M <= this->size(), + "ArrayRef: invalid slice, N = ", + N, + "; M = ", + M, + "; size = ", + this->size()); + return ArrayRef(this->data() + N, M); + } + + /// slice(n) - Chop off the first N elements of the array. + /// We deviate from HeaderOnlyArrayRef by using TORCH_CHECK instead of + /// STD_TORCH_CHECK + constexpr ArrayRef slice(size_t N) const { + TORCH_CHECK( + N <= this->size(), + "ArrayRef: invalid slice, N = ", + N, + "; size = ", + this->size()); + return slice(N, this->size() - N); // should this slice be this->slice? + } + + /// @} + /// @name Operator Overloads + /// @{ + + /// Vector compatibility + /// We deviate from HeaderOnlyArrayRef by using TORCH_CHECK instead of + /// STD_TORCH_CHECK + constexpr const T& at(size_t Index) const { + TORCH_CHECK( + Index < this->Length, + "ArrayRef: invalid index Index = ", + Index, + "; Length = ", + this->Length); + return this->Data[Index]; + } + + /// Disallow accidental assignment from a temporary. + /// + /// The declaration here is extra complicated so that "arrayRef = {}" + /// continues to select the move assignment operator. + template + std::enable_if_t, ArrayRef>& operator=( + // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward) + U&& Temporary) = delete; + + /// Disallow accidental assignment from a temporary. + /// + /// The declaration here is extra complicated so that "arrayRef = {}" + /// continues to select the move assignment operator. + template + std::enable_if_t, ArrayRef>& operator=( + std::initializer_list) = delete; + + /// @} +}; + +/// Deduction guides for ArrayRef to support CTAD with inherited constructors +/// These mirror the constructors inherited from HeaderOnlyArrayRef +/// @{ + +// Single element constructor +template +ArrayRef(const T&) -> ArrayRef; + +// Pointer and length constructor +template +ArrayRef(const T*, size_t) -> ArrayRef; + +// Range constructor (begin, end) +template +ArrayRef(const T*, const T*) -> ArrayRef; + +// Generic container constructor (anything with .data() and .size()) +template +ArrayRef(const Container&) -> ArrayRef< + std::remove_pointer_t().data())>>; + +// std::vector constructor +template +ArrayRef(const std::vector&) -> ArrayRef; + +// std::array constructor +template +ArrayRef(const std::array&) -> ArrayRef; + +// C array constructor +template +ArrayRef(const T (&)[N]) -> ArrayRef; + +// std::initializer_list constructor +template +ArrayRef(const std::initializer_list&) -> ArrayRef; + +/// @} + +template +std::ostream& operator<<(std::ostream& out, ArrayRef list) { + int i = 0; + out << '['; + for (const auto& e : list) { + if (i++ > 0) + out << ", "; + out << e; + } + out << ']'; + return out; +} + +/// @name ArrayRef Convenience constructors +/// @{ + +/// Construct an ArrayRef from a single element. +template +ArrayRef makeArrayRef(const T& OneElt) { + return OneElt; +} + +/// Construct an ArrayRef from a pointer and length. +template +ArrayRef makeArrayRef(const T* data, size_t length) { + return ArrayRef(data, length); +} + +/// Construct an ArrayRef from a range. +template +ArrayRef makeArrayRef(const T* begin, const T* end) { + return ArrayRef(begin, end); +} + +/// Construct an ArrayRef from a SmallVector. +template +ArrayRef makeArrayRef(const SmallVectorImpl& Vec) { + return Vec; +} + +/// Construct an ArrayRef from a SmallVector. +template +ArrayRef makeArrayRef(const SmallVector& Vec) { + return Vec; +} + +/// Construct an ArrayRef from a std::vector. +template +ArrayRef makeArrayRef(const std::vector& Vec) { + return Vec; +} + +/// Construct an ArrayRef from a std::array. +template +ArrayRef makeArrayRef(const std::array& Arr) { + return Arr; +} + +/// Construct an ArrayRef from an ArrayRef (no-op) (const) +template +ArrayRef makeArrayRef(const ArrayRef& Vec) { + return Vec; +} + +/// Construct an ArrayRef from an ArrayRef (no-op) +template +ArrayRef& makeArrayRef(ArrayRef& Vec) { + return Vec; +} + +/// Construct an ArrayRef from a C array. +template +// NOLINTNEXTLINE(*c-arrays*) +ArrayRef makeArrayRef(const T (&Arr)[N]) { + return ArrayRef(Arr); +} + +// WARNING: Template instantiation will NOT be willing to do an implicit +// conversions to get you to an c10::ArrayRef, which is why we need so +// many overloads. + +template +bool operator==(c10::ArrayRef a1, c10::ArrayRef a2) { + return a1.equals(a2); +} + +template +bool operator!=(c10::ArrayRef a1, c10::ArrayRef a2) { + return !a1.equals(a2); +} + +template +bool operator==(const std::vector& a1, c10::ArrayRef a2) { + return c10::ArrayRef(a1).equals(a2); +} + +template +bool operator!=(const std::vector& a1, c10::ArrayRef a2) { + return !c10::ArrayRef(a1).equals(a2); +} + +template +bool operator==(c10::ArrayRef a1, const std::vector& a2) { + return a1.equals(c10::ArrayRef(a2)); +} + +template +bool operator!=(c10::ArrayRef a1, const std::vector& a2) { + return !a1.equals(c10::ArrayRef(a2)); +} + +using IntArrayRef = ArrayRef; + +using IntList [[deprecated( + "This alias is deprecated because it doesn't make ownership semantics obvious. Use IntArrayRef instead!")]] = + ArrayRef; + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/BFloat16-inl.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/BFloat16-inl.h new file mode 100644 index 0000000000000000000000000000000000000000..90ca6b677ab3740550f4700479497fd58c35536b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/BFloat16-inl.h @@ -0,0 +1,6 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/BFloat16-math.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/BFloat16-math.h new file mode 100644 index 0000000000000000000000000000000000000000..6865f84fa6af5dbd8e2fb60ff46f1bbabdead1fd --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/BFloat16-math.h @@ -0,0 +1,304 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +C10_CLANG_DIAGNOSTIC_PUSH() +#if C10_CLANG_HAS_WARNING("-Wimplicit-float-conversion") +C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion") +#endif + +namespace c10 { +template +struct is_reduced_floating_point + : std::integral_constant< + bool, + std::is_same_v || std::is_same_v> {}; + +template +constexpr bool is_reduced_floating_point_v = + is_reduced_floating_point::value; +} // namespace c10 + +namespace std { + +#if !defined(FBCODE_CAFFE2) && !defined(C10_NODEPRECATED) +using c10::is_reduced_floating_point; +using c10::is_reduced_floating_point_v; +#endif // !defined(FBCODE_CAFFE2) && !defined(C10_NODEPRECATED) + +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T acos(T a) { + return std::acos(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T asin(T a) { + return std::asin(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T atan(T a) { + return std::atan(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T atanh(T a) { + return std::atanh(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T erf(T a) { + return std::erf(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T erfc(T a) { + return std::erfc(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T exp(T a) { + return std::exp(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T expm1(T a) { + return std::expm1(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline bool isfinite(T a) { + return std::isfinite(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T log(T a) { + return std::log(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T log10(T a) { + return std::log10(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T log1p(T a) { + return std::log1p(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T log2(T a) { + return std::log2(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T ceil(T a) { + return std::ceil(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T cos(T a) { + return std::cos(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T floor(T a) { + return std::floor(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T nearbyint(T a) { + return std::nearbyint(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T sin(T a) { + return std::sin(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T tan(T a) { + return std::tan(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T sinh(T a) { + return std::sinh(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T cosh(T a) { + return std::cosh(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T tanh(T a) { + return std::tanh(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T trunc(T a) { + return std::trunc(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T lgamma(T a) { + return std::lgamma(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T sqrt(T a) { + return std::sqrt(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T rsqrt(T a) { + return 1.0 / std::sqrt(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T abs(T a) { + return std::abs(float(a)); +} +#if defined(_MSC_VER) && defined(__CUDACC__) +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T pow(T a, double b) { + return std::pow(float(a), float(b)); +} +#else +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T pow(T a, double b) { + return std::pow(float(a), b); +} +#endif +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T pow(T a, T b) { + return std::pow(float(a), float(b)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T fmod(T a, T b) { + return std::fmod(float(a), float(b)); +} + +/* + The following function is inspired from the implementation in `musl` + Link to License: https://git.musl-libc.org/cgit/musl/tree/COPYRIGHT + ---------------------------------------------------------------------- + Copyright © 2005-2020 Rich Felker, et al. + + Permission is hereby granted, free of charge, to any person obtaining + a copy of this software and associated documentation files (the + "Software"), to deal in the Software without restriction, including + without limitation the rights to use, copy, modify, merge, publish, + distribute, sublicense, and/or sell copies of the Software, and to + permit persons to whom the Software is furnished to do so, subject to + the following conditions: + + The above copyright notice and this permission notice shall be + included in all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + ---------------------------------------------------------------------- + */ +template < + typename T, + typename std::enable_if_t, int> = 0> +C10_HOST_DEVICE inline T nextafter(T from, T to) { + // Reference: + // https://git.musl-libc.org/cgit/musl/tree/src/math/nextafter.c + using int_repr_t = uint16_t; + constexpr uint8_t bits = 16; + union { + T f; + int_repr_t i; + } ufrom = {from}, uto = {to}; + + // get a mask to get the sign bit i.e. MSB + int_repr_t sign_mask = int_repr_t{1} << (bits - 1); + + // short-circuit: if either is NaN, return NaN + if (from != from || to != to) { + return from + to; + } + + // short-circuit: if they are exactly the same. + if (ufrom.i == uto.i) { + return from; + } + + // mask the sign-bit to zero i.e. positive + // equivalent to abs(x) + int_repr_t abs_from = ufrom.i & ~sign_mask; + int_repr_t abs_to = uto.i & ~sign_mask; + if (abs_from == 0) { + // if both are zero but with different sign, + // preserve the sign of `to`. + if (abs_to == 0) { + return to; + } + // smallest subnormal with sign of `to`. + ufrom.i = (uto.i & sign_mask) | int_repr_t{1}; + return ufrom.f; + } + + // if abs(from) > abs(to) or sign(from) != sign(to) + if (abs_from > abs_to || ((ufrom.i ^ uto.i) & sign_mask)) { + ufrom.i--; + } else { + ufrom.i++; + } + + return ufrom.f; +} + +} // namespace std + +C10_CLANG_DIAGNOSTIC_POP() + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/BFloat16.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/BFloat16.h new file mode 100644 index 0000000000000000000000000000000000000000..90ca6b677ab3740550f4700479497fd58c35536b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/BFloat16.h @@ -0,0 +1,6 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Backtrace.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Backtrace.h new file mode 100644 index 0000000000000000000000000000000000000000..0a9e8d2c27ff43ab571d3883567ef5535c3287db --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Backtrace.h @@ -0,0 +1,36 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#ifndef C10_UTIL_BACKTRACE_H_ +#define C10_UTIL_BACKTRACE_H_ + +#include +#include +#include +#include + +#include +#include + +namespace c10 { + +// Symbolizing the backtrace can be expensive; pass it around as a lazy string +// so it is symbolized only if actually needed. +using Backtrace = std::shared_ptr>; + +// DEPRECATED: Prefer get_lazy_backtrace(). +C10_API std::string get_backtrace( + size_t frames_to_skip = 0, + size_t maximum_number_of_frames = 64, + bool skip_python_frames = true); + +C10_API Backtrace get_lazy_backtrace( + size_t frames_to_skip = 0, + size_t maximum_number_of_frames = 64, + bool skip_python_frames = true); + +} // namespace c10 + +#endif // C10_UTIL_BACKTRACE_H_ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Bitset.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Bitset.h new file mode 100644 index 0000000000000000000000000000000000000000..1e01d94ea590ccd96414ec760b09a48419de9de8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Bitset.h @@ -0,0 +1,123 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#if defined(_MSC_VER) +#include +#endif + +namespace c10::utils { + +/** + * This is a simple bitset class with sizeof(long long int) bits. + * You can set bits, unset bits, query bits by index, + * and query for the first set bit. + * Before using this class, please also take a look at std::bitset, + * which has more functionality and is more generic. It is probably + * a better fit for your use case. The sole reason for c10::utils::bitset + * to exist is that std::bitset misses a find_first_set() method. + */ +struct bitset final { + private: +#if defined(_MSC_VER) + // MSVCs _BitScanForward64 expects int64_t + using bitset_type = int64_t; +#else + // POSIX ffsll expects long long int + using bitset_type = long long int; +#endif + public: + static constexpr size_t NUM_BITS() { + return 8 * sizeof(bitset_type); + } + + constexpr bitset() noexcept = default; + constexpr bitset(const bitset&) noexcept = default; + constexpr bitset(bitset&&) noexcept = default; + // there is an issue for gcc 5.3.0 when define default function as constexpr + // see https://gcc.gnu.org/bugzilla/show_bug.cgi?id=68754. + bitset& operator=(const bitset&) noexcept = default; + bitset& operator=(bitset&&) noexcept = default; + ~bitset() = default; + + constexpr void set(size_t index) noexcept { + bitset_ |= (static_cast(1) << index); + } + + constexpr void unset(size_t index) noexcept { + bitset_ &= ~(static_cast(1) << index); + } + + constexpr bool get(size_t index) const noexcept { + return bitset_ & (static_cast(1) << index); + } + + constexpr bool is_entirely_unset() const noexcept { + return 0 == bitset_; + } + + // Call the given functor with the index of each bit that is set + template + // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward) + void for_each_set_bit(Func&& func) const { + bitset cur = *this; + size_t index = cur.find_first_set(); + while (0 != index) { + // -1 because find_first_set() is not one-indexed. + index -= 1; + func(index); + cur.unset(index); + index = cur.find_first_set(); + } + } + + private: + // Return the index of the first set bit. The returned index is one-indexed + // (i.e. if the very first bit is set, this function returns '1'), and a + // return of '0' means that there was no bit set. + size_t find_first_set() const { +#if defined(_MSC_VER) && (defined(_M_X64) || defined(_M_ARM64)) + unsigned long result; + bool has_bits_set = (0 != _BitScanForward64(&result, bitset_)); + if (!has_bits_set) { + return 0; + } + return result + 1; +#elif defined(_MSC_VER) && defined(_M_IX86) + unsigned long result; + if (static_cast(bitset_) != 0) { + bool has_bits_set = + (0 != _BitScanForward(&result, static_cast(bitset_))); + if (!has_bits_set) { + return 0; + } + return result + 1; + } else { + bool has_bits_set = + (0 != _BitScanForward(&result, static_cast(bitset_ >> 32))); + if (!has_bits_set) { + return 32; + } + return result + 33; + } +#else + return __builtin_ffsll(bitset_); +#endif + } + + friend bool operator==(bitset lhs, bitset rhs) noexcept { + return lhs.bitset_ == rhs.bitset_; + } + + bitset_type bitset_{0}; +}; + +inline bool operator!=(bitset lhs, bitset rhs) noexcept { + return !(lhs == rhs); +} + +} // namespace c10::utils + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/C++17.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/C++17.h new file mode 100644 index 0000000000000000000000000000000000000000..f9e010daa58b3e172456412c6478bfb1006b3e3b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/C++17.h @@ -0,0 +1,75 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#ifndef C10_UTIL_CPP17_H_ +#define C10_UTIL_CPP17_H_ + +#include +#include +#include +#include +#include + +#if !defined(__clang__) && !defined(_MSC_VER) && defined(__GNUC__) && \ + __GNUC__ < 9 +#error \ + "You're trying to build PyTorch with a too old version of GCC. We need GCC 9 or later." +#endif + +#if defined(__clang__) && __clang_major__ < 9 +#error \ + "You're trying to build PyTorch with a too old version of Clang. We need Clang 9 or later." +#endif + +#if (defined(_MSC_VER) && (!defined(_MSVC_LANG) || _MSVC_LANG < 201703L)) || \ + (!defined(_MSC_VER) && __cplusplus < 201703L) +#error You need C++17 to compile PyTorch +#endif + +#if defined(_WIN32) && (defined(min) || defined(max)) +#error Macro clash with min and max -- define NOMINMAX when compiling your program on Windows +#endif + +/* + * This header adds some polyfills with C++17 functionality + */ + +namespace c10 { + +namespace guts { + +#if defined(__HIP__) + +// Implementation from http://en.cppreference.com/w/cpp/utility/apply (but +// modified) +// TODO This is an incomplete implementation of std::apply, not working for +// member functions. +namespace detail { +template +C10_HOST_DEVICE constexpr auto apply_impl( + F&& f, + Tuple&& t, + std::index_sequence) { + return std::forward(f)(std::get(std::forward(t))...); +} +} // namespace detail + +template +C10_HOST_DEVICE constexpr auto apply(F&& f, Tuple&& t) { + return detail::apply_impl( + std::forward(f), + std::forward(t), + std::make_index_sequence< + std::tuple_size>::value>{}); +} + +#endif + +} // namespace guts + +} // namespace c10 + +#endif // C10_UTIL_CPP17_H_ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/CallOnce.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/CallOnce.h new file mode 100644 index 0000000000000000000000000000000000000000..0037755f64a8fce82ae816391559c2123e3ad1cf --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/CallOnce.h @@ -0,0 +1,75 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +#include +#include +#include +#include + +namespace c10 { + +// custom c10 call_once implementation to avoid the deadlock in std::call_once. +// The implementation here is a simplified version from folly and likely much +// much higher memory footprint. +template +inline void call_once(Flag& flag, F&& f, Args&&... args) { + if (C10_LIKELY(flag.test_once())) { + return; + } + flag.call_once_slow(std::forward(f), std::forward(args)...); +} + +class once_flag { + public: +#ifndef _WIN32 + // running into build error on MSVC. Can't seem to get a repro locally so I'm + // just avoiding constexpr + // + // C:/actions-runner/_work/pytorch/pytorch\c10/util/CallOnce.h(26): error: + // defaulted default constructor cannot be constexpr because the + // corresponding implicitly declared default constructor would not be + // constexpr 1 error detected in the compilation of + // "C:/actions-runner/_work/pytorch/pytorch/aten/src/ATen/cuda/cub.cu". + constexpr +#endif + once_flag() noexcept = default; + once_flag(const once_flag&) = delete; + once_flag& operator=(const once_flag&) = delete; + once_flag(once_flag&&) = delete; + once_flag& operator=(once_flag&&) = delete; + ~once_flag() = default; + bool test_once() { + return init_.load(std::memory_order_acquire); + } + + private: + template + friend void call_once(Flag& flag, F&& f, Args&&... args); + + template + void call_once_slow(F&& f, Args&&... args) { + std::lock_guard guard(mutex_); + if (init_.load(std::memory_order_relaxed)) { + return; + } + std::invoke(std::forward(f), std::forward(args)...); + init_.store(true, std::memory_order_release); + } + + void reset_once() { + init_.store(false, std::memory_order_release); + } + + private: + std::mutex mutex_; + std::atomic init_{false}; +}; + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/ConstexprCrc.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/ConstexprCrc.h new file mode 100644 index 0000000000000000000000000000000000000000..56dd979ce833087e264e6e8faef8563019fa3ea5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/ConstexprCrc.h @@ -0,0 +1,137 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include + +namespace c10::util { + +namespace detail { +// NOLINTNEXTLINE(*c-arrays*) +constexpr uint64_t crc64_table[] = { + 0x0000000000000000, 0x7ad870c830358979, 0xf5b0e190606b12f2, + 0x8f689158505e9b8b, 0xc038e5739841b68f, 0xbae095bba8743ff6, + 0x358804e3f82aa47d, 0x4f50742bc81f2d04, 0xab28ecb46814fe75, + 0xd1f09c7c5821770c, 0x5e980d24087fec87, 0x24407dec384a65fe, + 0x6b1009c7f05548fa, 0x11c8790fc060c183, 0x9ea0e857903e5a08, + 0xe478989fa00bd371, 0x7d08ff3b88be6f81, 0x07d08ff3b88be6f8, + 0x88b81eabe8d57d73, 0xf2606e63d8e0f40a, 0xbd301a4810ffd90e, + 0xc7e86a8020ca5077, 0x4880fbd87094cbfc, 0x32588b1040a14285, + 0xd620138fe0aa91f4, 0xacf86347d09f188d, 0x2390f21f80c18306, + 0x594882d7b0f40a7f, 0x1618f6fc78eb277b, 0x6cc0863448deae02, + 0xe3a8176c18803589, 0x997067a428b5bcf0, 0xfa11fe77117cdf02, + 0x80c98ebf2149567b, 0x0fa11fe77117cdf0, 0x75796f2f41224489, + 0x3a291b04893d698d, 0x40f16bccb908e0f4, 0xcf99fa94e9567b7f, + 0xb5418a5cd963f206, 0x513912c379682177, 0x2be1620b495da80e, + 0xa489f35319033385, 0xde51839b2936bafc, 0x9101f7b0e12997f8, + 0xebd98778d11c1e81, 0x64b116208142850a, 0x1e6966e8b1770c73, + 0x8719014c99c2b083, 0xfdc17184a9f739fa, 0x72a9e0dcf9a9a271, + 0x08719014c99c2b08, 0x4721e43f0183060c, 0x3df994f731b68f75, + 0xb29105af61e814fe, 0xc849756751dd9d87, 0x2c31edf8f1d64ef6, + 0x56e99d30c1e3c78f, 0xd9810c6891bd5c04, 0xa3597ca0a188d57d, + 0xec09088b6997f879, 0x96d1784359a27100, 0x19b9e91b09fcea8b, + 0x636199d339c963f2, 0xdf7adabd7a6e2d6f, 0xa5a2aa754a5ba416, + 0x2aca3b2d1a053f9d, 0x50124be52a30b6e4, 0x1f423fcee22f9be0, + 0x659a4f06d21a1299, 0xeaf2de5e82448912, 0x902aae96b271006b, + 0x74523609127ad31a, 0x0e8a46c1224f5a63, 0x81e2d7997211c1e8, + 0xfb3aa75142244891, 0xb46ad37a8a3b6595, 0xceb2a3b2ba0eecec, + 0x41da32eaea507767, 0x3b024222da65fe1e, 0xa2722586f2d042ee, + 0xd8aa554ec2e5cb97, 0x57c2c41692bb501c, 0x2d1ab4dea28ed965, + 0x624ac0f56a91f461, 0x1892b03d5aa47d18, 0x97fa21650afae693, + 0xed2251ad3acf6fea, 0x095ac9329ac4bc9b, 0x7382b9faaaf135e2, + 0xfcea28a2faafae69, 0x8632586aca9a2710, 0xc9622c4102850a14, + 0xb3ba5c8932b0836d, 0x3cd2cdd162ee18e6, 0x460abd1952db919f, + 0x256b24ca6b12f26d, 0x5fb354025b277b14, 0xd0dbc55a0b79e09f, + 0xaa03b5923b4c69e6, 0xe553c1b9f35344e2, 0x9f8bb171c366cd9b, + 0x10e3202993385610, 0x6a3b50e1a30ddf69, 0x8e43c87e03060c18, + 0xf49bb8b633338561, 0x7bf329ee636d1eea, 0x012b592653589793, + 0x4e7b2d0d9b47ba97, 0x34a35dc5ab7233ee, 0xbbcbcc9dfb2ca865, + 0xc113bc55cb19211c, 0x5863dbf1e3ac9dec, 0x22bbab39d3991495, + 0xadd33a6183c78f1e, 0xd70b4aa9b3f20667, 0x985b3e827bed2b63, + 0xe2834e4a4bd8a21a, 0x6debdf121b863991, 0x1733afda2bb3b0e8, + 0xf34b37458bb86399, 0x8993478dbb8deae0, 0x06fbd6d5ebd3716b, + 0x7c23a61ddbe6f812, 0x3373d23613f9d516, 0x49aba2fe23cc5c6f, + 0xc6c333a67392c7e4, 0xbc1b436e43a74e9d, 0x95ac9329ac4bc9b5, + 0xef74e3e19c7e40cc, 0x601c72b9cc20db47, 0x1ac40271fc15523e, + 0x5594765a340a7f3a, 0x2f4c0692043ff643, 0xa02497ca54616dc8, + 0xdafce7026454e4b1, 0x3e847f9dc45f37c0, 0x445c0f55f46abeb9, + 0xcb349e0da4342532, 0xb1eceec59401ac4b, 0xfebc9aee5c1e814f, + 0x8464ea266c2b0836, 0x0b0c7b7e3c7593bd, 0x71d40bb60c401ac4, + 0xe8a46c1224f5a634, 0x927c1cda14c02f4d, 0x1d148d82449eb4c6, + 0x67ccfd4a74ab3dbf, 0x289c8961bcb410bb, 0x5244f9a98c8199c2, + 0xdd2c68f1dcdf0249, 0xa7f41839ecea8b30, 0x438c80a64ce15841, + 0x3954f06e7cd4d138, 0xb63c61362c8a4ab3, 0xcce411fe1cbfc3ca, + 0x83b465d5d4a0eece, 0xf96c151de49567b7, 0x76048445b4cbfc3c, + 0x0cdcf48d84fe7545, 0x6fbd6d5ebd3716b7, 0x15651d968d029fce, + 0x9a0d8ccedd5c0445, 0xe0d5fc06ed698d3c, 0xaf85882d2576a038, + 0xd55df8e515432941, 0x5a3569bd451db2ca, 0x20ed197575283bb3, + 0xc49581ead523e8c2, 0xbe4df122e51661bb, 0x3125607ab548fa30, + 0x4bfd10b2857d7349, 0x04ad64994d625e4d, 0x7e7514517d57d734, + 0xf11d85092d094cbf, 0x8bc5f5c11d3cc5c6, 0x12b5926535897936, + 0x686de2ad05bcf04f, 0xe70573f555e26bc4, 0x9ddd033d65d7e2bd, + 0xd28d7716adc8cfb9, 0xa85507de9dfd46c0, 0x273d9686cda3dd4b, + 0x5de5e64efd965432, 0xb99d7ed15d9d8743, 0xc3450e196da80e3a, + 0x4c2d9f413df695b1, 0x36f5ef890dc31cc8, 0x79a59ba2c5dc31cc, + 0x037deb6af5e9b8b5, 0x8c157a32a5b7233e, 0xf6cd0afa9582aa47, + 0x4ad64994d625e4da, 0x300e395ce6106da3, 0xbf66a804b64ef628, + 0xc5bed8cc867b7f51, 0x8aeeace74e645255, 0xf036dc2f7e51db2c, + 0x7f5e4d772e0f40a7, 0x05863dbf1e3ac9de, 0xe1fea520be311aaf, + 0x9b26d5e88e0493d6, 0x144e44b0de5a085d, 0x6e963478ee6f8124, + 0x21c640532670ac20, 0x5b1e309b16452559, 0xd476a1c3461bbed2, + 0xaeaed10b762e37ab, 0x37deb6af5e9b8b5b, 0x4d06c6676eae0222, + 0xc26e573f3ef099a9, 0xb8b627f70ec510d0, 0xf7e653dcc6da3dd4, + 0x8d3e2314f6efb4ad, 0x0256b24ca6b12f26, 0x788ec2849684a65f, + 0x9cf65a1b368f752e, 0xe62e2ad306bafc57, 0x6946bb8b56e467dc, + 0x139ecb4366d1eea5, 0x5ccebf68aecec3a1, 0x2616cfa09efb4ad8, + 0xa97e5ef8cea5d153, 0xd3a62e30fe90582a, 0xb0c7b7e3c7593bd8, + 0xca1fc72bf76cb2a1, 0x45775673a732292a, 0x3faf26bb9707a053, + 0x70ff52905f188d57, 0x0a2722586f2d042e, 0x854fb3003f739fa5, + 0xff97c3c80f4616dc, 0x1bef5b57af4dc5ad, 0x61372b9f9f784cd4, + 0xee5fbac7cf26d75f, 0x9487ca0fff135e26, 0xdbd7be24370c7322, + 0xa10fceec0739fa5b, 0x2e675fb4576761d0, 0x54bf2f7c6752e8a9, + 0xcdcf48d84fe75459, 0xb71738107fd2dd20, 0x387fa9482f8c46ab, + 0x42a7d9801fb9cfd2, 0x0df7adabd7a6e2d6, 0x772fdd63e7936baf, + 0xf8474c3bb7cdf024, 0x829f3cf387f8795d, 0x66e7a46c27f3aa2c, + 0x1c3fd4a417c62355, 0x935745fc4798b8de, 0xe98f353477ad31a7, + 0xa6df411fbfb21ca3, 0xdc0731d78f8795da, 0x536fa08fdfd90e51, + 0x29b7d047efec8728, +}; + +inline constexpr uint64_t crc64impl( + uint64_t accumulator, + const char* data, + size_t size) { + for (size_t i = 0; i < size; ++i) { + accumulator = + crc64_table[(accumulator ^ data[i]) & 0xFF] ^ (accumulator >> 8); + } + return accumulator; +} +} // namespace detail + +struct crc64_t final : IdWrapper { + constexpr crc64_t(uint64_t checksum) : IdWrapper(checksum) {} + constexpr uint64_t checksum() const { + return this->underlyingId(); + } +}; + +// CRC64 with Jones coefficients and an init value of 0. +inline constexpr crc64_t crc64(const char* str, size_t size) { + return crc64_t{detail::crc64impl(0, str, size)}; +} + +inline constexpr crc64_t crc64(std::string_view str) { + return crc64(str.data(), str.size()); +} +} // namespace c10::util + +// Allow usage of crc64_t in std::unordered_set +C10_DEFINE_HASH_FOR_IDWRAPPER(c10::util::crc64_t) + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/DeadlockDetection.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/DeadlockDetection.h new file mode 100644 index 0000000000000000000000000000000000000000..5fd611a2add7563d8c6ca6fba28e704765c4ec79 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/DeadlockDetection.h @@ -0,0 +1,57 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +/// This file provides some simple utilities for detecting common deadlocks in +/// PyTorch. For now, we focus exclusively on detecting Python GIL deadlocks, +/// as the GIL is a wide ranging lock that is taken out in many situations. +/// The basic strategy is before performing an operation that may block, you +/// can use TORCH_ASSERT_NO_GIL_WITHOUT_PYTHON_DEP() to assert that the GIL is +/// not held. This macro is to be used in contexts where no static dependency +/// on Python is available (we will handle indirecting a virtual call for you). +/// +/// If the GIL is held by a torchdeploy interpreter, we always report false. +/// If you are in a context where Python bindings are available, it's better +/// to directly assert on PyGILState_Check (as it avoids a vcall and also +/// works correctly with torchdeploy.) + +#define TORCH_ASSERT_NO_GIL_WITHOUT_PYTHON_DEP() \ + TORCH_INTERNAL_ASSERT( \ + !c10::impl::check_python_gil(), \ + "Holding GIL before a blocking operation! Please release the GIL before blocking, or see https://github.com/pytorch/pytorch/issues/56297 for how to release the GIL for destructors of objects") + +namespace c10::impl { + +C10_API bool check_python_gil(); + +struct C10_API PythonGILHooks { + virtual ~PythonGILHooks() = default; + // Returns true if we hold the GIL. If not linked against Python we + // always return false. + virtual bool check_python_gil() const = 0; +}; + +C10_API void SetPythonGILHooks(PythonGILHooks* factory); + +// DO NOT call this registerer from a torch deploy instance! You will clobber +// other registrations +struct C10_API PythonGILHooksRegisterer { + explicit PythonGILHooksRegisterer(PythonGILHooks* factory) { + SetPythonGILHooks(factory); + } + PythonGILHooksRegisterer(const PythonGILHooksRegisterer&) = delete; + PythonGILHooksRegisterer(PythonGILHooksRegisterer&&) = delete; + PythonGILHooksRegisterer& operator=(const PythonGILHooksRegisterer&) = delete; + PythonGILHooksRegisterer& operator=(PythonGILHooksRegisterer&&) = delete; + ~PythonGILHooksRegisterer() { + SetPythonGILHooks(nullptr); + } +}; + +} // namespace c10::impl + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Deprecated.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Deprecated.h new file mode 100644 index 0000000000000000000000000000000000000000..ccd1ac50400d3dcdc160c42e8745bac7139c8217 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Deprecated.h @@ -0,0 +1,7 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/DimVector.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/DimVector.h new file mode 100644 index 0000000000000000000000000000000000000000..682b8f364a2094c0feec2b6c19a8e2e54d296ee1 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/DimVector.h @@ -0,0 +1,22 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include + +namespace c10 { + +constexpr size_t kDimVectorStaticSize = C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE; + +/// A container for sizes or strides +using DimVector = SmallVector; +using SymDimVector = SmallVector; + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/DynamicCounter.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/DynamicCounter.h new file mode 100644 index 0000000000000000000000000000000000000000..37e0af4319435c223442cc52d2b34d8e12e2715b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/DynamicCounter.h @@ -0,0 +1,54 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include + +#include + +namespace c10::monitor { + +class C10_API DynamicCounter { + public: + using Callback = std::function; + + // Creates a dynamic counter that can be queried at any point in time by + // multiple backends. Only one counter with a given key can exist at any point + // in time. + // + // The callback is invoked every time the counter is queried. + // The callback must be thread-safe. + // The callback must not throw. + // The callback must not block. + DynamicCounter(std::string_view key, Callback getCounterCallback); + + // Unregisters the callback. + // Waits for all ongoing callback invocations to finish. + ~DynamicCounter(); + + private: + struct Guard; + std::unique_ptr guard_; +}; + +namespace detail { +class DynamicCounterBackendIf { + public: + virtual ~DynamicCounterBackendIf() = default; + + virtual void registerCounter( + std::string_view key, + DynamicCounter::Callback getCounterCallback) = 0; + // MUST wait for all ongoing callback invocations to finish + virtual void unregisterCounter(std::string_view key) = 0; +}; + +void C10_API registerDynamicCounterBackend( + std::unique_ptr /*backend*/); +} // namespace detail +} // namespace c10::monitor + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Enumerate.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Enumerate.h new file mode 100644 index 0000000000000000000000000000000000000000..441e158ccc4ab86cd7c19a25963d1da7005c82e9 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Enumerate.h @@ -0,0 +1,164 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/* + * Ported from folly/container/Enumerate.h + */ + +#pragma once + +#include +#include + +#ifdef _WIN32 +#include // @manual +using ssize_t = SSIZE_T; +#endif + +#include + +/** + * Similar to Python's enumerate(), enumerate() can be used to + * iterate a range with a for-range loop, and it also allows to + * retrieve the count of iterations so far. Can be used in constexpr + * context. + * + * For example: + * + * for (auto&& [index, element] : enumerate(vec)) { + * // index is a const reference to a size_t containing the iteration count. + * // element is a reference to the type contained within vec, mutable + * // unless vec is const. + * } + * + * If the binding is const, the element reference is too. + * + * for (const auto&& [index, element] : enumerate(vec)) { + * // element is always a const reference. + * } + * + * It can also be used as follows: + * + * for (auto&& it : enumerate(vec)) { + * // *it is a reference to the current element. Mutable unless vec is const. + * // it->member can be used as well. + * // it.index contains the iteration count. + * } + * + * As before, const auto&& it can also be used. + */ + +namespace c10 { + +namespace detail { + +template +struct MakeConst { + using type = const T; +}; +template +struct MakeConst { + using type = const T&; +}; +template +struct MakeConst { + using type = const T*; +}; + +template +class Enumerator { + public: + constexpr explicit Enumerator(Iterator it) : it_(std::move(it)) {} + + class Proxy { + public: + using difference_type = ssize_t; + using value_type = typename std::iterator_traits::value_type; + using reference = typename std::iterator_traits::reference; + using pointer = typename std::iterator_traits::pointer; + using iterator_category = std::input_iterator_tag; + + C10_ALWAYS_INLINE constexpr explicit Proxy(const Enumerator& e) + : index(e.idx_), element(*e.it_) {} + + // Non-const Proxy: Forward constness from Iterator. + C10_ALWAYS_INLINE constexpr reference operator*() { + return element; + } + C10_ALWAYS_INLINE constexpr pointer operator->() { + return std::addressof(element); + } + + // Const Proxy: Force const references. + C10_ALWAYS_INLINE constexpr typename MakeConst::type operator*() + const { + return element; + } + C10_ALWAYS_INLINE constexpr typename MakeConst::type operator->() + const { + return std::addressof(element); + } + + public: + size_t index; + reference element; + }; + + C10_ALWAYS_INLINE constexpr Proxy operator*() const { + return Proxy(*this); + } + + C10_ALWAYS_INLINE constexpr Enumerator& operator++() { + ++it_; + ++idx_; + return *this; + } + + template + C10_ALWAYS_INLINE constexpr bool operator==( + const Enumerator& rhs) const { + return it_ == rhs.it_; + } + + template + C10_ALWAYS_INLINE constexpr bool operator!=( + const Enumerator& rhs) const { + return !(it_ == rhs.it_); + } + + private: + template + friend class Enumerator; + + Iterator it_; + size_t idx_ = 0; +}; + +template +class RangeEnumerator { + Range r_; + using BeginIteratorType = decltype(std::declval().begin()); + using EndIteratorType = decltype(std::declval().end()); + + public: + // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved) + constexpr explicit RangeEnumerator(Range&& r) : r_(std::forward(r)) {} + + constexpr Enumerator begin() { + return Enumerator(r_.begin()); + } + constexpr Enumerator end() { + return Enumerator(r_.end()); + } +}; + +} // namespace detail + +template +constexpr detail::RangeEnumerator enumerate(Range&& r) { + return detail::RangeEnumerator(std::forward(r)); +} + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Exception.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Exception.h new file mode 100644 index 0000000000000000000000000000000000000000..c6b4a7fa25013fa413504a69fb177b0e1d6febcc --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Exception.h @@ -0,0 +1,875 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#ifndef C10_UTIL_EXCEPTION_H_ +#define C10_UTIL_EXCEPTION_H_ + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#if defined(_MSC_VER) && _MSC_VER <= 1900 +#define __func__ __FUNCTION__ +#endif + +namespace c10 { + +/// The primary ATen error class. +/// Provides a complete error message with source location information via +/// `what()`, and a more concise message via `what_without_backtrace()`. +/// Don't throw this directly; use TORCH_CHECK/TORCH_INTERNAL_ASSERT instead. +/// +/// NB: c10::Error is handled specially by the default torch to suppress the +/// backtrace, see torch/csrc/Exceptions.h +class C10_API Error : public std::exception { + private: + // The actual error message. + std::string msg_; + + // Context for the message (in order of decreasing specificity). Context will + // be automatically formatted appropriately, so it is not necessary to add + // extra leading/trailing newlines to strings inside this vector + std::vector context_; + + // The C++ backtrace at the point when this exception was raised. This + // may be empty if there is no valid backtrace. (We don't use optional + // here to reduce the dependencies this file has.) + Backtrace backtrace_; + + // These two are derived fields from msg_stack_ and backtrace_, but we need + // fields for the strings so that we can return a const char* (as the + // signature of std::exception requires). Currently, the invariant + // is that these fields are ALWAYS populated consistently with respect + // to msg_stack_ and backtrace_. + mutable OptimisticLazy what_; + std::string what_without_backtrace_; + + // This is a little debugging trick: you can stash a relevant pointer + // in caller, and then when you catch the exception, you can compare + // against pointers you have on hand to get more information about + // where the exception came from. In Caffe2, this is used to figure + // out which operator raised an exception. + const void* caller_; + + public: + // PyTorch-style Error constructor. NB: the implementation of this + // is actually in Logging.cpp + Error(SourceLocation source_location, std::string msg); + + // Caffe2-style error message + Error( + const char* file, + const uint32_t line, + const char* condition, + const std::string& msg, + Backtrace backtrace, + const void* caller = nullptr); + + // Base constructor + Error( + std::string msg, + Backtrace backtrace = nullptr, + const void* caller = nullptr); + + // Add some new context to the message stack. The last added context + // will be formatted at the end of the context list upon printing. + // WARNING: This method is O(n) in the size of the stack, so don't go + // wild adding a ridiculous amount of context to error messages. + void add_context(std::string msg); + + const std::string& msg() const { + return msg_; + } + + const std::vector& context() const { + return context_; + } + + const Backtrace& backtrace() const; + + /// Returns the complete error message, including the source location. + /// The returned pointer is invalidated if you call add_context() on + /// this object. + const char* what() const noexcept override; + + const void* caller() const noexcept { + return caller_; + } + + /// Returns only the error message string, without source location. + /// The returned pointer is invalidated if you call add_context() on + /// this object. + virtual const char* what_without_backtrace() const noexcept { + return what_without_backtrace_.c_str(); + } + + private: + void refresh_what(); + std::string compute_what(bool include_backtrace) const; +}; + +class C10_API Warning { + public: + class C10_API UserWarning{}; + class C10_API DeprecationWarning{}; + + using warning_variant_t = std::variant; + + Warning( + warning_variant_t type, + const SourceLocation& source_location, + std::string msg, + bool verbatim); + + Warning( + warning_variant_t type, + SourceLocation source_location, + const char* msg, + bool verbatim); + + Warning( + warning_variant_t type, + SourceLocation source_location, + ::c10::detail::CompileTimeEmptyString msg, + bool verbatim); + + // Getters for members + warning_variant_t type() const; + const SourceLocation& source_location() const; + const std::string& msg() const; + bool verbatim() const; + + private: + // The type of warning + warning_variant_t type_; + + // Where the warning happened. + SourceLocation source_location_; + + // The actual warning message. + std::string msg_; + + // See note: [Verbatim Warnings] + bool verbatim_; +}; + +using UserWarning = Warning::UserWarning; +using DeprecationWarning = Warning::DeprecationWarning; + +// Issue a warning with a given message. Dispatched to the current +// warning handler. +void C10_API warn(const Warning& warning); + +class C10_API WarningHandler { + public: + virtual ~WarningHandler() = default; + /// The default warning handler. Prints the message to stderr. + virtual void process(const Warning& warning); +}; + +namespace WarningUtils { + +// Note: [Verbatim Warnings] +// Warnings originating in C++ code can appear out-of-place to Python users: +// a user runs a line in Python, but the warning references a line in C++. +// Some parts of PyTorch, like the JIT, are cognizant of this mismatch +// and take care to map warnings back to the user's program, but most +// of PyTorch simply throws a context-free warning. To allow warning +// handlers to add context where appropriate, warn takes the +// "verbatim" flag. When this is false a warning handler might append +// the C++ warning to a Python warning message that relates the warning +// back to the user's program. Callers who have already accounted for +// context in their warnings should set verbatim to true so their warnings +// appear without modification. + +/// Sets the global warning handler. This is not thread-safe, so it should +/// generally be called once during initialization or while holding the GIL +/// for programs that use python. +/// User is responsible for keeping the WarningHandler alive until +/// it is not needed. +C10_API void set_warning_handler(WarningHandler* handler) noexcept(true); +/// Gets the global warning handler. +C10_API WarningHandler* get_warning_handler() noexcept(true); + +class C10_API WarningHandlerGuard { + WarningHandler* prev_handler_; + + public: + WarningHandlerGuard(WarningHandler* new_handler) + : prev_handler_(c10::WarningUtils::get_warning_handler()) { + c10::WarningUtils::set_warning_handler(new_handler); + } + WarningHandlerGuard(WarningHandlerGuard&& other) = delete; + WarningHandlerGuard(const WarningHandlerGuard&) = delete; + WarningHandlerGuard& operator=(const WarningHandlerGuard&) = delete; + WarningHandlerGuard& operator=(WarningHandlerGuard&&) = delete; + ~WarningHandlerGuard() { + c10::WarningUtils::set_warning_handler(prev_handler_); + } +}; + +/// The TORCH_WARN_ONCE macro is difficult to test for. Use +/// setWarnAlways(true) to turn it into TORCH_WARN, which can be +/// tested for more easily. +C10_API void set_warnAlways(bool /*setting*/) noexcept(true); +C10_API bool get_warnAlways() noexcept(true); + +// A RAII guard that sets warn_always (not thread-local) on +// construction, and sets it back to the original value upon destruction. +struct C10_API WarnAlways { + public: + explicit WarnAlways(bool setting = true); + ~WarnAlways(); + + private: + bool prev_setting; +}; + +} // namespace WarningUtils + +// Like Error, but we always report the C++ backtrace, instead of only +// reporting when TORCH_SHOW_CPP_STACKTRACES +class C10_API ErrorAlwaysShowCppStacktrace : public Error { + using Error::Error; + const char* what_without_backtrace() const noexcept override { + return what(); + } +}; + +// Used in ATen for out-of-bound indices that can reasonably only be detected +// lazily inside a kernel (See: advanced indexing). These turn into +// IndexError when they cross to Python. +class C10_API IndexError : public Error { + using Error::Error; +}; + +// Used in ATen for invalid values. These turn into +// ValueError when they cross to Python. +class C10_API ValueError : public Error { + using Error::Error; +}; + +// Used in ATen for invalid types. These turn into +// TypeError when they cross to Python. +class C10_API TypeError : public Error { + using Error::Error; +}; + +// Used in ATen for functionality that is not implemented. These turn into +// NotImplementedError when they cross to Python. +class C10_API NotImplementedError : public Error { + using Error::Error; +}; + +// Used in ATen for buffer-related errors, e.g. trying to create a DLPack of +// an unsupported device. These turn into BufferError when they cross to +// Python. +class C10_API BufferError : public Error { + using Error::Error; +}; + +// Used in ATen for non finite indices. These turn into +// ExitException when they cross to Python. +class C10_API EnforceFiniteError : public Error { + using Error::Error; +}; + +// Used in Onnxifi backend lowering. These turn into +// ExitException when they cross to Python. +class C10_API OnnxfiBackendSystemError : public Error { + using Error::Error; +}; + +// Used for numerical errors from the linalg module. These +// turn into LinAlgError when they cross into Python. +class C10_API LinAlgError : public Error { + using Error::Error; +}; + +class C10_API OutOfMemoryError : public Error { + using Error::Error; +}; + +// Used for handling syntactic errors in input arguments. +// These turn into SyntaxError when the cross into Python. +class C10_API SyntaxError : public Error { + using Error::Error; +}; + +// Raised when accelerator API call hits an error. +// These turn into AcceleratorError when the cross into Python +class C10_API AcceleratorError : public Error { + int32_t error_code; + + public: + AcceleratorError(SourceLocation loc, int32_t code, const std::string& msg) + : Error(loc, msg), error_code(code) {} + int32_t get_error_code() const { + return error_code; + } +}; + +// Base error type for all distributed errors. +// These turn into DistError when they cross into Python. +class C10_API DistError : public Error { + using Error::Error; +}; + +// Used for collective communication library errors from the distributed module. +// These turn into DistBackendError when they cross into Python. +class C10_API DistBackendError : public DistError { + using DistError::DistError; +}; + +// Used for errors originating from the store. +// These turn into DistStoreError when they cross into Python. +class C10_API DistStoreError : public DistError { + using DistError::DistError; +}; + +// Used for errors originating from the TCP/IP stack and not from collective +// libraries. These turn into DistNetworkError when they cross into Python. +class C10_API DistNetworkError : public DistError { + using DistError::DistError; +}; + +// Raised when a queue is empty and a non-blocking pop is called. +// Translated to torch.distributed.QueueEmptyError in Python +class C10_API DistQueueEmptyError : public DistStoreError { + using DistStoreError::DistStoreError; +}; + +// A utility function to return an exception std::string by prepending its +// exception type before its what() content +C10_API std::string GetExceptionString(const std::exception& e); + +} // namespace c10 + +// Private helper macro for implementing TORCH_INTERNAL_ASSERT and TORCH_CHECK +// +// Note: In the debug build With MSVC, __LINE__ might be of long type (a.k.a +// int32_t), which is different from the definition of `SourceLocation` that +// requires unsigned int (a.k.a uint32_t) and may cause a compile error with the +// message: error C2397: conversion from 'long' to 'uint32_t' requires a +// narrowing conversion Here the static cast is used to pass the build. if this +// is used inside a lambda the __func__ macro expands to operator(), which isn't +// very useful, but hard to fix in a macro so suppressing the warning. +#define C10_THROW_ERROR(err_type, msg) \ + throw ::c10::err_type( \ + {__func__, __FILE__, static_cast(__LINE__)}, msg) + +#define C10_BUILD_ERROR(err_type, msg) \ + ::c10::err_type({__func__, __FILE__, static_cast(__LINE__)}, msg) + +// Private helper macro for workaround MSVC misexpansion of nested macro +// invocations involving __VA_ARGS__. See +// https://stackoverflow.com/questions/5134523/msvc-doesnt-expand-va-args-correctly +#define C10_EXPAND_MSVC_WORKAROUND(x) x + +#include + +// ---------------------------------------------------------------------------- +// Error reporting macros +// ---------------------------------------------------------------------------- + +#ifdef STRIP_ERROR_MESSAGES +#define TORCH_RETHROW(e, ...) \ + do { \ + (void)e; /* Suppress unused variable warning */ \ + throw; \ + } while (false) +#else +#define TORCH_RETHROW(e, ...) \ + do { \ + e.add_context(::c10::str(__VA_ARGS__)); \ + throw; \ + } while (false) +#endif + +// A utility macro to provide assert()-like functionality; that is, enforcement +// of internal invariants in code. It supports an arbitrary number of extra +// arguments (evaluated only on failure), which will be printed in the assert +// failure message using operator<< (this is useful to print some variables +// which may be useful for debugging.) +// +// Usage: +// TORCH_INTERNAL_ASSERT(should_be_true); +// TORCH_INTERNAL_ASSERT(x == 0, "x = ", x); +// +// Assuming no bugs in PyTorch, the conditions tested by this macro should +// always be true; e.g., it should be possible to disable all of these +// conditions without changing observable user behavior. If you would like to +// do error reporting for user input, please use TORCH_CHECK instead. +// +// NOTE: It is SAFE to use this macro in production code; on failure, this +// simply raises an exception, it does NOT unceremoniously quit the process +// (unlike assert()). +// +#ifdef STRIP_ERROR_MESSAGES +#define TORCH_INTERNAL_ASSERT(cond, ...) \ + if (C10_UNLIKELY_OR_CONST(!(cond))) { \ + ::c10::detail::torchCheckFail( \ + __func__, \ + __FILE__, \ + static_cast(__LINE__), \ + #cond " INTERNAL ASSERT FAILED at " C10_STRINGIZE(__FILE__)); \ + } +#else +// It would be nice if we could build a combined string literal out of +// the TORCH_INTERNAL_ASSERT prefix and a user-provided string literal +// as the first argument, but there doesn't seem to be any good way to +// do that while still supporting having a first argument that isn't a +// string literal. +#define TORCH_INTERNAL_ASSERT(cond, ...) \ + if (C10_UNLIKELY_OR_CONST(!(cond))) { \ + ::c10::detail::torchInternalAssertFail( \ + __func__, \ + __FILE__, \ + static_cast(__LINE__), \ + #cond \ + " INTERNAL ASSERT FAILED at " C10_STRINGIZE(__FILE__) ":" C10_STRINGIZE( \ + __LINE__) ", please report a bug to PyTorch. ", \ + c10::str(__VA_ARGS__)); \ + } +#endif + +// A utility macro to make it easier to test for error conditions from user +// input. Like TORCH_INTERNAL_ASSERT, it supports an arbitrary number of extra +// arguments (evaluated only on failure), which will be printed in the error +// message using operator<< (e.g., you can pass any object which has +// operator<< defined. Most objects in PyTorch have these definitions!) +// +// Usage: +// TORCH_CHECK(should_be_true); // A default error message will be provided +// // in this case; but we recommend writing an +// // explicit error message, as it is more +// // user friendly. +// TORCH_CHECK(x == 0, "Expected x to be 0, but got ", x); +// +// On failure, this macro will raise an exception. If this exception propagates +// to Python, it will convert into a Python RuntimeError. +// +// NOTE: It is SAFE to use this macro in production code; on failure, this +// simply raises an exception, it does NOT unceremoniously quit the process +// (unlike CHECK() from glog.) +// +#define TORCH_CHECK_WITH(error_t, cond, ...) \ + TORCH_CHECK_WITH_MSG(error_t, cond, "", __VA_ARGS__) + +#ifdef STRIP_ERROR_MESSAGES +#define TORCH_CHECK_MSG(cond, type, ...) \ + (#cond #type " CHECK FAILED at " C10_STRINGIZE(__FILE__)) +#define TORCH_CHECK_WITH_MSG(error_t, cond, type, ...) \ + if (C10_UNLIKELY_OR_CONST(!(cond))) { \ + C10_THROW_ERROR(Error, TORCH_CHECK_MSG(cond, type, __VA_ARGS__)); \ + } +#else + +namespace c10::detail { +template +auto torchCheckMsgImpl(const char* /*msg*/, const Args&... args) { + return ::c10::str(args...); +} +inline C10_API const char* torchCheckMsgImpl(const char* msg) { + return msg; +} +// If there is just 1 user-provided C-string argument, use it. +inline C10_API const char* torchCheckMsgImpl( + const char* /*msg*/, + const char* args) { + return args; +} +} // namespace c10::detail + +#define TORCH_CHECK_MSG(cond, type, ...) \ + (::c10::detail::torchCheckMsgImpl( \ + "Expected " #cond \ + " to be true, but got false. " \ + "(Could this error message be improved? If so, " \ + "please report an enhancement request to PyTorch.)", \ + ##__VA_ARGS__)) +#define TORCH_CHECK_WITH_MSG(error_t, cond, type, ...) \ + if (C10_UNLIKELY_OR_CONST(!(cond))) { \ + C10_THROW_ERROR(error_t, TORCH_CHECK_MSG(cond, type, __VA_ARGS__)); \ + } +#endif + +namespace c10::detail { + +[[noreturn]] C10_API void torchCheckFail( + const char* func, + const char* file, + uint32_t line, + const std::string& msg); +[[noreturn]] C10_API void torchCheckFail( + const char* func, + const char* file, + uint32_t line, + const char* msg); + +// The c10::str() call that creates userMsg can have 1 of 3 return +// types depending on the number and types of arguments passed to +// TORCH_INTERNAL_ASSERT. 0 arguments will get a +// CompileTimeEmptyString, 1 const char * will be passed straight +// through, and anything else will get converted to std::string. +[[noreturn]] C10_API void torchInternalAssertFail( + const char* func, + const char* file, + uint32_t line, + const char* condMsg, + const char* userMsg); +[[noreturn]] inline C10_API void torchInternalAssertFail( + const char* func, + const char* file, + uint32_t line, + const char* condMsg, + ::c10::detail::CompileTimeEmptyString /*userMsg*/) { + torchCheckFail(func, file, line, condMsg); +} +[[noreturn]] C10_API void torchInternalAssertFail( + const char* func, + const char* file, + uint32_t line, + const char* condMsg, + const std::string& userMsg); + +} // namespace c10::detail + +#ifdef STANDALONE_TORCH_HEADER + +// TORCH_CHECK throws std::runtime_error instead of c10::Error which is +// useful when certain headers are used in a libtorch-independent way, +// e.g. when Vectorized is used in AOTInductor generated code. +#ifdef STRIP_ERROR_MESSAGES +#define TORCH_CHECK(cond, ...) \ + if (C10_UNLIKELY_OR_CONST(!(cond))) { \ + throw std::runtime_error(TORCH_CHECK_MSG( \ + cond, \ + "", \ + __func__, \ + ", ", \ + __FILE__, \ + ":", \ + __LINE__, \ + ", ", \ + __VA_ARGS__)); \ + } +#else +#define TORCH_CHECK(cond, ...) \ + if (C10_UNLIKELY_OR_CONST(!(cond))) { \ + throw std::runtime_error(TORCH_CHECK_MSG( \ + cond, \ + "", \ + __func__, \ + ", ", \ + __FILE__, \ + ":", \ + __LINE__, \ + ", ", \ + ##__VA_ARGS__)); \ + } +#endif + +#else + +#ifdef STRIP_ERROR_MESSAGES +#define TORCH_CHECK(cond, ...) \ + if (C10_UNLIKELY_OR_CONST(!(cond))) { \ + ::c10::detail::torchCheckFail( \ + __func__, \ + __FILE__, \ + static_cast(__LINE__), \ + TORCH_CHECK_MSG(cond, "", __VA_ARGS__)); \ + } +#else +#define TORCH_CHECK(cond, ...) \ + if (C10_UNLIKELY_OR_CONST(!(cond))) { \ + ::c10::detail::torchCheckFail( \ + __func__, \ + __FILE__, \ + static_cast(__LINE__), \ + TORCH_CHECK_MSG(cond, "", ##__VA_ARGS__)); \ + } +#endif + +#endif + +// An utility macro that does what `TORCH_CHECK` does if compiled in the host +// code, otherwise does nothing. Supposed to be used in the code shared between +// host and device code as an alternative for `TORCH_CHECK`. +#if defined(__CUDACC__) || defined(__HIPCC__) +#define TORCH_CHECK_IF_NOT_ON_CUDA(cond, ...) +#else +#define TORCH_CHECK_IF_NOT_ON_CUDA(cond, ...) TORCH_CHECK(cond, ##__VA_ARGS__) +#endif + +// Debug only version of TORCH_INTERNAL_ASSERT. This macro only checks in debug +// build, and does nothing in release build. It is appropriate to use +// in situations where you want to add an assert to a hotpath, but it is +// too expensive to run this assert on production builds. +#ifdef NDEBUG +// Optimized version - generates no code. +#define TORCH_INTERNAL_ASSERT_DEBUG_ONLY(...) \ + while (false) \ + C10_EXPAND_MSVC_WORKAROUND(TORCH_INTERNAL_ASSERT(__VA_ARGS__)) +#else +#define TORCH_INTERNAL_ASSERT_DEBUG_ONLY(...) \ + C10_EXPAND_MSVC_WORKAROUND(TORCH_INTERNAL_ASSERT(__VA_ARGS__)) +#endif + +// TODO: We're going to get a lot of similar looking string literals +// this way; check if this actually affects binary size. + +// Like TORCH_CHECK, but raises LinAlgError instead of Error. +#define TORCH_CHECK_LINALG(cond, ...) \ + TORCH_CHECK_WITH_MSG(LinAlgError, cond, "LINALG", __VA_ARGS__) + +// Like TORCH_CHECK, but raises IndexErrors instead of Errors. +#define TORCH_CHECK_INDEX(cond, ...) \ + TORCH_CHECK_WITH_MSG(IndexError, cond, "INDEX", __VA_ARGS__) + +// Like TORCH_CHECK, but raises ValueErrors instead of Errors. +#define TORCH_CHECK_VALUE(cond, ...) \ + TORCH_CHECK_WITH_MSG(ValueError, cond, "VALUE", __VA_ARGS__) + +// Like TORCH_CHECK, but raises TypeErrors instead of Errors. +#define TORCH_CHECK_TYPE(cond, ...) \ + TORCH_CHECK_WITH_MSG(TypeError, cond, "TYPE", __VA_ARGS__) + +// Like TORCH_CHECK, but raises NotImplementedErrors instead of Errors. +#define TORCH_CHECK_NOT_IMPLEMENTED(cond, ...) \ + TORCH_CHECK_WITH_MSG(NotImplementedError, cond, "TYPE", __VA_ARGS__) + +// Like TORCH_CHECK, but raises BufferError instead of Errors. +#define TORCH_CHECK_BUFFER(cond, ...) \ + TORCH_CHECK_WITH_MSG(BufferError, cond, "TYPE", __VA_ARGS__) + +#define TORCH_CHECK_ALWAYS_SHOW_CPP_STACKTRACE(cond, ...) \ + TORCH_CHECK_WITH_MSG( \ + ErrorAlwaysShowCppStacktrace, cond, "TYPE", ##__VA_ARGS__) + +#ifdef STRIP_ERROR_MESSAGES +#define WARNING_MESSAGE_STRING(...) \ + ::c10::detail::CompileTimeEmptyString {} +#else +#define WARNING_MESSAGE_STRING(...) ::c10::str(__VA_ARGS__) +#endif + +// Report a warning to the user. Accepts an arbitrary number of extra +// arguments which are concatenated into the warning message using operator<< +// +#ifdef DISABLE_WARN +#define _TORCH_WARN_WITH(...) ((void)0); +#else +#define _TORCH_WARN_WITH(warning_t, ...) \ + ::c10::warn(::c10::Warning( \ + warning_t(), \ + {__func__, __FILE__, static_cast(__LINE__)}, \ + WARNING_MESSAGE_STRING(__VA_ARGS__), \ + false)); +#endif + +#define TORCH_WARN(...) _TORCH_WARN_WITH(::c10::UserWarning, __VA_ARGS__); + +#define TORCH_WARN_DEPRECATION(...) \ + _TORCH_WARN_WITH(::c10::DeprecationWarning, __VA_ARGS__); + +// Report a warning to the user only once. Accepts an arbitrary number of extra +// arguments which are concatenated into the warning message using operator<< +// +#define _TORCH_WARN_ONCE(...) \ + [[maybe_unused]] static const auto C10_ANONYMOUS_VARIABLE( \ + torch_warn_once_) = [&] { \ + TORCH_WARN(__VA_ARGS__); \ + return true; \ + }() + +#ifdef DISABLE_WARN +#define TORCH_WARN_ONCE(...) ((void)0); +#else +#define TORCH_WARN_ONCE(...) \ + if (::c10::WarningUtils::get_warnAlways()) { \ + TORCH_WARN(__VA_ARGS__); \ + } else { \ + _TORCH_WARN_ONCE(__VA_ARGS__); \ + } +#endif + +// Report an error with a specific argument +// NOTE: using the argument name in TORCH_CHECK's message is preferred +#define TORCH_CHECK_ARG(cond, argN, ...) \ + TORCH_CHECK(cond, "invalid argument ", argN, ": ", __VA_ARGS__) + +#ifndef FATAL_IF +#ifdef C10_USE_GLOG +#define FATAL_IF(condition) \ + condition ? (void)0 \ + : ::c10::LoggerVoidify() & \ + ::c10::MessageLogger(__FILE__, __LINE__, ::google::GLOG_FATAL) \ + .stream() +#else +#define FATAL_IF(condition) \ + condition ? (void)0 \ + : ::c10::LoggerVoidify() & \ + ::c10::MessageLogger(__FILE__, __LINE__, ::c10::GLOG_FATAL).stream() +#endif +#endif + +#ifndef NON_FATAL_IF +#ifdef C10_USE_GLOG +#define NON_FATAL_IF(condition) \ + condition ? (void)0 \ + : ::c10::LoggerVoidify() & \ + ::c10::MessageLogger( \ + __FILE__, __LINE__, ::google::GLOG_FATAL, false) \ + .stream() +#else +#define NON_FATAL_IF(condition) \ + condition ? (void)0 \ + : ::c10::LoggerVoidify() & \ + ::c10::MessageLogger(__FILE__, __LINE__, ::c10::GLOG_FATAL, false) \ + .stream() +#endif +#endif + +// Binary comparison check macros +#define TORCH_CHECK_OP(val1, val2, op) \ + NON_FATAL_IF(((val1)op(val2))) \ + << "Check failed: " #val1 " " #op " " #val2 " (" << (val1) << " vs. " \ + << (val2) << "). " + +#define TORCH_DCHECK_OP(val1, val2, op) \ + FATAL_IF(((val1)op(val2))) << "Check failed: " #val1 " " #op " " #val2 " (" \ + << (val1) << " vs. " << (val2) << "). " + +#define TORCH_CHECK_EQ(val1, val2) TORCH_CHECK_OP(val1, val2, ==) +#define TORCH_CHECK_NE(val1, val2) TORCH_CHECK_OP(val1, val2, !=) +#define TORCH_CHECK_LE(val1, val2) TORCH_CHECK_OP(val1, val2, <=) +#define TORCH_CHECK_LT(val1, val2) TORCH_CHECK_OP(val1, val2, <) +#define TORCH_CHECK_GE(val1, val2) TORCH_CHECK_OP(val1, val2, >=) +#define TORCH_CHECK_GT(val1, val2) TORCH_CHECK_OP(val1, val2, >) + +// Debug versions of TORCH_CHECK_OP macros +#ifndef NDEBUG +#define TORCH_DCHECK_EQ(val1, val2) TORCH_DCHECK_OP(val1, val2, ==) +#define TORCH_DCHECK_NE(val1, val2) TORCH_DCHECK_OP(val1, val2, !=) +#define TORCH_DCHECK_LE(val1, val2) TORCH_DCHECK_OP(val1, val2, <=) +#define TORCH_DCHECK_LT(val1, val2) TORCH_DCHECK_OP(val1, val2, <) +#define TORCH_DCHECK_GE(val1, val2) TORCH_DCHECK_OP(val1, val2, >=) +#define TORCH_DCHECK_GT(val1, val2) TORCH_DCHECK_OP(val1, val2, >) +#else // !NDEBUG +// Optimized versions - generate no code +#define TORCH_DCHECK_EQ(val1, val2) \ + while (false) \ + TORCH_DCHECK_OP(val1, val2, ==) +#define TORCH_DCHECK_NE(val1, val2) \ + while (false) \ + TORCH_DCHECK_OP(val1, val2, !=) +#define TORCH_DCHECK_LE(val1, val2) \ + while (false) \ + TORCH_DCHECK_OP(val1, val2, <=) +#define TORCH_DCHECK_LT(val1, val2) \ + while (false) \ + TORCH_DCHECK_OP(val1, val2, <) +#define TORCH_DCHECK_GE(val1, val2) \ + while (false) \ + TORCH_DCHECK_OP(val1, val2, >=) +#define TORCH_DCHECK_GT(val1, val2) \ + while (false) \ + TORCH_DCHECK_OP(val1, val2, >) +#endif // NDEBUG + +// Null pointer check macro +#define TORCH_CHECK_NOTNULL(val) \ + ::c10::CheckNotNull(__FILE__, __LINE__, #val, (val), false) + +#ifndef NDEBUG +#define TORCH_DCHECK_NOTNULL(val) \ + ::c10::CheckNotNull(__FILE__, __LINE__, #val, (val), true) +#else // !NDEBUG +#define TORCH_DCHECK_NOTNULL(val) \ + while (false) \ + TORCH_CHECK_NOTNULL(val) +#endif // NDEBUG + +// ---------------------------------------------------------------------------- +// Deprecated macros +// ---------------------------------------------------------------------------- + +namespace c10::detail { + +/* +// Deprecation disabled until we fix sites in our codebase +[[deprecated("AT_ERROR(msg) is deprecated, use TORCH_CHECK(false, msg) +instead.")]] +*/ +inline void deprecated_AT_ERROR() {} + +/* +// Deprecation disabled until we fix sites in our codebase +[[deprecated("AT_ASSERT is deprecated, if you mean to indicate an +internal invariant failure, use " \ + "TORCH_INTERNAL_ASSERT instead; if you mean to do user +error checking, use " \ "TORCH_CHECK. See +https://github.com/pytorch/pytorch/issues/20287 for more details.")]] +*/ +inline void deprecated_AT_ASSERT() {} + +/* +// Deprecation disabled until we fix sites in our codebase +[[deprecated("AT_ASSERTM is deprecated, if you mean to indicate an +internal invariant failure, use " \ + "TORCH_INTERNAL_ASSERT instead; if you mean to do user +error checking, use " \ "TORCH_CHECK. See +https://github.com/pytorch/pytorch/issues/20287 for more details.")]] +*/ +inline void deprecated_AT_ASSERTM() {} + +} // namespace c10::detail + +// Deprecated alias; this alias was deprecated because people kept mistakenly +// using it for user error checking. Use TORCH_INTERNAL_ASSERT or TORCH_CHECK +// instead. See https://github.com/pytorch/pytorch/issues/20287 for more +// details. +#define AT_ASSERT(...) \ + do { \ + ::c10::detail::deprecated_AT_ASSERT(); \ + C10_EXPAND_MSVC_WORKAROUND(TORCH_INTERNAL_ASSERT(__VA_ARGS__)); \ + } while (false) + +// Deprecated alias, like AT_ASSERT. The new TORCH_INTERNAL_ASSERT macro +// supports both 0-ary and variadic calls, so having a separate +// message-accepting macro is not necessary. +// +// NB: we MUST include cond explicitly here, as MSVC will miscompile the macro +// expansion, shunting all of __VA_ARGS__ to cond. An alternate workaround +// can be seen at +// https://stackoverflow.com/questions/5134523/msvc-doesnt-expand-va-args-correctly +#define AT_ASSERTM(cond, ...) \ + do { \ + ::c10::detail::deprecated_AT_ASSERTM(); \ + C10_EXPAND_MSVC_WORKAROUND(TORCH_INTERNAL_ASSERT(cond, __VA_ARGS__)); \ + } while (false) + +// Deprecated alias; this alias was deprecated because it represents extra API +// surface that makes it hard for people to understand what macro to use. +// Use TORCH_CHECK(false, ...) or TORCH_INTERNAL_ASSERT(false, ...) to +// unconditionally fail at a line of code. +#define AT_ERROR(...) \ + do { \ + ::c10::detail::deprecated_AT_ERROR(); \ + C10_EXPAND_MSVC_WORKAROUND(TORCH_CHECK(false, ::c10::str(__VA_ARGS__))); \ + } while (false) + +#endif // C10_UTIL_EXCEPTION_H_ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/ExclusivelyOwned.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/ExclusivelyOwned.h new file mode 100644 index 0000000000000000000000000000000000000000..24cdba8d3ea3d9850b673974971c9eca37ff365f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/ExclusivelyOwned.h @@ -0,0 +1,145 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include + +namespace c10 { + +// See example implementation in TensorBase.h and TensorBody.h. +// Synopsis: +// +// repr_type -- type to use to store an owned T in ExclusivelyOwned. +// +// pointer_type -- pointer-esque type to return from +// ExclusivelyOwned's get() and operator*() methods. +// +// const_pointer_type -- similar to pointer_type, used for the const methods. +// +// static repr_type nullRepr() -- return a null instance of repr_type. +// +// template +// static repr_type createInPlace(Args&&... args) -- used by the in-place +// ExclusivelyOwned constructor. +// +// static repr_type moveToRepr(T&& x) -- move the given x into an +// instance of repr_type. used by the ExclusivelyOwned(T&&) +// constructor. +// +// static void destroyOwned(repr_type x) -- free memory for a +// known-exclusively-owned instance of x. Replaces calling repr_type's +// destructor. Being able to implement this more efficiently than +// repr_type's destructor is the main reason to use ExclusivelyOwned +// for a type. +// +// static T take(repr_type&) -- move out of the given repr_type into an owned T. +// +// static pointer_type getImpl(const repr_type&) -- return a pointer +// to the given repr_type. May take repr_type by value if that is more +// efficient. +template +struct ExclusivelyOwnedTraits; + +/// ExclusivelyOwned is a smart-pointer-like wrapper around an +/// exclusively-owned instance of some type T that normally has +/// mandatory reference counting (currently just Tensor). If you have +/// an isolated piece of code that knows that it has sole ownership of +/// an object of one of these types (i.e., because you created it +/// directly or using a factory function) and that object will not +/// escape from that isolated piece of code, then moving the object +/// into an ExclusivelyOwned will avoid an atomic reference count +/// decrement at destruction time. +/// +/// If you directly create the Tensor in the first +/// place, you can use the in_place constructor of ExclusivelyOwned to +/// additionally avoid doing any stores to initialize the refcount & +/// weakcount. +template +class ExclusivelyOwned { + using EOT = ExclusivelyOwnedTraits; + typename ExclusivelyOwnedTraits::repr_type repr_; + + public: + ExclusivelyOwned() : repr_(EOT::nullRepr()) {} + + explicit ExclusivelyOwned(T&& t) : repr_(EOT::moveToRepr(std::move(t))) {} + + template + explicit ExclusivelyOwned(std::in_place_t /*unused*/, Args&&... args) + : repr_(EOT::createInPlace(std::forward(args)...)) {} + + ExclusivelyOwned(const ExclusivelyOwned&) = delete; + + ExclusivelyOwned(ExclusivelyOwned&& rhs) noexcept + : repr_(std::move(rhs.repr_)) { + rhs.repr_ = EOT::nullRepr(); + } + + ExclusivelyOwned& operator=(const ExclusivelyOwned&) = delete; + + ExclusivelyOwned& operator=(ExclusivelyOwned&& rhs) noexcept { + EOT::destroyOwned(repr_); + repr_ = std::move(rhs.repr_); + rhs.repr_ = EOT::nullRepr(); + return *this; + } + + ExclusivelyOwned& operator=(T&& rhs) noexcept { + EOT::destroyOwned(repr_); + repr_ = EOT::moveToRepr(std::move(rhs)); + return *this; + } + + ~ExclusivelyOwned() { + EOT::destroyOwned(repr_); + // Don't bother to call the destructor of repr_, since we already + // did specialized destruction for the exclusively-owned case in + // destroyOwned! + } + + // We don't provide this because it would require us to be able to + // differentiate an owned-but-empty T from a lack of T. This is + // particularly problematic for Tensor, which wants to use an + // undefined Tensor as its null state. + explicit operator bool() const noexcept = delete; + + operator T() && { + return take(); + } + + // NOTE: the equivalent operation on MaybeOwned is a moving + // operator*. For ExclusivelyOwned, take() and operator*() may well + // have different return types, so they are different functions. + T take() && { + return EOT::take(repr_); + } + + typename EOT::const_pointer_type operator->() const { + return get(); + } + + typename EOT::const_pointer_type get() const { + return EOT::getImpl(repr_); + } + + typename EOT::pointer_type operator->() { + return get(); + } + + typename EOT::pointer_type get() { + return EOT::getImpl(repr_); + } + + std::remove_pointer_t& operator*() const { + return *get(); + } + + std::remove_pointer_t& operator*() { + return *get(); + } +}; + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/ExclusivelyOwnedTensorTraits.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/ExclusivelyOwnedTensorTraits.h new file mode 100644 index 0000000000000000000000000000000000000000..5b3a76fe9fc94776a70538d212e657435189b350 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/ExclusivelyOwnedTensorTraits.h @@ -0,0 +1,80 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +#include + +namespace c10 { +// Shared ExclusivelyOwnedTraits implementation between caffe2::Tensor and +// at::TensorBase. +template +struct ExclusivelyOwnedTensorTraits { + using repr_type = TensorType; + using pointer_type = TensorType*; + using const_pointer_type = const TensorType*; + + static repr_type nullRepr() { + return TensorType(); + } + + template + static repr_type createInPlace(Args&&... args) { + return TensorType(std::forward(args)...); + } + + static repr_type moveToRepr(TensorType&& x) { + return std::move(x); + } + + static void destroyOwned(TensorType& x) { + TensorImpl* const toDestroy = x.unsafeReleaseTensorImpl(); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + toDestroy != nullptr, "Tensor somehow got null TensorImpl?"); + // May be 0 because UndefinedTensorImpl doesn't get its refcount + // incremented. + const bool isUndefined = toDestroy == UndefinedTensorImpl::singleton(); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + toDestroy->refcount() == 1 || + (toDestroy->refcount() == 0 && isUndefined), + "ExclusivelyOwned destroyed with isUndefined ", + isUndefined, + " and refcount ", + toDestroy->refcount(), + ", expected 1 or, if isUndefined, 0!"); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + toDestroy->weakcount() == 1 || + (toDestroy->weakcount() == 0 && + toDestroy == UndefinedTensorImpl::singleton()), + "ExclusivelyOwned destroyed with isUndefined ", + isUndefined, + " and weakcount ", + toDestroy->weakcount(), + ", expected 1 or, if isUndefined, 0!"); + if (!isUndefined) { +#ifndef NDEBUG + // Needed to pass the debug assertions in ~intrusive_ptr_target. + toDestroy->combined_refcount_.store(0, std::memory_order_relaxed); +#endif + delete toDestroy; + } + } + + static TensorType take(TensorType& x) { + return std::move(x); + } + + static pointer_type getImpl(repr_type& x) { + return &x; + } + + static const_pointer_type getImpl(const repr_type& x) { + return &x; + } +}; +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/FbcodeMaps.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/FbcodeMaps.h new file mode 100644 index 0000000000000000000000000000000000000000..8ce3648d928f50cf474d26cab63c16df16dda728 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/FbcodeMaps.h @@ -0,0 +1,34 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#ifndef C10_UTIL_FBCODEMAPS_H_ +#define C10_UTIL_FBCODEMAPS_H_ + +// Map typedefs so that we can use folly's F14 maps in fbcode without +// taking a folly dependency. + +#ifdef FBCODE_CAFFE2 +#include +#include +#else +#include +#include +#endif + +namespace c10 { +#ifdef FBCODE_CAFFE2 +template +using FastMap = folly::F14FastMap; +template +using FastSet = folly::F14FastSet; +#else +template +using FastMap = std::unordered_map; +template +using FastSet = std::unordered_set; +#endif +} // namespace c10 + +#endif // C10_UTIL_FBCODEMAPS_H_ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/FileSystem.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/FileSystem.h new file mode 100644 index 0000000000000000000000000000000000000000..964c57668f629d342576a50f173247192e9f6c4d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/FileSystem.h @@ -0,0 +1,27 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Shim header for filesystem for compilers that are too old to have it not +// in the experimental namespace + +#if __has_include() +#include +#elif __has_include() +#include +#else +#error "Neither nor is available." +#endif + +namespace c10 { + +#if __has_include() +// NOLINTNEXTLINE(misc-unused-alias-decls) +namespace filesystem = std::filesystem; +#elif __has_include() +// NOLINTNEXTLINE(misc-unused-alias-decls) +namespace filesystem = std::experimental::filesystem; +#endif + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Flags.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Flags.h new file mode 100644 index 0000000000000000000000000000000000000000..c2485bfdebae3a17f0fc8131cfcf24c01052c2a9 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Flags.h @@ -0,0 +1,247 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#ifndef C10_UTIL_FLAGS_H_ +#define C10_UTIL_FLAGS_H_ + +/* Commandline flags support for C10. + * + * This is a portable commandline flags tool for c10, so we can optionally + * choose to use gflags or a lightweight custom implementation if gflags is + * not possible on a certain platform. If you have gflags installed, set the + * macro C10_USE_GFLAGS will seamlessly route everything to gflags. + * + * To define a flag foo of type bool default to true, do the following in the + * *global* namespace: + * C10_DEFINE_bool(foo, true, "An example."); + * + * To use it in another .cc file, you can use C10_DECLARE_* as follows: + * C10_DECLARE_bool(foo); + * + * In both cases, you can then access the flag via FLAGS_foo. + * + * It is recommended that you build with gflags. To learn more about the flags + * usage, refer to the gflags page here: + * + * https://gflags.github.io/gflags/ + * + * Note about Python users / devs: gflags is initiated from a C++ function + * ParseCommandLineFlags, and is usually done in native binaries in the main + * function. As Python does not have a modifiable main function, it is usually + * difficult to change the flags after Python starts. Hence, it is recommended + * that one sets the default value of the flags to one that's acceptable in + * general - that will allow Python to run without wrong flags. + */ + +#include +#include + +#include + +namespace c10 { +/** + * Sets the usage message when a commandline tool is called with "--help". + */ +C10_API void SetUsageMessage(const std::string& str); + +/** + * Returns the usage message for the commandline tool set by SetUsageMessage. + */ +C10_API const char* UsageMessage(); + +/** + * Parses the commandline flags. + * + * This command parses all the commandline arguments passed in via pargc + * and argv. Once it is finished, partc and argv will contain the remaining + * commandline args that c10 does not deal with. Note that following + * convention, argv[0] contains the binary name and is not parsed. + */ +C10_API bool ParseCommandLineFlags(int* pargc, char*** pargv); + +/** + * Checks if the commandline flags has already been passed. + */ +C10_API bool CommandLineFlagsHasBeenParsed(); + +} // namespace c10 + +//////////////////////////////////////////////////////////////////////////////// +// Below are gflags and non-gflags specific implementations. +// In general, they define the following macros for one to declare (use +// C10_DECLARE) or define (use C10_DEFINE) flags: +// C10_{DECLARE,DEFINE}_{int,int64,double,bool,string} +//////////////////////////////////////////////////////////////////////////////// + +#ifdef C10_USE_GFLAGS + +//////////////////////////////////////////////////////////////////////////////// +// Begin gflags section: most functions are basically rerouted to gflags. +//////////////////////////////////////////////////////////////////////////////// +#include + +// C10 uses hidden visibility by default. However, in gflags, it only uses +// export on Windows platform (with dllexport) but not on linux/mac (with +// default visibility). As a result, to ensure that we are always exporting +// global variables, we will redefine the GFLAGS_DLL_DEFINE_FLAG macro if we +// are building C10 as a shared library. +// This has to be done after the inclusion of gflags, because some early +// versions of gflags.h (e.g. 2.0 on ubuntu 14.04) directly defines the +// macros, so we need to do definition after gflags is done. +#ifdef GFLAGS_DLL_DEFINE_FLAG +#undef GFLAGS_DLL_DEFINE_FLAG +#endif // GFLAGS_DLL_DEFINE_FLAG +#ifdef GFLAGS_DLL_DECLARE_FLAG +#undef GFLAGS_DLL_DECLARE_FLAG +#endif // GFLAGS_DLL_DECLARE_FLAG +#define GFLAGS_DLL_DEFINE_FLAG C10_EXPORT +#define GFLAGS_DLL_DECLARE_FLAG C10_IMPORT + +// gflags before 2.0 uses namespace google and after 2.1 uses namespace gflags. +// Using GFLAGS_GFLAGS_H_ to capture this change. +#ifndef GFLAGS_GFLAGS_H_ +namespace gflags = google; +#endif // GFLAGS_GFLAGS_H_ + +// Motivation about the gflags wrapper: +// (1) We would need to make sure that the gflags version and the non-gflags +// version of C10 are going to expose the same flags abstraction. One should +// explicitly use FLAGS_flag_name to access the flags. +// (2) For flag names, it is recommended to start with c10_ to distinguish it +// from regular gflags flags. For example, do +// C10_DEFINE_BOOL(c10_my_flag, true, "An example"); +// to allow one to use FLAGS_c10_my_flag. +// (3) Gflags has a design issue that does not properly expose the global flags, +// if one builds the library with -fvisibility=hidden. The current gflags (as of +// Aug 2018) only deals with the Windows case using dllexport, and not the Linux +// counterparts. As a result, we will explicitly use C10_EXPORT to export the +// flags defined in C10. This is done via a global reference, so the flag +// itself is not duplicated - under the hood it is the same global gflags flag. +#define C10_GFLAGS_DEF_WRAPPER(type, real_type, name, default_value, help_str) \ + DEFINE_##type(name, default_value, help_str); + +#define C10_DEFINE_int(name, default_value, help_str) \ + C10_GFLAGS_DEF_WRAPPER(int32, gflags::int32, name, default_value, help_str) +#define C10_DEFINE_int32(name, default_value, help_str) \ + C10_DEFINE_int(name, default_value, help_str) +#define C10_DEFINE_int64(name, default_value, help_str) \ + C10_GFLAGS_DEF_WRAPPER(int64, gflags::int64, name, default_value, help_str) +#define C10_DEFINE_double(name, default_value, help_str) \ + C10_GFLAGS_DEF_WRAPPER(double, double, name, default_value, help_str) +#define C10_DEFINE_bool(name, default_value, help_str) \ + C10_GFLAGS_DEF_WRAPPER(bool, bool, name, default_value, help_str) +#define C10_DEFINE_string(name, default_value, help_str) \ + C10_GFLAGS_DEF_WRAPPER(string, ::fLS::clstring, name, default_value, help_str) + +// DECLARE_typed_var should be used in header files and in the global namespace. +#define C10_GFLAGS_DECLARE_WRAPPER(type, real_type, name) DECLARE_##type(name); + +#define C10_DECLARE_int(name) \ + C10_GFLAGS_DECLARE_WRAPPER(int32, gflags::int32, name) +#define C10_DECLARE_int32(name) C10_DECLARE_int(name) +#define C10_DECLARE_int64(name) \ + C10_GFLAGS_DECLARE_WRAPPER(int64, gflags::int64, name) +#define C10_DECLARE_double(name) \ + C10_GFLAGS_DECLARE_WRAPPER(double, double, name) +#define C10_DECLARE_bool(name) C10_GFLAGS_DECLARE_WRAPPER(bool, bool, name) +#define C10_DECLARE_string(name) \ + C10_GFLAGS_DECLARE_WRAPPER(string, ::fLS::clstring, name) + +#define TORCH_DECLARE_int(name) C10_DECLARE_int(name) +#define TORCH_DECLARE_int32(name) C10_DECLARE_int32(name) +#define TORCH_DECLARE_int64(name) C10_DECLARE_int64(name) +#define TORCH_DECLARE_double(name) C10_DECLARE_double(name) +#define TORCH_DECLARE_bool(name) C10_DECLARE_bool(name) +#define TORCH_DECLARE_string(name) C10_DECLARE_string(name) + +//////////////////////////////////////////////////////////////////////////////// +// End gflags section. +//////////////////////////////////////////////////////////////////////////////// + +#else // C10_USE_GFLAGS + +//////////////////////////////////////////////////////////////////////////////// +// Begin non-gflags section: providing equivalent functionality. +//////////////////////////////////////////////////////////////////////////////// + +namespace c10 { + +class C10_API C10FlagParser { + public: + bool success() { + return success_; + } + + protected: + template + bool Parse(const std::string& content, T* value); + bool success_{false}; +}; + +C10_DECLARE_REGISTRY(C10FlagsRegistry, C10FlagParser, const std::string&); + +} // namespace c10 + +// The macros are defined outside the c10 namespace. In your code, you should +// write the C10_DEFINE_* and C10_DECLARE_* macros outside any namespace +// as well. + +#define C10_DEFINE_typed_var(type, name, default_value, help_str) \ + C10_EXPORT type FLAGS_##name = default_value; \ + namespace c10 { \ + namespace { \ + class C10FlagParser_##name : public C10FlagParser { \ + public: \ + explicit C10FlagParser_##name(const std::string& content) { \ + success_ = C10FlagParser::Parse(content, &FLAGS_##name); \ + } \ + }; \ + RegistererC10FlagsRegistry g_C10FlagsRegistry_##name( \ + #name, \ + C10FlagsRegistry(), \ + RegistererC10FlagsRegistry::DefaultCreator, \ + "(" #type ", default " #default_value ") " help_str); \ + } \ + } + +#define C10_DEFINE_int(name, default_value, help_str) \ + C10_DEFINE_typed_var(int, name, default_value, help_str) +#define C10_DEFINE_int32(name, default_value, help_str) \ + C10_DEFINE_int(name, default_value, help_str) +#define C10_DEFINE_int64(name, default_value, help_str) \ + C10_DEFINE_typed_var(int64_t, name, default_value, help_str) +#define C10_DEFINE_double(name, default_value, help_str) \ + C10_DEFINE_typed_var(double, name, default_value, help_str) +#define C10_DEFINE_bool(name, default_value, help_str) \ + C10_DEFINE_typed_var(bool, name, default_value, help_str) +#define C10_DEFINE_string(name, default_value, help_str) \ + C10_DEFINE_typed_var(std::string, name, default_value, help_str) + +// DECLARE_typed_var should be used in header files and in the global namespace. +#define C10_DECLARE_typed_var(type, name) C10_API extern type FLAGS_##name + +#define C10_DECLARE_int(name) C10_DECLARE_typed_var(int, name) +#define C10_DECLARE_int32(name) C10_DECLARE_int(name) +#define C10_DECLARE_int64(name) C10_DECLARE_typed_var(int64_t, name) +#define C10_DECLARE_double(name) C10_DECLARE_typed_var(double, name) +#define C10_DECLARE_bool(name) C10_DECLARE_typed_var(bool, name) +#define C10_DECLARE_string(name) C10_DECLARE_typed_var(std::string, name) + +#define TORCH_DECLARE_typed_var(type, name) TORCH_API extern type FLAGS_##name + +#define TORCH_DECLARE_int(name) TORCH_DECLARE_typed_var(int, name) +#define TORCH_DECLARE_int32(name) TORCH_DECLARE_int(name) +#define TORCH_DECLARE_int64(name) TORCH_DECLARE_typed_var(int64_t, name) +#define TORCH_DECLARE_double(name) TORCH_DECLARE_typed_var(double, name) +#define TORCH_DECLARE_bool(name) TORCH_DECLARE_typed_var(bool, name) +#define TORCH_DECLARE_string(name) TORCH_DECLARE_typed_var(std::string, name) + +//////////////////////////////////////////////////////////////////////////////// +// End non-gflags section. +//////////////////////////////////////////////////////////////////////////////// + +#endif // C10_USE_GFLAGS + +#endif // C10_UTIL_FLAGS_H_ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Float4_e2m1fn_x2.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Float4_e2m1fn_x2.h new file mode 100644 index 0000000000000000000000000000000000000000..fd690e5aa345ac097a2b4022b6e5a42677e403f8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Float4_e2m1fn_x2.h @@ -0,0 +1,6 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Float8_e4m3fn-inl.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Float8_e4m3fn-inl.h new file mode 100644 index 0000000000000000000000000000000000000000..ed07b955168f7ab08b4a20657d8f36ea7cd4123c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Float8_e4m3fn-inl.h @@ -0,0 +1,6 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Float8_e4m3fn.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Float8_e4m3fn.h new file mode 100644 index 0000000000000000000000000000000000000000..ed07b955168f7ab08b4a20657d8f36ea7cd4123c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Float8_e4m3fn.h @@ -0,0 +1,6 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Float8_e4m3fnuz-inl.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Float8_e4m3fnuz-inl.h new file mode 100644 index 0000000000000000000000000000000000000000..30481a62430fdf08f2107bc1ab50e811314767f3 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Float8_e4m3fnuz-inl.h @@ -0,0 +1,6 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Float8_e4m3fnuz.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Float8_e4m3fnuz.h new file mode 100644 index 0000000000000000000000000000000000000000..30481a62430fdf08f2107bc1ab50e811314767f3 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Float8_e4m3fnuz.h @@ -0,0 +1,6 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Float8_e5m2-inl.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Float8_e5m2-inl.h new file mode 100644 index 0000000000000000000000000000000000000000..f4e0802e2f7b1a6712f95dea5b82267d8a8498dc --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Float8_e5m2-inl.h @@ -0,0 +1,6 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Float8_e5m2.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Float8_e5m2.h new file mode 100644 index 0000000000000000000000000000000000000000..f4e0802e2f7b1a6712f95dea5b82267d8a8498dc --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Float8_e5m2.h @@ -0,0 +1,6 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Float8_e5m2fnuz-inl.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Float8_e5m2fnuz-inl.h new file mode 100644 index 0000000000000000000000000000000000000000..f3e8c25099a630204f3c4ee345fd2a3653c14116 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Float8_e5m2fnuz-inl.h @@ -0,0 +1,6 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Float8_e5m2fnuz.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Float8_e5m2fnuz.h new file mode 100644 index 0000000000000000000000000000000000000000..f3e8c25099a630204f3c4ee345fd2a3653c14116 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Float8_e5m2fnuz.h @@ -0,0 +1,6 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Float8_e8m0fnu-inl.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Float8_e8m0fnu-inl.h new file mode 100644 index 0000000000000000000000000000000000000000..030b23d64750b7378c8fc281c96d2fe662e38d88 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Float8_e8m0fnu-inl.h @@ -0,0 +1,6 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Float8_e8m0fnu.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Float8_e8m0fnu.h new file mode 100644 index 0000000000000000000000000000000000000000..030b23d64750b7378c8fc281c96d2fe662e38d88 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Float8_e8m0fnu.h @@ -0,0 +1,6 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/FunctionRef.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/FunctionRef.h new file mode 100644 index 0000000000000000000000000000000000000000..342824b5b9095219b123ab4bfb19fbb3cd1a7819 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/FunctionRef.h @@ -0,0 +1,80 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +//===- llvm/ADT/STLExtras.h - Useful STL related functions ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains some templates that are useful if you are working with the +// STL at all. +// +// No library is required when using these functions. +// +//===----------------------------------------------------------------------===// + +// c10: modified from llvm::function_ref +// c10: added more SFINAE to enable use in overloaded functions + +#pragma once + +#include +#include +#include + +namespace c10 { + +/// An efficient, type-erasing, non-owning reference to a callable. This is +/// intended for use as the type of a function parameter that is not used +/// after the function in question returns. +/// +/// This class does not own the callable, so it is not in general safe to store +/// a function_ref. +template +class function_ref; + +template +class function_ref { + Ret (*callback)(intptr_t callable, Params... params) = nullptr; + intptr_t callable{}; + + template + static Ret callback_fn(intptr_t callable, Params... params) { + return (*reinterpret_cast(callable))( + std::forward(params)...); + } + + public: + function_ref() = default; + function_ref(std::nullptr_t) {} + + template + function_ref( + // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward) + Callable&& callable, + std::enable_if_t, + function_ref>>* /*unused*/ + = nullptr, + std::enable_if_t, + Ret>>* /*unused*/ + = nullptr) + : callback(callback_fn>), + callable(reinterpret_cast(&callable)) {} + + Ret operator()(Params... params) const { + return callback(callable, std::forward(params)...); + } + + operator bool() const { + return callback; + } +}; + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Gauge.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Gauge.h new file mode 100644 index 0000000000000000000000000000000000000000..b10ed7f5c9b33b99adbd031069af7c4e2fd3d0e3 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Gauge.h @@ -0,0 +1,55 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +#include +#include + +namespace c10::monitor { +namespace detail { + +class GaugeImpl; + +class GaugeBackendIf { + public: + virtual ~GaugeBackendIf() = default; + virtual void record(int64_t value) noexcept = 0; +}; + +class GaugeBackendFactoryIf { + public: + virtual ~GaugeBackendFactoryIf() = default; + + // May return nullptr if the gauge will be ignored by the given backend. + virtual std::unique_ptr create( + std::string_view key) noexcept = 0; +}; + +void C10_API + registerGaugeBackend(std::unique_ptr /*backend*/); +} // namespace detail + +// A handle to a Gauge. +class C10_API GaugeHandle { + public: + explicit GaugeHandle(std::string_view key); + void record(int64_t value); + + private: + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) + detail::GaugeImpl& impl_; +}; + +} // namespace c10::monitor + +#define STATIC_GAUGE(_key) \ + []() -> ::c10::monitor::GaugeHandle& { \ + static ::c10::monitor::GaugeHandle handle(#_key); \ + return handle; \ + }() + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Half-inl.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Half-inl.h new file mode 100644 index 0000000000000000000000000000000000000000..78c3d37c1698db15f05b3b3367765075be2d9046 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Half-inl.h @@ -0,0 +1,6 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Half.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Half.h new file mode 100644 index 0000000000000000000000000000000000000000..0a3d4462657c7aa4d4e3827a2de811132911632b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Half.h @@ -0,0 +1,13 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include + +// need to keep the following for BC because the APIs in here were exposed +// before migrating Half to torch/headeronly +#if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && \ + !defined(__APPLE__) +#include +#endif + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/IdWrapper.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/IdWrapper.h new file mode 100644 index 0000000000000000000000000000000000000000..b985cd3e51c325b50dd5ee368c216689888123d6 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/IdWrapper.h @@ -0,0 +1,82 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include + +namespace c10 { + +/** + * This template simplifies generation of simple classes that wrap an id + * in a typesafe way. Namely, you can use it to create a very lightweight + * type that only offers equality comparators and hashing. Example: + * + * struct MyIdType final : IdWrapper { + * constexpr explicit MyIdType(uint32_t id): IdWrapper(id) {} + * }; + * + * Then in the global top level namespace: + * + * C10_DEFINE_HASH_FOR_IDWRAPPER(MyIdType); + * + * That's it - equality operators and hash functions are automatically defined + * for you, given the underlying type supports it. + */ +template +class IdWrapper { + public: + using underlying_type = UnderlyingType; + using concrete_type = ConcreteType; + + protected: + constexpr explicit IdWrapper(underlying_type id) noexcept( + noexcept(underlying_type(std::declval()))) + : id_(id) {} + + constexpr underlying_type underlyingId() const + noexcept(noexcept(underlying_type(std::declval()))) { + return id_; + } + + private: + friend size_t hash_value(const concrete_type& v) { + return std::hash()(v.id_); + } + + // TODO Making operator== noexcept if underlying type is noexcept equality + // comparable doesn't work with GCC 4.8. + // Fix this once we don't need GCC 4.8 anymore. + friend constexpr bool operator==( + const concrete_type& lhs, + const concrete_type& rhs) noexcept { + return lhs.id_ == rhs.id_; + } + + // TODO Making operator!= noexcept if operator== is noexcept doesn't work with + // GCC 4.8. + // Fix this once we don't need GCC 4.8 anymore. + friend constexpr bool operator!=( + const concrete_type& lhs, + const concrete_type& rhs) noexcept { + return !(lhs == rhs); + } + + underlying_type id_; +}; + +} // namespace c10 + +#define C10_DEFINE_HASH_FOR_IDWRAPPER(ClassName) \ + namespace std { \ + template <> \ + struct hash { \ + size_t operator()(ClassName x) const { \ + return hash_value(x); \ + } \ + }; \ + } + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/IntrusiveList.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/IntrusiveList.h new file mode 100644 index 0000000000000000000000000000000000000000..a28803082f7b641b92dae8acf320b7b9be348d74 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/IntrusiveList.h @@ -0,0 +1,211 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include + +namespace c10 { + +template +class IntrusiveList; + +class IntrusiveListHook { + template + friend class ListIterator; + + template + friend class IntrusiveList; + + IntrusiveListHook* next_{nullptr}; + IntrusiveListHook* prev_{nullptr}; + + void link_before(IntrusiveListHook* next_node) { + next_ = next_node; + prev_ = next_node->prev_; + next_node->prev_ = this; + prev_->next_ = this; + } + + public: + IntrusiveListHook() : next_(this), prev_(this) {} + + IntrusiveListHook(const IntrusiveListHook&) = delete; + IntrusiveListHook& operator=(const IntrusiveListHook&) = delete; + IntrusiveListHook(IntrusiveListHook&&) = delete; + IntrusiveListHook& operator=(IntrusiveListHook&&) = delete; + + void unlink() { + TORCH_CHECK(is_linked()); + next_->prev_ = prev_; + prev_->next_ = next_; + next_ = this; + prev_ = this; + } + + ~IntrusiveListHook() { + if (is_linked()) { + unlink(); + } + } + + bool is_linked() const { + return next_ != this; + } +}; + +template +class ListIterator { + static_assert(std::is_same_v, IntrusiveListHook>); + static_assert(std::is_base_of_v); + P* ptr_; + + friend class IntrusiveList; + + public: + using iterator_category = std::bidirectional_iterator_tag; + using value_type = std::conditional_t, const T, T>; + using difference_type = std::ptrdiff_t; + using pointer = value_type*; + using reference = value_type&; + + explicit ListIterator(P* ptr) : ptr_(ptr) {} + ~ListIterator() = default; + + ListIterator(const ListIterator&) = default; + ListIterator& operator=(const ListIterator&) = default; + ListIterator(ListIterator&&) = default; + ListIterator& operator=(ListIterator&&) = default; + + template < + typename Q, + class = std::enable_if_t && !std::is_const_v>> + ListIterator(const ListIterator& rhs) : ptr_(rhs.ptr_) {} + + template < + typename Q, + class = std::enable_if_t && !std::is_const_v>> + ListIterator& operator=(const ListIterator& rhs) { + ptr_ = rhs.ptr_; + return *this; + } + + template + bool operator==(const ListIterator& other) const { + return ptr_ == other.ptr_; + } + + template + bool operator!=(const ListIterator& other) const { + return !(*this == other); + } + + auto& operator*() const { + return static_cast(*ptr_); + } + + ListIterator& operator++() { + TORCH_CHECK(ptr_); + ptr_ = ptr_->next_; + return *this; + } + + ListIterator& operator--() { + TORCH_CHECK(ptr_); + ptr_ = ptr_->prev_; + return *this; + } + + auto* operator->() const { + return static_cast(ptr_); + } +}; + +template +class IntrusiveList { + static_assert(std::is_base_of_v); + + public: + IntrusiveList() = default; + IntrusiveList(const std::initializer_list>& items) { + for (auto& item : items) { + insert(this->end(), item); + } + } + ~IntrusiveList() { + while (head_.is_linked()) { + head_.next_->unlink(); + } + } + IntrusiveList(const IntrusiveList&) = delete; + IntrusiveList& operator=(const IntrusiveList&) = delete; + IntrusiveList(IntrusiveList&&) = delete; + IntrusiveList& operator=(IntrusiveList&&) = delete; + + using iterator = ListIterator; + using const_iterator = ListIterator; + + auto begin() const { + return ++const_iterator{&head_}; + } + + auto begin() { + return ++iterator{&head_}; + } + + auto end() const { + return const_iterator{&head_}; + } + + auto end() { + return iterator{&head_}; + } + + auto rbegin() const { + return std::reverse_iterator{end()}; + } + + auto rbegin() { + return std::reverse_iterator{end()}; + } + + auto rend() const { + return std::reverse_iterator{begin()}; + } + + auto rend() { + return std::reverse_iterator{begin()}; + } + + auto iterator_to(const T& n) const { + return const_iterator{&n}; + } + + auto iterator_to(T& n) { + return iterator{&n}; + } + + iterator insert(iterator pos, T& n) { + n.link_before(pos.ptr_); + return iterator{&n}; + } + + size_t size() const { + size_t ret = 0; + for ([[maybe_unused]] auto& _ : *this) { + ret++; + } + return ret; + } + + bool empty() const { + return !head_.is_linked(); + } + + private: + IntrusiveListHook head_; +}; + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Lazy.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Lazy.h new file mode 100644 index 0000000000000000000000000000000000000000..204fc205ef9940c397de23915c4fee7dba8673ec --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Lazy.h @@ -0,0 +1,125 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +namespace c10 { + +/** + * Thread-safe lazy value with opportunistic concurrency: on concurrent first + * access, the factory may be called by multiple threads, but only one result is + * stored and its reference returned to all the callers. + * + * Value is heap-allocated; this optimizes for the case in which the value is + * never actually computed. + */ +template +class OptimisticLazy { + public: + OptimisticLazy() = default; + OptimisticLazy(const OptimisticLazy& other) { + if (T* value = other.value_.load(std::memory_order_acquire)) { + value_ = new T(*value); + } + } + OptimisticLazy(OptimisticLazy&& other) noexcept + : value_(other.value_.exchange(nullptr, std::memory_order_acq_rel)) {} + ~OptimisticLazy() { + reset(); + } + + template + T& ensure(const Factory& factory) { + if (T* value = value_.load(std::memory_order_acquire)) { + return *value; + } + T* value = new T(factory()); + T* old = nullptr; + if (!value_.compare_exchange_strong( + old, value, std::memory_order_release, std::memory_order_acquire)) { + delete value; + value = old; + } + return *value; + } + + // The following methods are not thread-safe: they should not be called + // concurrently with any other method. + + OptimisticLazy& operator=(const OptimisticLazy& other) { + *this = OptimisticLazy{other}; + return *this; + } + + OptimisticLazy& operator=(OptimisticLazy&& other) noexcept { + if (this != &other) { + reset(); + value_.store( + other.value_.exchange(nullptr, std::memory_order_acquire), + std::memory_order_release); + } + return *this; + } + + void reset() { + if (T* old = value_.load(std::memory_order_relaxed)) { + value_.store(nullptr, std::memory_order_relaxed); + delete old; + } + } + + private: + std::atomic value_{nullptr}; +}; + +/** + * Interface for a value that is computed on first access. + */ +template +class LazyValue { + public: + virtual ~LazyValue() = default; + + virtual const T& get() const = 0; +}; + +/** + * Convenience thread-safe LazyValue implementation with opportunistic + * concurrency. + */ +template +class OptimisticLazyValue : public LazyValue { + public: + const T& get() const override { + return value_.ensure([this] { return compute(); }); + } + + private: + virtual T compute() const = 0; + + mutable OptimisticLazy value_; +}; + +/** + * Convenience immutable (thus thread-safe) LazyValue implementation for cases + * in which the value is not actually lazy. + */ +template +class PrecomputedLazyValue : public LazyValue { + public: + PrecomputedLazyValue(T value) : value_(std::move(value)) {} + + const T& get() const override { + return value_; + } + + private: + T value_; +}; + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/LeftRight.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/LeftRight.h new file mode 100644 index 0000000000000000000000000000000000000000..0435fffb73fdd7a8e6ef0cedc7d6feac6b818651 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/LeftRight.h @@ -0,0 +1,234 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace c10 { + +namespace detail { + +struct IncrementRAII final { + public: + explicit IncrementRAII(std::atomic* counter) : _counter(counter) { + _counter->fetch_add(1); + } + + ~IncrementRAII() { + _counter->fetch_sub(1); + } + IncrementRAII(IncrementRAII&&) = delete; + IncrementRAII& operator=(IncrementRAII&&) = delete; + + private: + std::atomic* _counter; + + C10_DISABLE_COPY_AND_ASSIGN(IncrementRAII); +}; + +} // namespace detail + +// LeftRight wait-free readers synchronization primitive +// https://hal.archives-ouvertes.fr/hal-01207881/document +// +// LeftRight is quite easy to use (it can make an arbitrary +// data structure permit wait-free reads), but it has some +// particular performance characteristics you should be aware +// of if you're deciding to use it: +// +// - Reads still incur an atomic write (this is how LeftRight +// keeps track of how long it needs to keep around the old +// data structure) +// +// - Writes get executed twice, to keep both the left and right +// versions up to date. So if your write is expensive or +// nondeterministic, this is also an inappropriate structure +// +// LeftRight is used fairly rarely in PyTorch's codebase. If you +// are still not sure if you need it or not, consult your local +// C++ expert. +// +template +class LeftRight final { + public: + template + explicit LeftRight(const Args&... args) + : _counters{{{0}, {0}}}, + _foregroundCounterIndex(0), + _foregroundDataIndex(0), + _data{{T{args...}, T{args...}}} {} + + // Copying and moving would not be threadsafe. + // Needs more thought and careful design to make that work. + LeftRight(const LeftRight&) = delete; + LeftRight(LeftRight&&) noexcept = delete; + LeftRight& operator=(const LeftRight&) = delete; + LeftRight& operator=(LeftRight&&) noexcept = delete; + + ~LeftRight() { + // wait until any potentially running writers are finished + { + std::unique_lock lock(_writeMutex); + } + + // wait until any potentially running readers are finished + while (_counters[0].load() != 0 || _counters[1].load() != 0) { + std::this_thread::yield(); + } + } + + template + auto read(F&& readFunc) const { + detail::IncrementRAII _increment_counter( + &_counters[_foregroundCounterIndex.load()]); + + return std::forward(readFunc)(_data[_foregroundDataIndex.load()]); + } + + // Throwing an exception in writeFunc is ok but causes the state to be either + // the old or the new state, depending on if the first or the second call to + // writeFunc threw. + template + auto write(F&& writeFunc) { + std::unique_lock lock(_writeMutex); + + return _write(std::forward(writeFunc)); + } + + private: + template + auto _write(const F& writeFunc) { + /* + * Assume, A is in background and B in foreground. In simplified terms, we + * want to do the following: + * 1. Write to A (old background) + * 2. Switch A/B + * 3. Write to B (new background) + * + * More detailed algorithm (explanations on why this is important are below + * in code): + * 1. Write to A + * 2. Switch A/B data pointers + * 3. Wait until A counter is zero + * 4. Switch A/B counters + * 5. Wait until B counter is zero + * 6. Write to B + */ + + auto localDataIndex = _foregroundDataIndex.load(); + + // 1. Write to A + _callWriteFuncOnBackgroundInstance(writeFunc, localDataIndex); + + // 2. Switch A/B data pointers + localDataIndex = localDataIndex ^ 1; + _foregroundDataIndex = localDataIndex; + + /* + * 3. Wait until A counter is zero + * + * In the previous write run, A was foreground and B was background. + * There was a time after switching _foregroundDataIndex (B to foreground) + * and before switching _foregroundCounterIndex, in which new readers could + * have read B but incremented A's counter. + * + * In this current run, we just switched _foregroundDataIndex (A back to + * foreground), but before writing to the new background B, we have to make + * sure A's counter was zero briefly, so all these old readers are gone. + */ + auto localCounterIndex = _foregroundCounterIndex.load(); + _waitForBackgroundCounterToBeZero(localCounterIndex); + + /* + * 4. Switch A/B counters + * + * Now that we know all readers on B are really gone, we can switch the + * counters and have new readers increment A's counter again, which is the + * correct counter since they're reading A. + */ + localCounterIndex = localCounterIndex ^ 1; + _foregroundCounterIndex = localCounterIndex; + + /* + * 5. Wait until B counter is zero + * + * This waits for all the readers on B that came in while both data and + * counter for B was in foreground, i.e. normal readers that happened + * outside of that brief gap between switching data and counter. + */ + _waitForBackgroundCounterToBeZero(localCounterIndex); + + // 6. Write to B + return _callWriteFuncOnBackgroundInstance(writeFunc, localDataIndex); + } + + template + auto _callWriteFuncOnBackgroundInstance( + const F& writeFunc, + uint8_t localDataIndex) { + try { + return writeFunc(_data[localDataIndex ^ 1]); + } catch (...) { + // recover invariant by copying from the foreground instance + _data[localDataIndex ^ 1] = _data[localDataIndex]; + // rethrow + throw; + } + } + + void _waitForBackgroundCounterToBeZero(uint8_t counterIndex) { + while (_counters[counterIndex ^ 1].load() != 0) { + std::this_thread::yield(); + } + } + + mutable std::array, 2> _counters; + std::atomic _foregroundCounterIndex; + std::atomic _foregroundDataIndex; + std::array _data; + std::mutex _writeMutex; +}; + +// RWSafeLeftRightWrapper is API compatible with LeftRight and uses a +// read-write lock to protect T (data). +template +class RWSafeLeftRightWrapper final { + public: + template + explicit RWSafeLeftRightWrapper(const Args&... args) : data_{args...} {} + + // RWSafeLeftRightWrapper is not copyable or moveable since LeftRight + // is not copyable or moveable. + RWSafeLeftRightWrapper(const RWSafeLeftRightWrapper&) = delete; + RWSafeLeftRightWrapper(RWSafeLeftRightWrapper&&) noexcept = delete; + RWSafeLeftRightWrapper& operator=(const RWSafeLeftRightWrapper&) = delete; + RWSafeLeftRightWrapper& operator=(RWSafeLeftRightWrapper&&) noexcept = delete; + ~RWSafeLeftRightWrapper() = default; + + template + // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward) + auto read(F&& readFunc) const { + return data_.withLock( + [&readFunc](T const& data) { return std::forward(readFunc)(data); }); + } + + template + // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward) + auto write(F&& writeFunc) { + return data_.withLock( + [&writeFunc](T& data) { return std::forward(writeFunc)(data); }); + } + + private: + c10::Synchronized data_; +}; + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Load.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Load.h new file mode 100644 index 0000000000000000000000000000000000000000..38aef4c1ea38d790799e49f3f594ff8c3c7a0d78 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Load.h @@ -0,0 +1,43 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include +#include + +namespace c10 { +namespace detail { + +template +struct LoadImpl { + C10_HOST_DEVICE static T apply(const void* src) { + return *reinterpret_cast(src); + } +}; + +template <> +struct LoadImpl { + C10_HOST_DEVICE static bool apply(const void* src) { + static_assert(sizeof(bool) == sizeof(char)); + // NOTE: [Loading boolean values] + // Protect against invalid boolean values by loading as a byte + // first, then converting to bool (see gh-54789). + return *reinterpret_cast(src); + } +}; + +} // namespace detail + +template +C10_HOST_DEVICE constexpr T load(const void* src) { + return c10::detail::LoadImpl::apply(src); +} + +template +C10_HOST_DEVICE constexpr scalar_t load(const scalar_t* src) { + return c10::detail::LoadImpl::apply(src); +} + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Logging.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Logging.h new file mode 100644 index 0000000000000000000000000000000000000000..49420110eb333a07e7b15cc6f21a8c77af52e84d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Logging.h @@ -0,0 +1,378 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#ifndef C10_UTIL_LOGGING_H_ +#define C10_UTIL_LOGGING_H_ + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +// CAFFE2_LOG_THRESHOLD is a compile time flag that would allow us to turn off +// logging at compile time so no logging message below that level is produced +// at all. The value should be between INT_MIN and CAFFE_FATAL. +#ifndef CAFFE2_LOG_THRESHOLD +// If we have not defined the compile time log threshold, we keep all the +// log cases. +#define CAFFE2_LOG_THRESHOLD INT_MIN +#endif // CAFFE2_LOG_THRESHOLD + +// Below are different implementations for glog and non-glog cases. +#ifdef C10_USE_GLOG +#include +#else // !C10_USE_GLOG +#include +#endif // C10_USE_GLOG + +C10_DECLARE_int(caffe2_log_level); +C10_DECLARE_bool(caffe2_use_fatal_for_enforce); + +// Some versions of GLOG support less-spammy version of LOG_EVERY_MS. If it's +// not available - just short-circuit to the always working one one. +// We define the C10_ name to avoid confusing other files +#ifdef LOG_EVERY_MS +#define C10_LOG_EVERY_MS(severity, ms) LOG_EVERY_MS(severity, ms) +#else +#define C10_LOG_EVERY_MS(severity, ms) LOG(severity) +#endif + +// Same for LOG_FIRST_N +#ifdef LOG_FIRST_N +#define C10_LOG_FIRST_N(severity, n) LOG_FIRST_N(severity, n) +#else +#define C10_LOG_FIRST_N(severity, n) LOG(severity) +#endif + +// Same for LOG_EVERY_N +#ifdef LOG_EVERY_N +#define C10_LOG_EVERY_N(severity, n) LOG_EVERY_N(severity, n) +#else +#define C10_LOG_EVERY_N(severity, n) LOG(severity) +#endif + +namespace c10 { + +#if !defined(C10_NODEPRECATED) +using std::string; +#endif + +// Functions that we use for initialization. +C10_API bool InitCaffeLogging(int* argc, char** argv); +C10_API void UpdateLoggingLevelsFromFlags(); + +[[noreturn]] C10_API void ThrowEnforceNotMet( + const char* file, + const int line, + const char* condition, + const std::string& msg, + const void* caller = nullptr); + +[[noreturn]] C10_API void ThrowEnforceNotMet( + const char* file, + const int line, + const char* condition, + const char* msg, + const void* caller = nullptr); + +[[noreturn]] inline void ThrowEnforceNotMet( + const char* file, + const int line, + const char* condition, + detail::CompileTimeEmptyString /*msg*/, + const void* caller = nullptr) { + ThrowEnforceNotMet(file, line, condition, "", caller); +} + +[[noreturn]] C10_API void ThrowEnforceFiniteNotMet( + const char* file, + const int line, + const char* condition, + const std::string& msg, + const void* caller = nullptr); + +[[noreturn]] C10_API void ThrowEnforceFiniteNotMet( + const char* file, + const int line, + const char* condition, + const char* msg, + const void* caller = nullptr); + +[[noreturn]] inline void ThrowEnforceFiniteNotMet( + const char* file, + const int line, + const char* condition, + detail::CompileTimeEmptyString /*msg*/, + const void* caller = nullptr) { + ThrowEnforceFiniteNotMet(file, line, condition, "", caller); +} + +constexpr bool IsUsingGoogleLogging() { +#ifdef C10_USE_GLOG + return true; +#else + return false; +#endif +} + +/** + * A utility to allow one to show log info to stderr after the program starts. + * + * This is similar to calling GLOG's --logtostderr, or setting caffe2_log_level + * to smaller than INFO. You are recommended to only use this in a few sparse + * cases, such as when you want to write a tutorial or something. Normally, use + * the commandline flags to set the log level. + */ +C10_API void ShowLogInfoToStderr(); + +C10_API void SetStackTraceFetcher(std::function<::c10::Backtrace()> fetcher); + +/** + * Convenience function for non-lazy stack trace fetchers. The Backtrace + * overload should be preferred when stringifying the backtrace is expensive. + */ +C10_API void SetStackTraceFetcher(std::function fetcher); + +using EnforceNotMet = ::c10::Error; + +#define CAFFE_ENFORCE(condition, ...) \ + do { \ + if (C10_UNLIKELY(!(condition))) { \ + ::c10::ThrowEnforceNotMet( \ + __FILE__, __LINE__, #condition, ::c10::str(__VA_ARGS__)); \ + } \ + } while (false) + +#define CAFFE_ENFORCE_FINITE(condition, ...) \ + do { \ + if (C10_UNLIKELY(!(condition))) { \ + ::c10::ThrowEnforceFiniteNotMet( \ + __FILE__, __LINE__, #condition, ::c10::str(__VA_ARGS__)); \ + } \ + } while (false) + +#define CAFFE_ENFORCE_WITH_CALLER(condition, ...) \ + do { \ + if (C10_UNLIKELY(!(condition))) { \ + ::c10::ThrowEnforceNotMet( \ + __FILE__, __LINE__, #condition, ::c10::str(__VA_ARGS__), this); \ + } \ + } while (false) + +#define CAFFE_THROW(...) \ + ::c10::ThrowEnforceNotMet(__FILE__, __LINE__, "", ::c10::str(__VA_ARGS__)) + +/** + * Rich logging messages + * + * CAFFE_ENFORCE_THAT can be used with one of the "checker functions" that + * capture input argument values and add it to the exception message. E.g. + * `CAFFE_ENFORCE_THAT(Equals(foo(x), bar(y)), "Optional additional message")` + * would evaluate both foo and bar only once and if the results are not equal - + * include them in the exception message. + * + * Some of the basic checker functions like Equals or Greater are already + * defined below. Other header might define customized checkers by adding + * functions to caffe2::enforce_detail namespace. For example: + * + * namespace caffe2 { namespace enforce_detail { + * inline EnforceFailMessage IsVector(const vector& shape) { + * if (shape.size() == 1) { return EnforceOK(); } + * return c10::str("Shape ", shape, " is not a vector"); + * } + * }} + * + * With further usages like `CAFFE_ENFORCE_THAT(IsVector(Input(0).dims()))` + * + * Convenient wrappers for binary operations like CAFFE_ENFORCE_EQ are provided + * too. Please use them instead of TORCH_CHECK_EQ and friends for failures in + * user-provided input. + */ + +namespace enforce_detail { + +template +std::string enforceFailMsgImpl(const T1& x, const T2& y) { + return c10::str(x, " vs ", y); +} + +template +std::string enforceFailMsgImpl(const T1& x, const T2& y, const Args&... args) { + return c10::str(x, " vs ", y, ". ", args...); +} + +template +void enforceThatImpl( + Pred p, + const T1& lhs, + const T2& rhs, + const char* file, + int line, + const char* expr, + const void* caller, + GetFailMsgFunc getFailMsg) { + if (C10_UNLIKELY(!(p(lhs, rhs)))) { + ::c10::ThrowEnforceNotMet(file, line, expr, getFailMsg(lhs, rhs), caller); + } +} + +#define CAFFE_ENFORCE_THAT_IMPL(op, lhs, rhs, expr, ...) \ + ::c10::enforce_detail::enforceThatImpl( \ + op, \ + (lhs), \ + (rhs), \ + __FILE__, \ + __LINE__, \ + expr, \ + nullptr, \ + [&](const auto& arg1, const auto& arg2) { \ + return ::c10::enforce_detail::enforceFailMsgImpl( \ + arg1, arg2, ##__VA_ARGS__); \ + }) + +#define CAFFE_ENFORCE_THAT_IMPL_WITH_CALLER(op, lhs, rhs, expr, ...) \ + ::c10::enforce_detail::enforceThatImpl( \ + op, \ + (lhs), \ + (rhs), \ + __FILE__, \ + __LINE__, \ + expr, \ + this, \ + [&](const auto& arg1, const auto& arg2) { \ + return ::c10::enforce_detail::enforceFailMsgImpl( \ + arg1, arg2, ##__VA_ARGS__); \ + }) + +} // namespace enforce_detail + +#define CAFFE_ENFORCE_THAT(cmp, op, lhs, rhs, ...) \ + CAFFE_ENFORCE_THAT_IMPL(cmp, lhs, rhs, #lhs " " #op " " #rhs, ##__VA_ARGS__) + +#define CAFFE_ENFORCE_BINARY_OP(cmp, op, x, y, ...) \ + CAFFE_ENFORCE_THAT_IMPL(cmp, x, y, #x " " #op " " #y, ##__VA_ARGS__) +#define CAFFE_ENFORCE_EQ(x, y, ...) \ + CAFFE_ENFORCE_BINARY_OP(std::equal_to(), ==, x, y, ##__VA_ARGS__) +#define CAFFE_ENFORCE_NE(x, y, ...) \ + CAFFE_ENFORCE_BINARY_OP(std::not_equal_to(), !=, x, y, ##__VA_ARGS__) +#define CAFFE_ENFORCE_LE(x, y, ...) \ + CAFFE_ENFORCE_BINARY_OP(std::less_equal(), <=, x, y, ##__VA_ARGS__) +#define CAFFE_ENFORCE_LT(x, y, ...) \ + CAFFE_ENFORCE_BINARY_OP(std::less(), <, x, y, ##__VA_ARGS__) +#define CAFFE_ENFORCE_GE(x, y, ...) \ + CAFFE_ENFORCE_BINARY_OP(std::greater_equal(), >=, x, y, ##__VA_ARGS__) +#define CAFFE_ENFORCE_GT(x, y, ...) \ + CAFFE_ENFORCE_BINARY_OP(std::greater(), >, x, y, ##__VA_ARGS__) + +#define CAFFE_ENFORCE_BINARY_OP_WITH_CALLER(cmp, op, x, y, ...) \ + CAFFE_ENFORCE_THAT_IMPL_WITH_CALLER( \ + cmp, x, y, #x " " #op " " #y, ##__VA_ARGS__) +#define CAFFE_ENFORCE_EQ_WITH_CALLER(x, y, ...) \ + CAFFE_ENFORCE_BINARY_OP_WITH_CALLER( \ + std::equal_to(), ==, x, y, ##__VA_ARGS__) +#define CAFFE_ENFORCE_NE_WITH_CALLER(x, y, ...) \ + CAFFE_ENFORCE_BINARY_OP_WITH_CALLER( \ + std::not_equal_to(), !=, x, y, ##__VA_ARGS__) +#define CAFFE_ENFORCE_LE_WITH_CALLER(x, y, ...) \ + CAFFE_ENFORCE_BINARY_OP_WITH_CALLER( \ + std::less_equal(), <=, x, y, ##__VA_ARGS__) +#define CAFFE_ENFORCE_LT_WITH_CALLER(x, y, ...) \ + CAFFE_ENFORCE_BINARY_OP_WITH_CALLER(std::less(), <, x, y, ##__VA_ARGS__) +#define CAFFE_ENFORCE_GE_WITH_CALLER(x, y, ...) \ + CAFFE_ENFORCE_BINARY_OP_WITH_CALLER( \ + std::greater_equal(), >=, x, y, ##__VA_ARGS__) +#define CAFFE_ENFORCE_GT_WITH_CALLER(x, y, ...) \ + CAFFE_ENFORCE_BINARY_OP_WITH_CALLER( \ + std::greater(), >, x, y, ##__VA_ARGS__) + +struct IValue; +class C10_API EventSampledHandler { + public: + virtual void log( + std::string_view model_id, + const std::vector& args) = 0; + virtual ~EventSampledHandler() = default; +}; + +#define C10_LOG_EVENT_SAMPLED(event, ...) \ + static const std::unique_ptr<::c10::EventSampledHandler>& \ + _##event##EventSampledHandler = ::c10::GetEventSampledHandler(#event); \ + if (_##event##EventSampledHandler) { \ + _##event##EventSampledHandler->log(__VA_ARGS__); \ + } + +// Must be called in the main thread before any other threads are spawned. +C10_API void InitEventSampledHandlers( + std::vector>> /*handlers*/); +C10_API const std::unique_ptr& GetEventSampledHandler( + std::string_view /*event*/); + +/** + * Very lightweight logging for the first time API usage. It's beneficial for + * tracking of individual functionality usage in larger applications. + * + * In order to ensure light-weightedness of logging, we utilize static variable + * trick - LogAPIUsage will be invoked only once and further invocations will + * just do an atomic check. + * + * Example: + * // Logs caller info with an arbitrary text event, if there is a usage. + * C10_LOG_API_USAGE_ONCE("my_api"); + */ +#define C10_LOG_API_USAGE_ONCE(...) \ + [[maybe_unused]] static bool C10_ANONYMOUS_VARIABLE(logFlag) = \ + ::c10::detail::LogAPIUsageFakeReturn(__VA_ARGS__); + +// API usage logging capabilities +C10_API void SetAPIUsageLogger(std::function logger); +C10_API void LogAPIUsage(const std::string& context); + +C10_API void SetAPIUsageMetadataLogger( + std::function& metadata_map)> logger); +C10_API void LogAPIUsageMetadata( + const std::string& context, + const std::map& metadata_map); + +// PyTorch ddp usage logging capabilities +// DDPLoggingData holds data that can be logged in applications +// for analysis and debugging. Data structure is defined in +// c10 directory so that it can be easily imported by both c10 +// and torch files. +struct DDPLoggingData { + // logging fields that are string types. + std::map strs_map; + // logging fields that are int64_t types. + std::map ints_map; +}; + +C10_API void SetPyTorchDDPUsageLogger( + std::function logger); +C10_API void LogPyTorchDDPUsage(const DDPLoggingData& ddpData); + +namespace detail { +// Return value is needed to do the static variable initialization trick +C10_API bool LogAPIUsageFakeReturn(const std::string& context); +} // namespace detail + +// Initializes the c10 logger. +C10_API void initLogging(); + +// Sets the rank, which will be included in log messages +C10_API void SetGlobalRank(int64_t rank); + +} // namespace c10 + +#endif // C10_UTIL_LOGGING_H_ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/MathConstants.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/MathConstants.h new file mode 100644 index 0000000000000000000000000000000000000000..f3e86ce2e1da5bc4d1d40ad22e4f31280ac16c2e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/MathConstants.h @@ -0,0 +1,147 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include + +C10_CLANG_DIAGNOSTIC_PUSH() +#if C10_CLANG_HAS_WARNING("-Wimplicit-float-conversion") +C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion") +#endif + +namespace c10 { +// TODO: Replace me with inline constexpr variable when C++17 becomes available +namespace detail { +template +C10_HOST_DEVICE inline constexpr T e() { + return static_cast(2.718281828459045235360287471352662); +} + +template +C10_HOST_DEVICE inline constexpr T euler() { + return static_cast(0.577215664901532860606512090082402); +} + +template +C10_HOST_DEVICE inline constexpr T frac_1_pi() { + return static_cast(0.318309886183790671537767526745028); +} + +template +C10_HOST_DEVICE inline constexpr T frac_1_sqrt_pi() { + return static_cast(0.564189583547756286948079451560772); +} + +template +C10_HOST_DEVICE inline constexpr T frac_sqrt_2() { + return static_cast(0.707106781186547524400844362104849); +} + +template +C10_HOST_DEVICE inline constexpr T frac_sqrt_3() { + return static_cast(0.577350269189625764509148780501957); +} + +template +C10_HOST_DEVICE inline constexpr T golden_ratio() { + return static_cast(1.618033988749894848204586834365638); +} + +template +C10_HOST_DEVICE inline constexpr T ln_10() { + return static_cast(2.302585092994045684017991454684364); +} + +template +C10_HOST_DEVICE inline constexpr T ln_2() { + return static_cast(0.693147180559945309417232121458176); +} + +template +C10_HOST_DEVICE inline constexpr T log_10_e() { + return static_cast(0.434294481903251827651128918916605); +} + +template +C10_HOST_DEVICE inline constexpr T log_2_e() { + return static_cast(1.442695040888963407359924681001892); +} + +template +C10_HOST_DEVICE inline constexpr T pi() { + return static_cast(3.141592653589793238462643383279502); +} + +template +C10_HOST_DEVICE inline constexpr T sqrt_2() { + return static_cast(1.414213562373095048801688724209698); +} + +template +C10_HOST_DEVICE inline constexpr T sqrt_3() { + return static_cast(1.732050807568877293527446341505872); +} + +template <> +C10_HOST_DEVICE inline constexpr BFloat16 pi() { + // According to + // https://en.wikipedia.org/wiki/Bfloat16_floating-point_format#Special_values + // pi is encoded as 4049 + return BFloat16(0x4049, BFloat16::from_bits()); +} + +template <> +C10_HOST_DEVICE inline constexpr Half pi() { + return Half(0x4248, Half::from_bits()); +} +} // namespace detail + +template +constexpr T e = c10::detail::e(); + +template +constexpr T euler = c10::detail::euler(); + +template +constexpr T frac_1_pi = c10::detail::frac_1_pi(); + +template +constexpr T frac_1_sqrt_pi = c10::detail::frac_1_sqrt_pi(); + +template +constexpr T frac_sqrt_2 = c10::detail::frac_sqrt_2(); + +template +constexpr T frac_sqrt_3 = c10::detail::frac_sqrt_3(); + +template +constexpr T golden_ratio = c10::detail::golden_ratio(); + +template +constexpr T ln_10 = c10::detail::ln_10(); + +template +constexpr T ln_2 = c10::detail::ln_2(); + +template +constexpr T log_10_e = c10::detail::log_10_e(); + +template +constexpr T log_2_e = c10::detail::log_2_e(); + +template +constexpr T pi = c10::detail::pi(); + +template +constexpr T sqrt_2 = c10::detail::sqrt_2(); + +template +constexpr T sqrt_3 = c10::detail::sqrt_3(); +} // namespace c10 + +C10_CLANG_DIAGNOSTIC_POP() + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/MaybeOwned.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/MaybeOwned.h new file mode 100644 index 0000000000000000000000000000000000000000..61e6ed82f27a4a2b91300f0987612f5a03c3bea1 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/MaybeOwned.h @@ -0,0 +1,242 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +#include +#include +#include + +namespace c10 { + +/// MaybeOwnedTraits describes how to borrow from T. Here is how we +/// can implement borrowing from an arbitrary type T using a raw +/// pointer to const: +template +struct MaybeOwnedTraitsGenericImpl { + using owned_type = T; + using borrow_type = const T*; + + static borrow_type createBorrow(const owned_type& from) { + return &from; + } + + static void assignBorrow(borrow_type& lhs, borrow_type rhs) { + lhs = rhs; + } + + static void destroyBorrow(borrow_type& /*toDestroy*/) {} + + static const owned_type& referenceFromBorrow(const borrow_type& borrow) { + return *borrow; + } + + static const owned_type* pointerFromBorrow(const borrow_type& borrow) { + return borrow; + } + + static bool debugBorrowIsValid(const borrow_type& borrow) { + return borrow != nullptr; + } +}; + +/// It is possible to eliminate the extra layer of indirection for +/// borrows for some types that we control. For examples, see +/// intrusive_ptr.h and TensorBody.h. + +template +struct MaybeOwnedTraits; + +// Explicitly enable MaybeOwned>, rather than allowing +// MaybeOwned to be used for any type right away. +template +struct MaybeOwnedTraits> + : public MaybeOwnedTraitsGenericImpl> {}; + +/// A smart pointer around either a borrowed or owned T. When +/// constructed with borrowed(), the caller MUST ensure that the +/// borrowed-from argument outlives this MaybeOwned. Compare to +/// Rust's std::borrow::Cow +/// (https://doc.rust-lang.org/std/borrow/enum.Cow.html), but note +/// that it is probably not suitable for general use because C++ has +/// no borrow checking. Included here to support +/// Tensor::expect_contiguous. +template +class MaybeOwned final { + using borrow_type = typename MaybeOwnedTraits::borrow_type; + using owned_type = typename MaybeOwnedTraits::owned_type; + + bool isBorrowed_; + union { + borrow_type borrow_; + owned_type own_; + }; + + /// Don't use this; use borrowed() instead. + explicit MaybeOwned(const owned_type& t) + : isBorrowed_(true), borrow_(MaybeOwnedTraits::createBorrow(t)) {} + + /// Don't use this; use owned() instead. + explicit MaybeOwned(T&& t) noexcept(std::is_nothrow_move_constructible_v) + : isBorrowed_(false), own_(std::move(t)) {} + + /// Don't use this; use owned() instead. + template + explicit MaybeOwned(std::in_place_t /*unused*/, Args&&... args) + : isBorrowed_(false), own_(std::forward(args)...) {} + + public: + explicit MaybeOwned() : isBorrowed_(true), borrow_() {} + + // Copying a borrow yields another borrow of the original, as with a + // T*. Copying an owned T yields another owned T for safety: no + // chains of borrowing by default! (Note you could get that behavior + // with MaybeOwned::borrowed(*rhs) if you wanted it.) + MaybeOwned(const MaybeOwned& rhs) : isBorrowed_(rhs.isBorrowed_) { + if (C10_LIKELY(rhs.isBorrowed_)) { + MaybeOwnedTraits::assignBorrow(borrow_, rhs.borrow_); + } else { + new (&own_) T(rhs.own_); + } + } + + MaybeOwned& operator=(const MaybeOwned& rhs) { + if (this == &rhs) { + return *this; + } + if (C10_UNLIKELY(!isBorrowed_)) { + if (rhs.isBorrowed_) { + own_.~T(); + MaybeOwnedTraits::assignBorrow(borrow_, rhs.borrow_); + isBorrowed_ = true; + } else { + own_ = rhs.own_; + } + } else { + if (C10_LIKELY(rhs.isBorrowed_)) { + MaybeOwnedTraits::assignBorrow(borrow_, rhs.borrow_); + } else { + MaybeOwnedTraits::destroyBorrow(borrow_); + new (&own_) T(rhs.own_); + isBorrowed_ = false; + } + } + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isBorrowed_ == rhs.isBorrowed_); + return *this; + } + + MaybeOwned(MaybeOwned&& rhs) noexcept( + // NOLINTNEXTLINE(*-noexcept-move-*) + std::is_nothrow_move_constructible_v && + std::is_nothrow_move_assignable_v) + : isBorrowed_(rhs.isBorrowed_) { + if (C10_LIKELY(rhs.isBorrowed_)) { + MaybeOwnedTraits::assignBorrow(borrow_, rhs.borrow_); + } else { + new (&own_) T(std::move(rhs.own_)); + } + } + + MaybeOwned& operator=(MaybeOwned&& rhs) noexcept( + std::is_nothrow_move_assignable_v && + std::is_nothrow_move_assignable_v && + std::is_nothrow_move_constructible_v && + // NOLINTNEXTLINE(*-noexcept-move-*) + std::is_nothrow_destructible_v && + std::is_nothrow_destructible_v) { + if (this == &rhs) { + return *this; + } + if (C10_UNLIKELY(!isBorrowed_)) { + if (rhs.isBorrowed_) { + own_.~T(); + MaybeOwnedTraits::assignBorrow(borrow_, rhs.borrow_); + isBorrowed_ = true; + } else { + own_ = std::move(rhs.own_); + } + } else { + if (C10_LIKELY(rhs.isBorrowed_)) { + MaybeOwnedTraits::assignBorrow(borrow_, rhs.borrow_); + } else { + MaybeOwnedTraits::destroyBorrow(borrow_); + new (&own_) T(std::move(rhs.own_)); + isBorrowed_ = false; + } + } + return *this; + } + + static MaybeOwned borrowed(const T& t) { + return MaybeOwned(t); + } + + static MaybeOwned owned(T&& t) noexcept( + std::is_nothrow_move_constructible_v) { + return MaybeOwned(std::move(t)); + } + + template + static MaybeOwned owned(std::in_place_t /*unused*/, Args&&... args) { + return MaybeOwned(std::in_place, std::forward(args)...); + } + + ~MaybeOwned() noexcept( + // NOLINTNEXTLINE(*-noexcept-destructor) + std::is_nothrow_destructible_v && + std::is_nothrow_destructible_v) { + if (C10_UNLIKELY(!isBorrowed_)) { + own_.~T(); + } else { + MaybeOwnedTraits::destroyBorrow(borrow_); + } + } + + // This is an implementation detail! You should know what you're doing + // if you are testing this. If you just want to guarantee ownership move + // this into a T + bool unsafeIsBorrowed() const { + return isBorrowed_; + } + + const T& operator*() const& { + if (isBorrowed_) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + MaybeOwnedTraits::debugBorrowIsValid(borrow_)); + } + return C10_LIKELY(isBorrowed_) + ? MaybeOwnedTraits::referenceFromBorrow(borrow_) + : own_; + } + + const T* operator->() const { + if (isBorrowed_) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + MaybeOwnedTraits::debugBorrowIsValid(borrow_)); + } + return C10_LIKELY(isBorrowed_) + ? MaybeOwnedTraits::pointerFromBorrow(borrow_) + : &own_; + } + + // If borrowed, copy the underlying T. If owned, move from + // it. borrowed/owned state remains the same, and either we + // reference the same borrow as before or we are an owned moved-from + // T. + T operator*() && { + if (isBorrowed_) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + MaybeOwnedTraits::debugBorrowIsValid(borrow_)); + return MaybeOwnedTraits::referenceFromBorrow(borrow_); + } else { + return std::move(own_); + } + } +}; + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Metaprogramming.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Metaprogramming.h new file mode 100644 index 0000000000000000000000000000000000000000..55c3fb2ba6db0dbc8bf8d00e616aefc3acab7c85 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Metaprogramming.h @@ -0,0 +1,6 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/NetworkFlow.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/NetworkFlow.h new file mode 100644 index 0000000000000000000000000000000000000000..e029ae65773be41aa7d05402fc3e1c3d50dbb8a8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/NetworkFlow.h @@ -0,0 +1,59 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include + +#include +#include + +/** + * This file provides a network flow implementation. + * https://en.wikipedia.org/wiki/Flow_network + * + * It aims to mirror some of the behavior of networkx, which is/was used by + * functorch partitioners for splitting the graph into a forward and backward + * graph. + */ + +namespace c10 { + +enum class C10_API_ENUM MinCutStatus { + SUCCESS = 0, + UNBOUNDED = 1, + OVERFLOW_INF = 2, + INVALID = 3, +}; + +struct MinCutResult { + MinCutStatus status; + int64_t max_flow; + std::vector reachable; + std::vector unreachable; +}; + +// Modeled after networkx implementation +class C10_API NetworkFlowGraph { + public: + // selected such that INF + INF is < INT64_MAX + constexpr static int64_t INF = (1LL << 62) - 1; + + struct Edge { + std::string source, dest; + int64_t capacity; + }; + + MinCutStatus add_edge( + const std::string& source, + const std::string& dest, + int64_t capacity = 1); + + MinCutResult minimum_cut(const std::string& s, const std::string& t) const; + + std::vector edges; +}; + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Optional.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Optional.h new file mode 100644 index 0000000000000000000000000000000000000000..55c4697368c60f86b69db1b1bc65cf0cb2e99404 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Optional.h @@ -0,0 +1,65 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#ifndef C10_UTIL_OPTIONAL_H_ +#define C10_UTIL_OPTIONAL_H_ + +#include +#include + +// Macros.h is not needed, but it does namespace shenanigans that lots +// of downstream code seems to rely on. Feel free to remove it and fix +// up builds. + +namespace c10 { + +#if !defined(FBCODE_CAFFE2) && !defined(C10_NODEPRECATED) +// NOLINTNEXTLINE(misc-unused-using-decls) +using std::bad_optional_access; +// NOLINTNEXTLINE(misc-unused-using-decls) +using std::make_optional; +// NOLINTNEXTLINE(misc-unused-using-decls) +using std::nullopt; +// NOLINTNEXTLINE(misc-unused-using-decls) +using std::nullopt_t; +// NOLINTNEXTLINE(misc-unused-using-decls) +using std::optional; +#endif + +#if !defined(FBCODE_CAFFE2) && !defined(C10_NODEPRECATED) + +namespace detail_ { +// the call to convert(b) has return type A and converts b to type A iff b +// decltype(b) is implicitly convertible to A +template +constexpr U convert(U v) { + return v; +} +} // namespace detail_ +template +[[deprecated( + "Please use std::optional::value_or instead of c10::value_or_else")]] constexpr T +value_or_else(const std::optional& v, F&& func) { + static_assert( + std::is_convertible_v, T>, + "func parameters must be a callable that returns a type convertible to the value stored in the optional"); + return v.has_value() ? *v : detail_::convert(std::forward(func)()); +} + +template +[[deprecated( + "Please use std::optional::value_or instead of c10::value_or_else")]] constexpr T +value_or_else(std::optional&& v, F&& func) { + static_assert( + std::is_convertible_v, T>, + "func parameters must be a callable that returns a type convertible to the value stored in the optional"); + return v.has_value() ? constexpr_move(std::move(v).contained_val()) + : detail_::convert(std::forward(func)()); +} + +#endif + +} // namespace c10 +#endif // C10_UTIL_OPTIONAL_H_ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/OptionalArrayRef.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/OptionalArrayRef.h new file mode 100644 index 0000000000000000000000000000000000000000..cd15a5f19d1db7673c8f3485a136d1730e34f433 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/OptionalArrayRef.h @@ -0,0 +1,242 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// This file defines OptionalArrayRef, a class that has almost the same +// exact functionality as std::optional>, except that its +// converting constructor fixes a dangling pointer issue. +// +// The implicit converting constructor of both std::optional> and +// std::optional> can cause the underlying ArrayRef to store +// a dangling pointer. OptionalArrayRef prevents this by wrapping +// a std::optional> and fixing the constructor implementation. +// +// See https://github.com/pytorch/pytorch/issues/63645 for more on this. + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace c10 { + +template +class OptionalArrayRef final { + public: + // Constructors + + constexpr OptionalArrayRef() noexcept = default; + + constexpr OptionalArrayRef(std::nullopt_t /*unused*/) noexcept {} + + OptionalArrayRef(const OptionalArrayRef& other) = default; + + OptionalArrayRef(OptionalArrayRef&& other) noexcept = default; + + constexpr OptionalArrayRef(const std::optional>& other) noexcept + : wrapped_opt_array_ref(other) {} + + constexpr OptionalArrayRef(std::optional>&& other) noexcept + : wrapped_opt_array_ref(std::move(other)) {} + + constexpr OptionalArrayRef(const T& value) noexcept + : wrapped_opt_array_ref(value) {} + + template < + typename U = ArrayRef, + std::enable_if_t< + !std::is_same_v, OptionalArrayRef> && + !std::is_same_v, std::in_place_t> && + std::is_constructible_v, U&&> && + std::is_convertible_v> && + !std::is_convertible_v, + bool> = false> + constexpr OptionalArrayRef(U&& value) noexcept( + std::is_nothrow_constructible_v, U&&>) + : wrapped_opt_array_ref(std::forward(value)) {} + + template < + typename U = ArrayRef, + std::enable_if_t< + !std::is_same_v, OptionalArrayRef> && + !std::is_same_v, std::in_place_t> && + std::is_constructible_v, U&&> && + !std::is_convertible_v>, + bool> = false> + constexpr explicit OptionalArrayRef(U&& value) noexcept( + std::is_nothrow_constructible_v, U&&>) + : wrapped_opt_array_ref(std::forward(value)) {} + + template + constexpr explicit OptionalArrayRef( + std::in_place_t ip, + Args&&... args) noexcept + : wrapped_opt_array_ref(ip, std::forward(args)...) {} + + template + constexpr explicit OptionalArrayRef( + std::in_place_t ip, + std::initializer_list il, + Args&&... args) + : wrapped_opt_array_ref(ip, il, std::forward(args)...) {} + + constexpr OptionalArrayRef(const std::initializer_list& Vec) + : wrapped_opt_array_ref(ArrayRef(Vec)) {} + + // Destructor + + ~OptionalArrayRef() = default; + + // Assignment + + constexpr OptionalArrayRef& operator=(std::nullopt_t /*unused*/) noexcept { + wrapped_opt_array_ref = std::nullopt; + return *this; + } + + OptionalArrayRef& operator=(const OptionalArrayRef& other) = default; + + OptionalArrayRef& operator=(OptionalArrayRef&& other) noexcept = default; + + constexpr OptionalArrayRef& operator=( + const std::optional>& other) noexcept { + wrapped_opt_array_ref = other; + return *this; + } + + constexpr OptionalArrayRef& operator=( + std::optional>&& other) noexcept { + wrapped_opt_array_ref = std::move(other); + return *this; + } + + template < + typename U = ArrayRef, + typename = std::enable_if_t< + !std::is_same_v, OptionalArrayRef> && + std::is_constructible_v, U&&> && + std::is_assignable_v&, U&&>>> + constexpr OptionalArrayRef& operator=(U&& value) noexcept( + std::is_nothrow_constructible_v, U&&> && + std::is_nothrow_assignable_v&, U&&>) { + wrapped_opt_array_ref = std::forward(value); + return *this; + } + + // Observers + + constexpr ArrayRef* operator->() noexcept { + return &wrapped_opt_array_ref.value(); + } + + constexpr const ArrayRef* operator->() const noexcept { + return &wrapped_opt_array_ref.value(); + } + + constexpr ArrayRef& operator*() & noexcept { + return wrapped_opt_array_ref.value(); + } + + constexpr const ArrayRef& operator*() const& noexcept { + return wrapped_opt_array_ref.value(); + } + + constexpr ArrayRef&& operator*() && noexcept { + return std::move(wrapped_opt_array_ref.value()); + } + + constexpr const ArrayRef&& operator*() const&& noexcept { + return std::move(wrapped_opt_array_ref.value()); + } + + constexpr explicit operator bool() const noexcept { + return wrapped_opt_array_ref.has_value(); + } + + constexpr bool has_value() const noexcept { + return wrapped_opt_array_ref.has_value(); + } + + constexpr ArrayRef& value() & { + return wrapped_opt_array_ref.value(); + } + + constexpr const ArrayRef& value() const& { + // NOLINTNEXTLINE(bugprone-unchecked-optional-access) + return wrapped_opt_array_ref.value(); + } + + constexpr ArrayRef&& value() && { + return std::move(wrapped_opt_array_ref.value()); + } + + constexpr const ArrayRef&& value() const&& { + return std::move(wrapped_opt_array_ref.value()); + } + + template + constexpr std:: + enable_if_t>, ArrayRef> + value_or(U&& default_value) const& { + return wrapped_opt_array_ref.value_or(std::forward(default_value)); + } + + template + constexpr std:: + enable_if_t>, ArrayRef> + value_or(U&& default_value) && { + return wrapped_opt_array_ref.value_or(std::forward(default_value)); + } + + // Modifiers + + constexpr void swap(OptionalArrayRef& other) noexcept { + std::swap(wrapped_opt_array_ref, other.wrapped_opt_array_ref); + } + + constexpr void reset() noexcept { + wrapped_opt_array_ref.reset(); + } + + template + constexpr std:: + enable_if_t, Args&&...>, ArrayRef&> + emplace(Args&&... args) noexcept( + std::is_nothrow_constructible_v, Args&&...>) { + return wrapped_opt_array_ref.emplace(std::forward(args)...); + } + + template + constexpr ArrayRef& emplace( + std::initializer_list il, + Args&&... args) noexcept { + return wrapped_opt_array_ref.emplace(il, std::forward(args)...); + } + + private: + std::optional> wrapped_opt_array_ref; +}; + +using OptionalIntArrayRef = OptionalArrayRef; + +inline bool operator==( + const OptionalIntArrayRef& a1, + const IntArrayRef& other) { + if (!a1.has_value()) { + return false; + } + return a1.value() == other; +} + +inline bool operator==( + const c10::IntArrayRef& a1, + const c10::OptionalIntArrayRef& a2) { + return a2 == a1; +} + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/ParallelGuard.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/ParallelGuard.h new file mode 100644 index 0000000000000000000000000000000000000000..e577497980fbf93d2e928b9c879f085cc1852a4d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/ParallelGuard.h @@ -0,0 +1,25 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include + +namespace c10 { + +// RAII thread local guard that tracks whether code is being executed in +// `at::parallel_for` or `at::parallel_reduce` loop function. +class C10_API ParallelGuard { + public: + static bool is_enabled(); + + ParallelGuard(bool state); + ~ParallelGuard(); + + private: + bool previous_state_; +}; + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Registry.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Registry.h new file mode 100644 index 0000000000000000000000000000000000000000..92d1809d8c3094d19c927d9594afab15eba475ad --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Registry.h @@ -0,0 +1,334 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#ifndef C10_UTIL_REGISTRY_H_ +#define C10_UTIL_REGISTRY_H_ + +/** + * Simple registry implementation that uses static variables to + * register object creators during program initialization time. + */ + +// NB: This Registry works poorly when you have other namespaces. +// Make all macro invocations from inside the at namespace. + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace c10 { + +template +inline std::string KeyStrRepr(const KeyType& /*key*/) { + return "[key type printing not supported]"; +} + +template <> +inline std::string KeyStrRepr(const std::string& key) { + return key; +} + +enum RegistryPriority { + REGISTRY_FALLBACK = 1, + REGISTRY_DEFAULT = 2, + REGISTRY_PREFERRED = 3, +}; + +/** + * @brief A template class that allows one to register classes by keys. + * + * The keys are usually a std::string specifying the name, but can be anything + * that can be used in a std::map. + * + * You should most likely not use the Registry class explicitly, but use the + * helper macros below to declare specific registries as well as registering + * objects. + */ +template +class Registry { + public: + typedef std::function Creator; + + Registry(bool warning = true) : registry_(), priority_(), warning_(warning) {} + ~Registry() = default; + + void Register( + const SrcType& key, + Creator creator, + const RegistryPriority priority = REGISTRY_DEFAULT) { + std::lock_guard lock(register_mutex_); + // The if statement below is essentially the same as the following line: + // TORCH_CHECK_EQ(registry_.count(key), 0) << "Key " << key + // << " registered twice."; + // However, TORCH_CHECK_EQ depends on google logging, and since registration + // is carried out at static initialization time, we do not want to have an + // explicit dependency on glog's initialization function. + if (registry_.count(key) != 0) { + auto cur_priority = priority_[key]; + if (priority > cur_priority) { +#ifdef DEBUG + std::string warn_msg = + "Overwriting already registered item for key " + KeyStrRepr(key); + fprintf(stderr, "%s\n", warn_msg.c_str()); +#endif + registry_[key] = creator; + priority_[key] = priority; + } else if (priority == cur_priority) { + std::string err_msg = + "Key already registered with the same priority: " + KeyStrRepr(key); + fprintf(stderr, "%s\n", err_msg.c_str()); + if (terminate_) { + std::exit(1); + } else { + throw std::runtime_error(err_msg); + } + } else if (warning_) { + std::string warn_msg = + "Higher priority item already registered, skipping registration of " + + KeyStrRepr(key); + fprintf(stderr, "%s\n", warn_msg.c_str()); + } + } else { + registry_[key] = creator; + priority_[key] = priority; + } + } + + void Register( + const SrcType& key, + Creator creator, + const std::string& help_msg, + const RegistryPriority priority = REGISTRY_DEFAULT) { + Register(key, creator, priority); + help_message_[key] = help_msg; + } + + inline bool Has(const SrcType& key) { + return (registry_.count(key) != 0); + } + + ObjectPtrType Create(const SrcType& key, Args... args) { + auto it = registry_.find(key); + if (it == registry_.end()) { + // Returns nullptr if the key is not registered. + return nullptr; + } + return it->second(args...); + } + + /** + * Returns the keys currently registered as a std::vector. + */ + std::vector Keys() const { + std::vector keys; + keys.reserve(registry_.size()); + for (const auto& it : registry_) { + keys.push_back(it.first); + } + return keys; + } + + inline const std::unordered_map& HelpMessage() const { + return help_message_; + } + + const char* HelpMessage(const SrcType& key) const { + auto it = help_message_.find(key); + if (it == help_message_.end()) { + return nullptr; + } + return it->second.c_str(); + } + + // Used for testing, if terminate is unset, Registry throws instead of + // calling std::exit + void SetTerminate(bool terminate) { + terminate_ = terminate; + } + + C10_DISABLE_COPY_AND_ASSIGN(Registry); + Registry(Registry&&) = delete; + Registry& operator=(Registry&&) = delete; + + private: + std::unordered_map registry_; + std::unordered_map priority_; + bool terminate_{true}; + const bool warning_; + std::unordered_map help_message_; + std::mutex register_mutex_; +}; + +template +class Registerer { + public: + explicit Registerer( + const SrcType& key, + Registry* registry, + typename Registry::Creator creator, + const std::string& help_msg = "") { + registry->Register(key, creator, help_msg); + } + + explicit Registerer( + const SrcType& key, + const RegistryPriority priority, + Registry* registry, + typename Registry::Creator creator, + const std::string& help_msg = "") { + registry->Register(key, creator, help_msg, priority); + } + + template + static ObjectPtrType DefaultCreator(Args... args) { + return ObjectPtrType(new DerivedType(args...)); + } +}; + +/** + * C10_DECLARE_TYPED_REGISTRY is a macro that expands to a function + * declaration, as well as creating a convenient typename for its corresponding + * registerer. + */ +// Note on C10_IMPORT and C10_EXPORT below: we need to explicitly mark DECLARE +// as import and DEFINE as export, because these registry macros will be used +// in downstream shared libraries as well, and one cannot use *_API - the API +// macro will be defined on a per-shared-library basis. Semantically, when one +// declares a typed registry it is always going to be IMPORT, and when one +// defines a registry (which should happen ONLY ONCE and ONLY IN SOURCE FILE), +// the instantiation unit is always going to be exported. +// +// The only unique condition is when in the same file one does DECLARE and +// DEFINE - in Windows compilers, this generates a warning that dllimport and +// dllexport are mixed, but the warning is fine and linker will be properly +// exporting the symbol. Same thing happens in the gflags flag declaration and +// definition caes. +#define C10_DECLARE_TYPED_REGISTRY( \ + RegistryName, SrcType, ObjectType, PtrType, ...) \ + C10_API ::c10::Registry, ##__VA_ARGS__>* \ + RegistryName(); \ + typedef ::c10::Registerer, ##__VA_ARGS__> \ + Registerer##RegistryName + +#define TORCH_DECLARE_TYPED_REGISTRY( \ + RegistryName, SrcType, ObjectType, PtrType, ...) \ + TORCH_API ::c10::Registry, ##__VA_ARGS__>* \ + RegistryName(); \ + typedef ::c10::Registerer, ##__VA_ARGS__> \ + Registerer##RegistryName + +#define C10_DEFINE_TYPED_REGISTRY( \ + RegistryName, SrcType, ObjectType, PtrType, ...) \ + C10_EXPORT ::c10::Registry, ##__VA_ARGS__>* \ + RegistryName() { \ + static ::c10::Registry, ##__VA_ARGS__>* \ + registry = new ::c10:: \ + Registry, ##__VA_ARGS__>(); \ + return registry; \ + } + +#define C10_DEFINE_TYPED_REGISTRY_WITHOUT_WARNING( \ + RegistryName, SrcType, ObjectType, PtrType, ...) \ + C10_EXPORT ::c10::Registry, ##__VA_ARGS__>* \ + RegistryName() { \ + static ::c10::Registry, ##__VA_ARGS__>* \ + registry = \ + new ::c10::Registry, ##__VA_ARGS__>( \ + false); \ + return registry; \ + } + +// Note(Yangqing): The __VA_ARGS__ below allows one to specify a templated +// creator with comma in its templated arguments. +#define C10_REGISTER_TYPED_CREATOR(RegistryName, key, ...) \ + static Registerer##RegistryName C10_ANONYMOUS_VARIABLE(g_##RegistryName)( \ + key, RegistryName(), ##__VA_ARGS__); + +#define C10_REGISTER_TYPED_CREATOR_WITH_PRIORITY( \ + RegistryName, key, priority, ...) \ + static Registerer##RegistryName C10_ANONYMOUS_VARIABLE(g_##RegistryName)( \ + key, priority, RegistryName(), ##__VA_ARGS__); + +#define C10_REGISTER_TYPED_CLASS(RegistryName, key, ...) \ + static Registerer##RegistryName C10_ANONYMOUS_VARIABLE(g_##RegistryName)( \ + key, \ + RegistryName(), \ + Registerer##RegistryName::DefaultCreator<__VA_ARGS__>, \ + ::c10::demangle_type<__VA_ARGS__>()); + +#define C10_REGISTER_TYPED_CLASS_WITH_PRIORITY( \ + RegistryName, key, priority, ...) \ + static Registerer##RegistryName C10_ANONYMOUS_VARIABLE(g_##RegistryName)( \ + key, \ + priority, \ + RegistryName(), \ + Registerer##RegistryName::DefaultCreator<__VA_ARGS__>, \ + ::c10::demangle_type<__VA_ARGS__>()); + +// C10_DECLARE_REGISTRY and C10_DEFINE_REGISTRY are hard-wired to use +// std::string as the key type, because that is the most commonly used cases. +#define C10_DECLARE_REGISTRY(RegistryName, ObjectType, ...) \ + C10_DECLARE_TYPED_REGISTRY( \ + RegistryName, std::string, ObjectType, std::unique_ptr, ##__VA_ARGS__) + +#define TORCH_DECLARE_REGISTRY(RegistryName, ObjectType, ...) \ + TORCH_DECLARE_TYPED_REGISTRY( \ + RegistryName, std::string, ObjectType, std::unique_ptr, ##__VA_ARGS__) + +#define C10_DEFINE_REGISTRY(RegistryName, ObjectType, ...) \ + C10_DEFINE_TYPED_REGISTRY( \ + RegistryName, std::string, ObjectType, std::unique_ptr, ##__VA_ARGS__) + +#define C10_DEFINE_REGISTRY_WITHOUT_WARNING(RegistryName, ObjectType, ...) \ + C10_DEFINE_TYPED_REGISTRY_WITHOUT_WARNING( \ + RegistryName, std::string, ObjectType, std::unique_ptr, ##__VA_ARGS__) + +#define C10_DECLARE_SHARED_REGISTRY(RegistryName, ObjectType, ...) \ + C10_DECLARE_TYPED_REGISTRY( \ + RegistryName, std::string, ObjectType, std::shared_ptr, ##__VA_ARGS__) + +#define TORCH_DECLARE_SHARED_REGISTRY(RegistryName, ObjectType, ...) \ + TORCH_DECLARE_TYPED_REGISTRY( \ + RegistryName, std::string, ObjectType, std::shared_ptr, ##__VA_ARGS__) + +#define C10_DEFINE_SHARED_REGISTRY(RegistryName, ObjectType, ...) \ + C10_DEFINE_TYPED_REGISTRY( \ + RegistryName, std::string, ObjectType, std::shared_ptr, ##__VA_ARGS__) + +#define C10_DEFINE_SHARED_REGISTRY_WITHOUT_WARNING( \ + RegistryName, ObjectType, ...) \ + C10_DEFINE_TYPED_REGISTRY_WITHOUT_WARNING( \ + RegistryName, std::string, ObjectType, std::shared_ptr, ##__VA_ARGS__) + +// C10_REGISTER_CREATOR and C10_REGISTER_CLASS are hard-wired to use std::string +// as the key +// type, because that is the most commonly used cases. +#define C10_REGISTER_CREATOR(RegistryName, key, ...) \ + C10_REGISTER_TYPED_CREATOR(RegistryName, #key, __VA_ARGS__) + +#define C10_REGISTER_CREATOR_WITH_PRIORITY(RegistryName, key, priority, ...) \ + C10_REGISTER_TYPED_CREATOR_WITH_PRIORITY( \ + RegistryName, #key, priority, __VA_ARGS__) + +#define C10_REGISTER_CLASS(RegistryName, key, ...) \ + C10_REGISTER_TYPED_CLASS(RegistryName, #key, __VA_ARGS__) + +#define C10_REGISTER_CLASS_WITH_PRIORITY(RegistryName, key, priority, ...) \ + C10_REGISTER_TYPED_CLASS_WITH_PRIORITY( \ + RegistryName, #key, priority, __VA_ARGS__) + +} // namespace c10 + +#endif // C10_UTIL_REGISTRY_H_ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/ScopeExit.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/ScopeExit.h new file mode 100644 index 0000000000000000000000000000000000000000..fa4eaaceadd2588bbe53fcd51d3cbffde5d3b220 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/ScopeExit.h @@ -0,0 +1,55 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +namespace c10 { + +/** + * Mostly copied from https://llvm.org/doxygen/ScopeExit_8h_source.html + */ +template +class scope_exit { + Callable ExitFunction; + bool Engaged = true; // False once moved-from or release()d. + + public: + template + // NOLINTNEXTLINE(bugprone-forwarding-reference-overload) + explicit scope_exit(Fp&& F) : ExitFunction(std::forward(F)) {} + + scope_exit(scope_exit&& Rhs) noexcept + : ExitFunction(std::move(Rhs.ExitFunction)), Engaged(Rhs.Engaged) { + Rhs.release(); + } + scope_exit(const scope_exit&) = delete; + scope_exit& operator=(scope_exit&&) = delete; + scope_exit& operator=(const scope_exit&) = delete; + + void release() { + Engaged = false; + } + + ~scope_exit() { + if (Engaged) { + ExitFunction(); + } + } +}; + +// Keeps the callable object that is passed in, and execute it at the +// destruction of the returned object (usually at the scope exit where the +// returned object is kept). +// +// Interface is specified by p0052r2. +template +scope_exit> make_scope_exit(Callable&& F) { + return scope_exit>(std::forward(F)); +} + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Semaphore.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Semaphore.h new file mode 100644 index 0000000000000000000000000000000000000000..1a0e63680bee7b6aa9107d6f0a20fa50388d6acb --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Semaphore.h @@ -0,0 +1,76 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include + +/* + a simple semaphore interface. +*/ + +// note: __cpp_lib_semaphore will not be defined in some apple platforms +// even if >= C++20. +#if __has_include() && defined(__cpp_lib_semaphore) && __cpp_lib_semaphore >= 201907L +#define C10_SEMAPHORE_USE_STL +#endif + +#ifdef C10_SEMAPHORE_USE_STL +#include +#else +// To use moodycamel semaphore, we need to include the header file +// for concurrentqueue first. Hiding implementation detail here. +#ifdef BLOCK_SIZE +#pragma push_macro("BLOCK_SIZE") +#undef BLOCK_SIZE +#include // @manual +#pragma pop_macro("BLOCK_SIZE") +#else +#include // @manual +#endif + +#include // @manual +#endif + +namespace c10 { + +class Semaphore { + public: + Semaphore(int32_t initial_count = 0) : impl_(initial_count) {} + + void release(int32_t n = 1) { +#ifdef C10_SEMAPHORE_USE_STL + impl_.release(n); +#else + impl_.signal(n); +#endif + } + + void acquire() { +#ifdef C10_SEMAPHORE_USE_STL + impl_.acquire(); +#else + impl_.wait(); +#endif + } + + bool tryAcquire() { +#ifdef C10_SEMAPHORE_USE_STL + return impl_.try_acquire(); +#else + return impl_.tryWait(); +#endif + } + + private: +#ifdef C10_SEMAPHORE_USE_STL + std::counting_semaphore<> impl_; +#else + moodycamel::LightweightSemaphore impl_; +#endif +}; +} // namespace c10 + +#undef C10_SEMAPHORE_USE_STL + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/SmallBuffer.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/SmallBuffer.h new file mode 100644 index 0000000000000000000000000000000000000000..1c40d21a692f0470d02d25bc8794f1b8d58c55a0 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/SmallBuffer.h @@ -0,0 +1,92 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include +#include +#include +#include + +/** Helper class for allocating temporary fixed size arrays with SBO. + * + * This is intentionally much simpler than SmallVector, to improve performance + * at the expense of many features: + * - No zero-initialization for numeric types + * - No resizing after construction + * - No copy/move + * - No non-trivial types + */ + +namespace c10 { + +template +class SmallBuffer { + static_assert(std::is_trivial_v, "SmallBuffer is intended for POD types"); + + std::array storage_; + size_t size_{}; + T* data_{}; + + public: + SmallBuffer(size_t size) : size_(size) { + if (size > N) { + data_ = new T[size]; + } else { + data_ = &storage_[0]; + } + } + + SmallBuffer(const SmallBuffer&) = delete; + SmallBuffer& operator=(const SmallBuffer&) = delete; + + // move constructor is needed in function return + SmallBuffer(SmallBuffer&& rhs) noexcept : size_{rhs.size_} { + rhs.size_ = 0; + if (size_ > N) { + data_ = rhs.data_; + rhs.data_ = nullptr; + } else { + storage_ = std::move(rhs.storage_); + data_ = &storage_[0]; + } + } + + SmallBuffer& operator=(SmallBuffer&&) = delete; + + ~SmallBuffer() { + if (size_ > N) { + delete[] data_; + } + } + T& operator[](size_t idx) { + return data()[idx]; + } + const T& operator[](size_t idx) const { + return data()[idx]; + } + T* data() { + return data_; + } + const T* data() const { + return data_; + } + size_t size() const { + return size_; + } + T* begin() { + return data_; + } + const T* begin() const { + return data_; + } + T* end() { + return data_ + size_; + } + const T* end() const { + return data_ + size_; + } +}; + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/SmallVector.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/SmallVector.h new file mode 100644 index 0000000000000000000000000000000000000000..b2a4dbb0f92f530cd21dc8a63ee48f82f430393d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/SmallVector.h @@ -0,0 +1,1472 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +//===- llvm/ADT/SmallVector.h - 'Normally small' vectors --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the SmallVector class. +// +//===----------------------------------------------------------------------===// + +// ATen: modified from llvm::SmallVector. +// used std::is_trivially_{copy,move}_constructible +// replaced iterator_range constructor with inline Container&& constructor +// replaced LLVM_NODISCARD, LLVM_LIKELY, and LLVM_UNLIKELY with c10 equivalents +// removed LLVM_GSL_OWNER +// added SmallVector::at +// added operator<< for std::ostream +// added C10_API to export SmallVectorBase + +#pragma once + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace c10 { + +/// This is all the stuff common to all SmallVectors. +/// +/// The template parameter specifies the type which should be used to hold the +/// Size and Capacity of the SmallVector, so it can be adjusted. +/// Using 32 bit size is desirable to shrink the size of the SmallVector. +/// Using 64 bit size is desirable for cases like SmallVector, where a +/// 32 bit size would limit the vector to ~4GB. SmallVectors are used for +/// buffering bitcode output - which can exceed 4GB. +template +class C10_API SmallVectorBase { + protected: + void* BeginX; + Size_T Size = 0, Capacity; + + /// The maximum value of the Size_T used. + static constexpr size_t SizeTypeMax() { + return std::numeric_limits::max(); + } + + SmallVectorBase(void* FirstEl, size_t TotalCapacity) + : BeginX(FirstEl), Capacity(TotalCapacity) {} + + /// This is a helper for \a grow() that's out of line to reduce code + /// duplication. This function will report a fatal error if it can't grow at + /// least to \p MinSize. + void* mallocForGrow(size_t MinSize, size_t TSize, size_t& NewCapacity); + + /// This is an implementation of the grow() method which only works + /// on POD-like data types and is out of line to reduce code duplication. + /// This function will report a fatal error if it cannot increase capacity. + void grow_pod(const void* FirstEl, size_t MinSize, size_t TSize); + + public: + SmallVectorBase() = delete; + size_t size() const { + return Size; + } + size_t capacity() const { + return Capacity; + } + + [[nodiscard]] bool empty() const { + return !Size; + } + + /// Set the array size to \p N, which the current array must have enough + /// capacity for. + /// + /// This does not construct or destroy any elements in the vector. + /// + /// Clients can use this in conjunction with capacity() to write past the end + /// of the buffer when they know that more elements are available, and only + /// update the size later. This avoids the cost of value initializing elements + /// which will only be overwritten. + void set_size(size_t N) { + assert(N <= capacity()); + Size = N; + } +}; + +template +using SmallVectorSizeType = + std::conditional_t= 8, uint64_t, uint32_t>; + +/// Figure out the offset of the first element. +template +struct SmallVectorAlignmentAndSize { + // NOLINTNEXTLINE(*c-arrays*) + alignas(SmallVectorBase>) char Base[sizeof( + SmallVectorBase>)]; + // NOLINTNEXTLINE(*c-arrays*) + alignas(T) char FirstEl[sizeof(T)]; +}; + +/// This is the part of SmallVectorTemplateBase which does not depend on whether +/// the type T is a POD. The extra dummy template argument is used by ArrayRef +/// to avoid unnecessarily requiring T to be complete. +template +class SmallVectorTemplateCommon + : public SmallVectorBase> { + using Base = SmallVectorBase>; + + /// Find the address of the first element. For this pointer math to be valid + /// with small-size of 0 for T with lots of alignment, it's important that + /// SmallVectorStorage is properly-aligned even for small-size of 0. + void* getFirstEl() const { + return const_cast(reinterpret_cast( + reinterpret_cast(this) + + offsetof(SmallVectorAlignmentAndSize, FirstEl))); + } + // Space after 'FirstEl' is clobbered, do not add any instance vars after it. + + protected: + SmallVectorTemplateCommon(size_t Size) : Base(getFirstEl(), Size) {} + + void grow_pod(size_t MinSize, size_t TSize) { + Base::grow_pod(getFirstEl(), MinSize, TSize); + } + + /// Return true if this is a smallvector which has not had dynamic + /// memory allocated for it. + bool isSmall() const { + return this->BeginX == getFirstEl(); + } + + /// Put this vector in a state of being small. + void resetToSmall() { + this->BeginX = getFirstEl(); + this->Size = this->Capacity = 0; // FIXME: Setting Capacity to 0 is suspect. + } + + /// Return true if V is an internal reference to the given range. + bool isReferenceToRange(const void* V, const void* First, const void* Last) + const { + // Use std::less to avoid UB. + std::less<> LessThan; + return !LessThan(V, First) && LessThan(V, Last); + } + + /// Return true if V is an internal reference to this vector. + bool isReferenceToStorage(const void* V) const { + return isReferenceToRange(V, this->begin(), this->end()); + } + + /// Return true if First and Last form a valid (possibly empty) range in this + /// vector's storage. + bool isRangeInStorage(const void* First, const void* Last) const { + // Use std::less to avoid UB. + std::less<> LessThan; + return !LessThan(First, this->begin()) && !LessThan(Last, First) && + !LessThan(this->end(), Last); + } + + /// Return true unless Elt will be invalidated by resizing the vector to + /// NewSize. + bool isSafeToReferenceAfterResize(const void* Elt, size_t NewSize) { + // Past the end. + if (C10_LIKELY(!isReferenceToStorage(Elt))) + return true; + + // Return false if Elt will be destroyed by shrinking. + if (NewSize <= this->size()) + return Elt < this->begin() + NewSize; + + // Return false if we need to grow. + return NewSize <= this->capacity(); + } + + /// Check whether Elt will be invalidated by resizing the vector to NewSize. + void assertSafeToReferenceAfterResize(const void* Elt, size_t NewSize) { + (void)Elt; // Suppress unused variable warning + (void)NewSize; // Suppress unused variable warning + assert( + isSafeToReferenceAfterResize(Elt, NewSize) && + "Attempting to reference an element of the vector in an operation " + "that invalidates it"); + } + + /// Check whether Elt will be invalidated by increasing the size of the + /// vector by N. + void assertSafeToAdd(const void* Elt, size_t N = 1) { + this->assertSafeToReferenceAfterResize(Elt, this->size() + N); + } + + /// Check whether any part of the range will be invalidated by clearing. + void assertSafeToReferenceAfterClear(const T* From, const T* To) { + if (From == To) + return; + this->assertSafeToReferenceAfterResize(From, 0); + this->assertSafeToReferenceAfterResize(To - 1, 0); + } + template < + class ItTy, + std::enable_if_t, T*>, bool> = + false> + void assertSafeToReferenceAfterClear(ItTy /*unused*/, ItTy /*unused*/) {} + + /// Check whether any part of the range will be invalidated by growing. + void assertSafeToAddRange(const T* From, const T* To) { + if (From == To) + return; + this->assertSafeToAdd(From, To - From); + this->assertSafeToAdd(To - 1, To - From); + } + template < + class ItTy, + std::enable_if_t, T*>, bool> = + false> + void assertSafeToAddRange(ItTy /*unused*/, ItTy /*unused*/) {} + + /// Reserve enough space to add one element, and return the updated element + /// pointer in case it was a reference to the storage. + template + static const T* reserveForParamAndGetAddressImpl( + U* This, + const T& Elt, + size_t N) { + size_t NewSize = This->size() + N; + if (C10_LIKELY(NewSize <= This->capacity())) + return &Elt; + + bool ReferencesStorage = false; + int64_t Index = -1; + if constexpr (!U::TakesParamByValue) { + if (C10_UNLIKELY(This->isReferenceToStorage(&Elt))) { + ReferencesStorage = true; + Index = &Elt - This->begin(); + } + } + This->grow(NewSize); + return ReferencesStorage ? This->begin() + Index : &Elt; + } + + public: + using size_type = size_t; + using difference_type = ptrdiff_t; + using value_type = T; + using iterator = T*; + using const_iterator = const T*; + + using const_reverse_iterator = std::reverse_iterator; + using reverse_iterator = std::reverse_iterator; + + using reference = T&; + using const_reference = const T&; + using pointer = T*; + using const_pointer = const T*; + + using Base::capacity; + using Base::empty; + using Base::size; + + // forward iterator creation methods. + iterator begin() { + return (iterator)this->BeginX; + } + const_iterator begin() const { + return (const_iterator)this->BeginX; + } + iterator end() { + return begin() + size(); + } + const_iterator end() const { + return begin() + size(); + } + + // reverse iterator creation methods. + reverse_iterator rbegin() { + return reverse_iterator(end()); + } + const_reverse_iterator rbegin() const { + return const_reverse_iterator(end()); + } + reverse_iterator rend() { + return reverse_iterator(begin()); + } + const_reverse_iterator rend() const { + return const_reverse_iterator(begin()); + } + + size_type size_in_bytes() const { + return size() * sizeof(T); + } + constexpr size_type max_size() const { + return std::min(this->SizeTypeMax(), size_type(-1) / sizeof(T)); + } + + size_t capacity_in_bytes() const { + return capacity() * sizeof(T); + } + + /// Return a pointer to the vector's buffer, even if empty(). + pointer data() { + return pointer(begin()); + } + /// Return a pointer to the vector's buffer, even if empty(). + const_pointer data() const { + return const_pointer(begin()); + } + + // SmallVector::at is NOT from LLVM. + reference at(size_type idx) { + assert(idx < size()); + return begin()[idx]; + } + const_reference at(size_type idx) const { + assert(idx < size()); + return begin()[idx]; + } + reference operator[](size_type idx) { + assert(idx < size()); + return begin()[idx]; + } + const_reference operator[](size_type idx) const { + assert(idx < size()); + return begin()[idx]; + } + + reference front() { + assert(!empty()); + return begin()[0]; + } + const_reference front() const { + assert(!empty()); + return begin()[0]; + } + + reference back() { + assert(!empty()); + return end()[-1]; + } + const_reference back() const { + assert(!empty()); + return end()[-1]; + } +}; + +/// SmallVectorTemplateBase - This is where we put +/// method implementations that are designed to work with non-trivial T's. +/// +/// We approximate is_trivially_copyable with trivial move/copy construction and +/// trivial destruction. While the standard doesn't specify that you're allowed +/// copy these types with memcpy, there is no way for the type to observe this. +/// This catches the important case of std::pair, which is not +/// trivially assignable. +/// +/// XXX: if build fails here fall back to C10_IS_TRIVIALLY_COPYABLE and make a +/// note +template < + typename T, + bool = (std::is_trivially_copy_constructible_v) && + (std::is_trivially_move_constructible_v) && + std::is_trivially_destructible_v> +class SmallVectorTemplateBase : public SmallVectorTemplateCommon { + friend class SmallVectorTemplateCommon; + + protected: + static constexpr bool TakesParamByValue = false; + using ValueParamT = const T&; + + SmallVectorTemplateBase(size_t Size) : SmallVectorTemplateCommon(Size) {} + + static void destroy_range(T* S, T* E) { + while (S != E) { + --E; + E->~T(); + } + } + + /// Move the range [I, E) into the uninitialized memory starting with "Dest", + /// constructing elements as needed. + template + static void uninitialized_move(It1 I, It1 E, It2 Dest) { + std::uninitialized_copy( + std::make_move_iterator(I), std::make_move_iterator(E), Dest); + } + + /// Copy the range [I, E) onto the uninitialized memory starting with "Dest", + /// constructing elements as needed. + template + static void uninitialized_copy(It1 I, It1 E, It2 Dest) { + std::uninitialized_copy(I, E, Dest); + } + + /// Grow the allocated memory (without initializing new elements), doubling + /// the size of the allocated memory. Guarantees space for at least one more + /// element, or MinSize more elements if specified. + void grow(size_t MinSize = 0); + + /// Create a new allocation big enough for \p MinSize and pass back its size + /// in \p NewCapacity. This is the first section of \a grow(). + T* mallocForGrow(size_t MinSize, size_t& NewCapacity) { + return static_cast( + SmallVectorBase>::mallocForGrow( + MinSize, sizeof(T), NewCapacity)); + } + + /// Move existing elements over to the new allocation \p NewElts, the middle + /// section of \a grow(). + void moveElementsForGrow(T* NewElts); + + /// Transfer ownership of the allocation, finishing up \a grow(). + void takeAllocationForGrow(T* NewElts, size_t NewCapacity); + + /// Reserve enough space to add one element, and return the updated element + /// pointer in case it was a reference to the storage. + const T* reserveForParamAndGetAddress(const T& Elt, size_t N = 1) { + return this->reserveForParamAndGetAddressImpl(this, Elt, N); + } + + /// Reserve enough space to add one element, and return the updated element + /// pointer in case it was a reference to the storage. + T* reserveForParamAndGetAddress(T& Elt, size_t N = 1) { + return const_cast(this->reserveForParamAndGetAddressImpl(this, Elt, N)); + } + + static T&& forward_value_param(T&& V) { + return std::move(V); + } + static const T& forward_value_param(const T& V) { + return V; + } + + void growAndAssign(size_t NumElts, const T& Elt) { + // Grow manually in case Elt is an internal reference. + size_t NewCapacity = 0; + T* NewElts = mallocForGrow(NumElts, NewCapacity); + std::uninitialized_fill_n(NewElts, NumElts, Elt); + this->destroy_range(this->begin(), this->end()); + takeAllocationForGrow(NewElts, NewCapacity); + this->set_size(NumElts); + } + + template + T& growAndEmplaceBack(ArgTypes&&... Args) { + // Grow manually in case one of Args is an internal reference. + size_t NewCapacity = 0; + T* NewElts = mallocForGrow(0, NewCapacity); + ::new ((void*)(NewElts + this->size())) T(std::forward(Args)...); + moveElementsForGrow(NewElts); + takeAllocationForGrow(NewElts, NewCapacity); + this->set_size(this->size() + 1); + return this->back(); + } + + public: + void push_back(const T& Elt) { + const T* EltPtr = reserveForParamAndGetAddress(Elt); + ::new ((void*)this->end()) T(*EltPtr); + this->set_size(this->size() + 1); + } + + // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved) + void push_back(T&& Elt) { + T* EltPtr = reserveForParamAndGetAddress(Elt); + ::new ((void*)this->end()) T(::std::move(*EltPtr)); + this->set_size(this->size() + 1); + } + + void pop_back() { + this->set_size(this->size() - 1); + this->end()->~T(); + } +}; + +// Define this out-of-line to dissuade the C++ compiler from inlining it. +template +void SmallVectorTemplateBase::grow(size_t MinSize) { + size_t NewCapacity = 0; + T* NewElts = mallocForGrow(MinSize, NewCapacity); + moveElementsForGrow(NewElts); + takeAllocationForGrow(NewElts, NewCapacity); +} + +// Define this out-of-line to dissuade the C++ compiler from inlining it. +template +void SmallVectorTemplateBase::moveElementsForGrow( + T* NewElts) { + // Move the elements over. + this->uninitialized_move(this->begin(), this->end(), NewElts); + + // Destroy the original elements. + destroy_range(this->begin(), this->end()); +} + +// Define this out-of-line to dissuade the C++ compiler from inlining it. +template +void SmallVectorTemplateBase::takeAllocationForGrow( + T* NewElts, + size_t NewCapacity) { + // If this wasn't grown from the inline copy, deallocate the old space. + if (!this->isSmall()) + free(this->begin()); + + this->BeginX = NewElts; + this->Capacity = NewCapacity; +} + +/// SmallVectorTemplateBase - This is where we put +/// method implementations that are designed to work with trivially copyable +/// T's. This allows using memcpy in place of copy/move construction and +/// skipping destruction. +template +class SmallVectorTemplateBase : public SmallVectorTemplateCommon { + friend class SmallVectorTemplateCommon; + + protected: + /// True if it's cheap enough to take parameters by value. Doing so avoids + /// overhead related to mitigations for reference invalidation. + static constexpr bool TakesParamByValue = sizeof(T) <= 2 * sizeof(void*); + + /// Either const T& or T, depending on whether it's cheap enough to take + /// parameters by value. + using ValueParamT = std::conditional_t; + + SmallVectorTemplateBase(size_t Size) : SmallVectorTemplateCommon(Size) {} + + // No need to do a destroy loop for POD's. + static void destroy_range(T* /*unused*/, T* /*unused*/) {} + + /// Move the range [I, E) onto the uninitialized memory + /// starting with "Dest", constructing elements into it as needed. + template + static void uninitialized_move(It1 I, It1 E, It2 Dest) { + // Just do a copy. + uninitialized_copy(I, E, Dest); + } + + /// Copy the range [I, E) onto the uninitialized memory + /// starting with "Dest", constructing elements into it as needed. + template + static void uninitialized_copy(It1 I, It1 E, It2 Dest) { + // Arbitrary iterator types; just use the basic implementation. + std::uninitialized_copy(I, E, Dest); + } + + /// Copy the range [I, E) onto the uninitialized memory + /// starting with "Dest", constructing elements into it as needed. + template + static void uninitialized_copy( + T1* I, + T1* E, + T2* Dest, + std::enable_if_t, T2>>* /*unused*/ + = nullptr) { + // Use memcpy for PODs iterated by pointers (which includes SmallVector + // iterators): std::uninitialized_copy optimizes to memmove, but we can + // use memcpy here. Note that I and E are iterators and thus might be + // invalid for memcpy if they are equal. + if (I != E) + memcpy(reinterpret_cast(Dest), I, (E - I) * sizeof(T)); + } + + /// Double the size of the allocated memory, guaranteeing space for at + /// least one more element or MinSize if specified. + void grow(size_t MinSize = 0) { + this->grow_pod(MinSize, sizeof(T)); + } + + /// Reserve enough space to add one element, and return the updated element + /// pointer in case it was a reference to the storage. + const T* reserveForParamAndGetAddress(const T& Elt, size_t N = 1) { + return this->reserveForParamAndGetAddressImpl(this, Elt, N); + } + + /// Reserve enough space to add one element, and return the updated element + /// pointer in case it was a reference to the storage. + T* reserveForParamAndGetAddress(T& Elt, size_t N = 1) { + return const_cast(this->reserveForParamAndGetAddressImpl(this, Elt, N)); + } + + /// Copy \p V or return a reference, depending on \a ValueParamT. + static ValueParamT forward_value_param(ValueParamT V) { + return V; + } + + void growAndAssign(size_t NumElts, T Elt) { + // Elt has been copied in case it's an internal reference, side-stepping + // reference invalidation problems without losing the realloc optimization. + this->set_size(0); + this->grow(NumElts); + std::uninitialized_fill_n(this->begin(), NumElts, Elt); + this->set_size(NumElts); + } + + template + T& growAndEmplaceBack(ArgTypes&&... Args) { + // Use push_back with a copy in case Args has an internal reference, + // side-stepping reference invalidation problems without losing the realloc + // optimization. + push_back(T(std::forward(Args)...)); + return this->back(); + } + + public: + void push_back(ValueParamT Elt) { + const T* EltPtr = reserveForParamAndGetAddress(Elt); + memcpy(reinterpret_cast(this->end()), EltPtr, sizeof(T)); + this->set_size(this->size() + 1); + } + + void pop_back() { + this->set_size(this->size() - 1); + } +}; + +/// This class consists of common code factored out of the SmallVector class to +/// reduce code duplication based on the SmallVector 'N' template parameter. +template +class SmallVectorImpl : public SmallVectorTemplateBase { + using SuperClass = SmallVectorTemplateBase; + + public: + using iterator = typename SuperClass::iterator; + using const_iterator = typename SuperClass::const_iterator; + using reference = typename SuperClass::reference; + using size_type = typename SuperClass::size_type; + + protected: + using SmallVectorTemplateBase::TakesParamByValue; + using ValueParamT = typename SuperClass::ValueParamT; + + // Default ctor - Initialize to empty. + explicit SmallVectorImpl(unsigned N) : SmallVectorTemplateBase(N) {} + + public: + SmallVectorImpl(const SmallVectorImpl&) = delete; + + ~SmallVectorImpl() { + // Subclass has already destructed this vector's elements. + // If this wasn't grown from the inline copy, deallocate the old space. + if (!this->isSmall()) + free(this->begin()); + } + + void clear() { + this->destroy_range(this->begin(), this->end()); + this->Size = 0; + } + + private: + template + void resizeImpl(size_type N) { + if (N < this->size()) { + this->pop_back_n(this->size() - N); + } else if (N > this->size()) { + this->reserve(N); + for (auto I = this->end(), E = this->begin() + N; I != E; ++I) + if (ForOverwrite) + new (&*I) T; + else + new (&*I) T(); + this->set_size(N); + } + } + + public: + void resize(size_type N) { + resizeImpl(N); + } + + /// Like resize, but \ref T is POD, the new values won't be initialized. + void resize_for_overwrite(size_type N) { + resizeImpl(N); + } + + void resize(size_type N, ValueParamT NV) { + if (N == this->size()) + return; + + if (N < this->size()) { + this->pop_back_n(this->size() - N); + return; + } + + // N > this->size(). Defer to append. + this->append(N - this->size(), NV); + } + + void reserve(size_type N) { + if (this->capacity() < N) + this->grow(N); + } + + void pop_back_n(size_type NumItems) { + assert(this->size() >= NumItems); + this->destroy_range(this->end() - NumItems, this->end()); + this->set_size(this->size() - NumItems); + } + + [[nodiscard]] T pop_back_val() { + T Result = ::std::move(this->back()); + this->pop_back(); + return Result; + } + + void swap(SmallVectorImpl& RHS) noexcept; + + /// Add the specified range to the end of the SmallVector. + template < + typename in_iter, + typename = std::enable_if_t::iterator_category, + std::input_iterator_tag>>> + void append(in_iter in_start, in_iter in_end) { + this->assertSafeToAddRange(in_start, in_end); + size_type NumInputs = std::distance(in_start, in_end); + this->reserve(this->size() + NumInputs); + this->uninitialized_copy(in_start, in_end, this->end()); + this->set_size(this->size() + NumInputs); + } + + /// Append \p NumInputs copies of \p Elt to the end. + void append(size_type NumInputs, ValueParamT Elt) { + const T* EltPtr = this->reserveForParamAndGetAddress(Elt, NumInputs); + std::uninitialized_fill_n(this->end(), NumInputs, *EltPtr); + this->set_size(this->size() + NumInputs); + } + + void append(std::initializer_list IL) { + append(IL.begin(), IL.end()); + } + + void append(const SmallVectorImpl& RHS) { + append(RHS.begin(), RHS.end()); + } + + void assign(size_type NumElts, ValueParamT Elt) { + // Note that Elt could be an internal reference. + if (NumElts > this->capacity()) { + this->growAndAssign(NumElts, Elt); + return; + } + + // Assign over existing elements. + std::fill_n(this->begin(), std::min(NumElts, this->size()), Elt); + if (NumElts > this->size()) + std::uninitialized_fill_n(this->end(), NumElts - this->size(), Elt); + else if (NumElts < this->size()) + this->destroy_range(this->begin() + NumElts, this->end()); + this->set_size(NumElts); + } + + // FIXME: Consider assigning over existing elements, rather than clearing & + // re-initializing them - for all assign(...) variants. + + template < + typename in_iter, + typename = std::enable_if_t::iterator_category, + std::input_iterator_tag>>> + void assign(in_iter in_start, in_iter in_end) { + this->assertSafeToReferenceAfterClear(in_start, in_end); + clear(); + append(in_start, in_end); + } + + void assign(std::initializer_list IL) { + clear(); + append(IL); + } + + void assign(const SmallVectorImpl& RHS) { + assign(RHS.begin(), RHS.end()); + } + + iterator erase(iterator I) { + assert( + this->isReferenceToStorage(I) && "Iterator to erase is out of bounds."); + + iterator N = I; + // Shift all elts down one. + std::move(I + 1, this->end(), I); + // Drop the last elt. + this->pop_back(); + return N; + } + + iterator erase(iterator S, iterator E) { + assert(this->isRangeInStorage(S, E) && "Range to erase is out of bounds."); + + iterator N = S; + // Shift all elts down. + iterator I = std::move(E, this->end(), S); + // Drop the last elts. + this->destroy_range(I, this->end()); + this->set_size(I - this->begin()); + return N; + } + + private: + template + iterator insert_one_impl(iterator I, ArgType&& Elt) { + // Callers ensure that ArgType is derived from T. + static_assert( + std::is_same>, T>:: + value, + "ArgType must be derived from T!"); + + if (I == this->end()) { // Important special case for empty vector. + this->push_back(::std::forward(Elt)); + return this->end() - 1; + } + + assert( + this->isReferenceToStorage(I) && + "Insertion iterator is out of bounds."); + + // Grow if necessary. + size_t Index = I - this->begin(); + std::remove_reference_t* EltPtr = + this->reserveForParamAndGetAddress(Elt); + I = this->begin() + Index; + + ::new ((void*)this->end()) T(::std::move(this->back())); + // Push everything else over. + std::move_backward(I, this->end() - 1, this->end()); + this->set_size(this->size() + 1); + + // If we just moved the element we're inserting, be sure to update + // the reference (never happens if TakesParamByValue). + static_assert( + !TakesParamByValue || std::is_same_v, + "ArgType must be 'T' when taking by value!"); + if (!TakesParamByValue && this->isReferenceToRange(EltPtr, I, this->end())) + ++EltPtr; + + *I = ::std::forward(*EltPtr); + return I; + } + + public: + iterator insert(iterator I, T&& Elt) { + return insert_one_impl(I, this->forward_value_param(std::move(Elt))); + } + + iterator insert(iterator I, const T& Elt) { + return insert_one_impl(I, this->forward_value_param(Elt)); + } + + iterator insert(iterator I, size_type NumToInsert, ValueParamT Elt) { + // Convert iterator to elt# to avoid invalidating iterator when we reserve() + size_t InsertElt = I - this->begin(); + + if (I == this->end()) { // Important special case for empty vector. + append(NumToInsert, Elt); + return this->begin() + InsertElt; + } + + assert( + this->isReferenceToStorage(I) && + "Insertion iterator is out of bounds."); + + // Ensure there is enough space, and get the (maybe updated) address of + // Elt. + const T* EltPtr = this->reserveForParamAndGetAddress(Elt, NumToInsert); + + // Uninvalidate the iterator. + I = this->begin() + InsertElt; + + // If there are more elements between the insertion point and the end of the + // range than there are being inserted, we can use a simple approach to + // insertion. Since we already reserved space, we know that this won't + // reallocate the vector. + if (size_t(this->end() - I) >= NumToInsert) { + T* OldEnd = this->end(); + append( + std::move_iterator(this->end() - NumToInsert), + std::move_iterator(this->end())); + + // Copy the existing elements that get replaced. + std::move_backward(I, OldEnd - NumToInsert, OldEnd); + + // If we just moved the element we're inserting, be sure to update + // the reference (never happens if TakesParamByValue). + if (!TakesParamByValue && I <= EltPtr && EltPtr < this->end()) + EltPtr += NumToInsert; + + std::fill_n(I, NumToInsert, *EltPtr); + return I; + } + + // Otherwise, we're inserting more elements than exist already, and we're + // not inserting at the end. + + // Move over the elements that we're about to overwrite. + T* OldEnd = this->end(); + this->set_size(this->size() + NumToInsert); + size_t NumOverwritten = OldEnd - I; + this->uninitialized_move(I, OldEnd, this->end() - NumOverwritten); + + // If we just moved the element we're inserting, be sure to update + // the reference (never happens if TakesParamByValue). + if (!TakesParamByValue && I <= EltPtr && EltPtr < this->end()) + EltPtr += NumToInsert; + + // Replace the overwritten part. + std::fill_n(I, NumOverwritten, *EltPtr); + + // Insert the non-overwritten middle part. + std::uninitialized_fill_n(OldEnd, NumToInsert - NumOverwritten, *EltPtr); + return I; + } + + template < + typename ItTy, + typename = std::enable_if_t::iterator_category, + std::input_iterator_tag>>> + iterator insert(iterator I, ItTy From, ItTy To) { + // Convert iterator to elt# to avoid invalidating iterator when we reserve() + size_t InsertElt = I - this->begin(); + + if (I == this->end()) { // Important special case for empty vector. + append(From, To); + return this->begin() + InsertElt; + } + + assert( + this->isReferenceToStorage(I) && + "Insertion iterator is out of bounds."); + + // Check that the reserve that follows doesn't invalidate the iterators. + this->assertSafeToAddRange(From, To); + + size_t NumToInsert = std::distance(From, To); + + // Ensure there is enough space. + reserve(this->size() + NumToInsert); + + // Uninvalidate the iterator. + I = this->begin() + InsertElt; + + // If there are more elements between the insertion point and the end of the + // range than there are being inserted, we can use a simple approach to + // insertion. Since we already reserved space, we know that this won't + // reallocate the vector. + if (size_t(this->end() - I) >= NumToInsert) { + T* OldEnd = this->end(); + append( + std::move_iterator(this->end() - NumToInsert), + std::move_iterator(this->end())); + + // Copy the existing elements that get replaced. + std::move_backward(I, OldEnd - NumToInsert, OldEnd); + + std::copy(From, To, I); + return I; + } + + // Otherwise, we're inserting more elements than exist already, and we're + // not inserting at the end. + + // Move over the elements that we're about to overwrite. + T* OldEnd = this->end(); + this->set_size(this->size() + NumToInsert); + size_t NumOverwritten = OldEnd - I; + this->uninitialized_move(I, OldEnd, this->end() - NumOverwritten); + + // Replace the overwritten part. + for (T* J = I; NumOverwritten > 0; --NumOverwritten) { + *J = *From; + ++J; + ++From; + } + + // Insert the non-overwritten middle part. + this->uninitialized_copy(From, To, OldEnd); + return I; + } + + void insert(iterator I, std::initializer_list IL) { + insert(I, IL.begin(), IL.end()); + } + + template + reference emplace_back(ArgTypes&&... Args) { + if (C10_UNLIKELY(this->size() >= this->capacity())) + return this->growAndEmplaceBack(std::forward(Args)...); + + ::new ((void*)this->end()) T(std::forward(Args)...); + this->set_size(this->size() + 1); + return this->back(); + } + + SmallVectorImpl& operator=(const SmallVectorImpl& RHS); + + SmallVectorImpl& operator=(SmallVectorImpl&& RHS) noexcept( + std::is_nothrow_move_constructible_v && + std::is_nothrow_destructible_v); + + bool operator==(const SmallVectorImpl& RHS) const { + if (this->size() != RHS.size()) + return false; + return std::equal(this->begin(), this->end(), RHS.begin()); + } + bool operator!=(const SmallVectorImpl& RHS) const { + return !(*this == RHS); + } + + bool operator<(const SmallVectorImpl& RHS) const { + return std::lexicographical_compare( + this->begin(), this->end(), RHS.begin(), RHS.end()); + } +}; + +template +void SmallVectorImpl::swap(SmallVectorImpl& RHS) noexcept { + if (this == &RHS) + return; + + // We can only avoid copying elements if neither vector is small. + if (!this->isSmall() && !RHS.isSmall()) { + std::swap(this->BeginX, RHS.BeginX); + std::swap(this->Size, RHS.Size); + std::swap(this->Capacity, RHS.Capacity); + return; + } + this->reserve(RHS.size()); + RHS.reserve(this->size()); + + // Swap the shared elements. + size_t NumShared = this->size(); + if (NumShared > RHS.size()) + NumShared = RHS.size(); + for (size_type i = 0; i != NumShared; ++i) + std::swap((*this)[i], RHS[i]); + + // Copy over the extra elts. + if (this->size() > RHS.size()) { + size_t EltDiff = this->size() - RHS.size(); + this->uninitialized_copy(this->begin() + NumShared, this->end(), RHS.end()); + RHS.set_size(RHS.size() + EltDiff); + this->destroy_range(this->begin() + NumShared, this->end()); + this->set_size(NumShared); + } else if (RHS.size() > this->size()) { + size_t EltDiff = RHS.size() - this->size(); + this->uninitialized_copy(RHS.begin() + NumShared, RHS.end(), this->end()); + this->set_size(this->size() + EltDiff); + this->destroy_range(RHS.begin() + NumShared, RHS.end()); + RHS.set_size(NumShared); + } +} + +template +SmallVectorImpl& SmallVectorImpl::operator=( + const SmallVectorImpl& RHS) { + // Avoid self-assignment. + if (this == &RHS) + return *this; + + // If we already have sufficient space, assign the common elements, then + // destroy any excess. + size_t RHSSize = RHS.size(); + size_t CurSize = this->size(); + if (CurSize >= RHSSize) { + // Assign common elements. + iterator NewEnd; + if (RHSSize) + NewEnd = std::copy(RHS.begin(), RHS.begin() + RHSSize, this->begin()); + else + NewEnd = this->begin(); + + // Destroy excess elements. + this->destroy_range(NewEnd, this->end()); + + // Trim. + this->set_size(RHSSize); + return *this; + } + + // If we have to grow to have enough elements, destroy the current elements. + // This allows us to avoid copying them during the grow. + // FIXME: don't do this if they're efficiently moveable. + if (this->capacity() < RHSSize) { + // Destroy current elements. + this->clear(); + CurSize = 0; + this->grow(RHSSize); + } else if (CurSize) { + // Otherwise, use assignment for the already-constructed elements. + std::copy(RHS.begin(), RHS.begin() + CurSize, this->begin()); + } + + // Copy construct the new elements in place. + this->uninitialized_copy( + RHS.begin() + CurSize, RHS.end(), this->begin() + CurSize); + + // Set end. + this->set_size(RHSSize); + return *this; +} + +template +SmallVectorImpl& SmallVectorImpl:: +operator=(SmallVectorImpl&& RHS) noexcept( + std::is_nothrow_move_constructible_v && + std::is_nothrow_destructible_v) { + // Avoid self-assignment. + if (this == &RHS) + return *this; + + // If the RHS isn't small, clear this vector and then steal its buffer. + if (!RHS.isSmall()) { + this->destroy_range(this->begin(), this->end()); + if (!this->isSmall()) + free(this->begin()); + this->BeginX = RHS.BeginX; + this->Size = RHS.Size; + this->Capacity = RHS.Capacity; + RHS.resetToSmall(); + return *this; + } + + // If we already have sufficient space, assign the common elements, then + // destroy any excess. + size_t RHSSize = RHS.size(); + size_t CurSize = this->size(); + if (CurSize >= RHSSize) { + // Assign common elements. + iterator NewEnd = this->begin(); + if (RHSSize) + NewEnd = std::move(RHS.begin(), RHS.end(), NewEnd); + + // Destroy excess elements and trim the bounds. + this->destroy_range(NewEnd, this->end()); + this->set_size(RHSSize); + + // Clear the RHS. + RHS.clear(); + + return *this; + } + + // If we have to grow to have enough elements, destroy the current elements. + // This allows us to avoid copying them during the grow. + // FIXME: this may not actually make any sense if we can efficiently move + // elements. + if (this->capacity() < RHSSize) { + // Destroy current elements. + this->clear(); + CurSize = 0; + this->grow(RHSSize); + } else if (CurSize) { + // Otherwise, use assignment for the already-constructed elements. + std::move(RHS.begin(), RHS.begin() + CurSize, this->begin()); + } + + // Move-construct the new elements in place. + this->uninitialized_move( + RHS.begin() + CurSize, RHS.end(), this->begin() + CurSize); + + // Set end. + this->set_size(RHSSize); + + RHS.clear(); + return *this; +} + +/// Storage for the SmallVector elements. This is specialized for the N=0 case +/// to avoid allocating unnecessary storage. +template +struct SmallVectorStorage { + alignas(T) char InlineElts[N * sizeof(T)]; +}; + +/// We need the storage to be properly aligned even for small-size of 0 so that +/// the pointer math in \a SmallVectorTemplateCommon::getFirstEl() is +/// well-defined. +template +struct alignas(T) SmallVectorStorage {}; + +/// Forward declaration of SmallVector so that +/// calculateSmallVectorDefaultInlinedElements can reference +/// `sizeof(SmallVector)`. +template +class /* LLVM_GSL_OWNER */ SmallVector; + +/// Helper class for calculating the default number of inline elements for +/// `SmallVector`. +/// +/// This should be migrated to a constexpr function when our minimum +/// compiler support is enough for multi-statement constexpr functions. +template +struct CalculateSmallVectorDefaultInlinedElements { + // Parameter controlling the default number of inlined elements + // for `SmallVector`. + // + // The default number of inlined elements ensures that + // 1. There is at least one inlined element. + // 2. `sizeof(SmallVector) <= kPreferredSmallVectorSizeof` unless + // it contradicts 1. + static constexpr size_t kPreferredSmallVectorSizeof = 64; + + // static_assert that sizeof(T) is not "too big". + // + // Because our policy guarantees at least one inlined element, it is possible + // for an arbitrarily large inlined element to allocate an arbitrarily large + // amount of inline storage. We generally consider it an antipattern for a + // SmallVector to allocate an excessive amount of inline storage, so we want + // to call attention to these cases and make sure that users are making an + // intentional decision if they request a lot of inline storage. + // + // We want this assertion to trigger in pathological cases, but otherwise + // not be too easy to hit. To accomplish that, the cutoff is actually somewhat + // larger than kPreferredSmallVectorSizeof (otherwise, + // `SmallVector>` would be one easy way to trip it, and that + // pattern seems useful in practice). + // + // One wrinkle is that this assertion is in theory non-portable, since + // sizeof(T) is in general platform-dependent. However, we don't expect this + // to be much of an issue, because most LLVM development happens on 64-bit + // hosts, and therefore sizeof(T) is expected to *decrease* when compiled for + // 32-bit hosts, dodging the issue. The reverse situation, where development + // happens on a 32-bit host and then fails due to sizeof(T) *increasing* on a + // 64-bit host, is expected to be very rare. + static_assert( + sizeof(T) <= 256, + "You are trying to use a default number of inlined elements for " + "`SmallVector` but `sizeof(T)` is really big! Please use an " + "explicit number of inlined elements with `SmallVector` to make " + "sure you really want that much inline storage."); + + // Discount the size of the header itself when calculating the maximum inline + // bytes. + static constexpr size_t PreferredInlineBytes = + kPreferredSmallVectorSizeof - sizeof(SmallVector); + static constexpr size_t NumElementsThatFit = PreferredInlineBytes / sizeof(T); + static constexpr size_t value = + NumElementsThatFit == 0 ? 1 : NumElementsThatFit; +}; + +/// This is a 'vector' (really, a variable-sized array), optimized +/// for the case when the array is small. It contains some number of elements +/// in-place, which allows it to avoid heap allocation when the actual number of +/// elements is below that threshold. This allows normal "small" cases to be +/// fast without losing generality for large inputs. +/// +/// \note +/// In the absence of a well-motivated choice for the number of inlined +/// elements \p N, it is recommended to use \c SmallVector (that is, +/// omitting the \p N). This will choose a default number of inlined elements +/// reasonable for allocation on the stack (for example, trying to keep \c +/// sizeof(SmallVector) around 64 bytes). +/// +/// \warning This does not attempt to be exception safe. +/// +/// \see https://llvm.org/docs/ProgrammersManual.html#llvm-adt-smallvector-h +template < + typename T, + unsigned N = CalculateSmallVectorDefaultInlinedElements::value> +class /* LLVM_GSL_OWNER */ SmallVector : public SmallVectorImpl, + SmallVectorStorage { + public: + SmallVector() : SmallVectorImpl(N) {} + + ~SmallVector() { + // Destroy the constructed elements in the vector. + this->destroy_range(this->begin(), this->end()); + } + + explicit SmallVector(size_t Size, const T& Value = T()) + : SmallVectorImpl(N) { + this->assign(Size, Value); + } + + template < + typename ItTy, + typename = std::enable_if_t::iterator_category, + std::input_iterator_tag>>> + SmallVector(ItTy S, ItTy E) : SmallVectorImpl(N) { + this->append(S, E); + } + + // note: The enable_if restricts Container to types that have a .begin() and + // .end() that return valid input iterators. + template < + typename Container, + std::enable_if_t< + std::is_convertible_v< + typename std::iterator_traits< + decltype(std::declval() + .begin())>::iterator_category, + std::input_iterator_tag> && + std::is_convertible_v< + typename std::iterator_traits< + decltype(std::declval() + .end())>::iterator_category, + std::input_iterator_tag>, + int> = 0> + explicit SmallVector(Container&& c) : SmallVectorImpl(N) { + this->append(c.begin(), c.end()); + } + + SmallVector(std::initializer_list IL) : SmallVectorImpl(N) { + this->assign(IL); + } + + SmallVector(const SmallVector& RHS) : SmallVectorImpl(N) { + if (!RHS.empty()) + SmallVectorImpl::operator=(RHS); + } + + SmallVector& operator=(const SmallVector& RHS) { + SmallVectorImpl::operator=(RHS); + return *this; + } + + SmallVector(SmallVector&& RHS) noexcept( + std::is_nothrow_move_assignable_v>) + : SmallVectorImpl(N) { + if (!RHS.empty()) + SmallVectorImpl::operator=(::std::move(RHS)); + } + + // note: The enable_if restricts Container to types that have a .begin() and + // .end() that return valid input iterators. + template < + typename Container, + std::enable_if_t< + std::is_convertible_v< + typename std::iterator_traits< + decltype(std::declval() + .begin())>::iterator_category, + std::input_iterator_tag> && + std::is_convertible_v< + typename std::iterator_traits< + decltype(std::declval() + .end())>::iterator_category, + std::input_iterator_tag>, + int> = 0> + SmallVector& operator=(const Container& RHS) { + this->assign(RHS.begin(), RHS.end()); + return *this; + } + + SmallVector(SmallVectorImpl&& RHS) noexcept( + std::is_nothrow_move_assignable_v>) + : SmallVectorImpl(N) { + if (!RHS.empty()) + SmallVectorImpl::operator=(::std::move(RHS)); + } + + SmallVector& operator=(SmallVector&& RHS) noexcept( + std::is_nothrow_move_assignable_v>) { + SmallVectorImpl::operator=(::std::move(RHS)); + return *this; + } + + SmallVector& operator=(SmallVectorImpl&& RHS) noexcept( + std::is_nothrow_move_constructible_v>) { + SmallVectorImpl::operator=(::std::move(RHS)); + return *this; + } + + // note: The enable_if restricts Container to types that have a .begin() and + // .end() that return valid input iterators. + template < + typename Container, + std::enable_if_t< + std::is_convertible_v< + typename std::iterator_traits< + decltype(std::declval() + .begin())>::iterator_category, + std::input_iterator_tag> && + std::is_convertible_v< + typename std::iterator_traits< + decltype(std::declval() + .end())>::iterator_category, + std::input_iterator_tag>, + int> = 0> + // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward) + SmallVector& operator=(Container&& C) { + this->assign(C.begin(), C.end()); + return *this; + } + + SmallVector& operator=(std::initializer_list IL) { + this->assign(IL); + return *this; + } +}; + +template +inline size_t capacity_in_bytes(const SmallVector& X) { + return X.capacity_in_bytes(); +} + +template +std::ostream& operator<<(std::ostream& out, const SmallVector& list) { + int i = 0; + out << '['; + for (auto e : list) { + if (i++ > 0) + out << ", "; + out << e; + } + out << ']'; + return out; +} + +template +using ValueTypeFromRangeType = std::remove_const_t< + std::remove_reference_t()))>>; + +/// Given a range of type R, iterate the entire range and return a +/// SmallVector with elements of the vector. This is useful, for example, +/// when you want to iterate a range and then sort the results. +template +// NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward) +SmallVector, Size> to_vector(R&& Range) { + return {std::begin(Range), std::end(Range)}; +} +template +SmallVector< + ValueTypeFromRangeType, + CalculateSmallVectorDefaultInlinedElements< + ValueTypeFromRangeType>::value> +// NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward) +to_vector(R&& Range) { + return {std::begin(Range), std::end(Range)}; +} + +} // end namespace c10 + +namespace std { + +/// Implement std::swap in terms of SmallVector swap. +template +inline void swap( + c10::SmallVectorImpl& LHS, + c10::SmallVectorImpl& RHS) noexcept { + LHS.swap(RHS); +} + +/// Implement std::swap in terms of SmallVector swap. +template +inline void swap( + c10::SmallVector& LHS, + c10::SmallVector& RHS) noexcept { + LHS.swap(RHS); +} + +} // end namespace std + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/StringUtil.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/StringUtil.h new file mode 100644 index 0000000000000000000000000000000000000000..7c77905085305f5b2884985df2857a219a760c56 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/StringUtil.h @@ -0,0 +1,267 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#ifndef C10_UTIL_STRINGUTIL_H_ +#define C10_UTIL_STRINGUTIL_H_ + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +C10_CLANG_DIAGNOSTIC_PUSH() +#if C10_CLANG_HAS_WARNING("-Wshorten-64-to-32") +C10_CLANG_DIAGNOSTIC_IGNORE("-Wshorten-64-to-32") +#endif + +namespace c10 { + +namespace detail { + +// Obtains the base name from a full path. +C10_API std::string StripBasename(const std::string& full_path); + +C10_API std::string ExcludeFileExtension(const std::string& full_path); + +struct CompileTimeEmptyString { + operator const std::string&() const { + static const std::string empty_string_literal; + return empty_string_literal; + } + operator const char*() const { + return ""; + } +}; + +template +struct CanonicalizeStrTypes { + using type = const T&; +}; + +template +// NOLINTNEXTLINE(*c-arrays*) +struct CanonicalizeStrTypes { + using type = const char*; +}; + +inline std::ostream& _str(std::ostream& ss) { + return ss; +} + +template +struct Streamable : std::false_type {}; + +template +struct Streamable() << T{})> + : std::true_type {}; + +template +inline std::ostream& _str(std::ostream& ss, const T& t) { + if constexpr (std::is_enum_v && !Streamable::value) { + // NOLINTNEXTLINE(modernize-type-traits) + return _str(ss, static_cast::type>(t)); + } else { + // NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage) + ss << t; + return ss; + } +} + +template +inline std::ostream& _str(std::ostream& ss, const std::optional& t) { + if (t.has_value()) { + return _str(ss, t.value()); + } + ss << "std::nullopt"; + return ss; +} +// Overloads of _str for wide types; forces narrowing. +C10_API std::ostream& _str(std::ostream& ss, const wchar_t* wCStr); +C10_API std::ostream& _str(std::ostream& ss, const wchar_t& wChar); +C10_API std::ostream& _str(std::ostream& ss, const std::wstring& wString); + +template <> +inline std::ostream& _str( + std::ostream& ss, + const CompileTimeEmptyString& /*unused*/) { + return ss; +} + +template +inline std::ostream& _str(std::ostream& ss, const T& t, const Args&... args) { + return _str(_str(ss, t), args...); +} + +template +struct _str_wrapper final { + static std::string call(const Args&... args) { + std::ostringstream ss; + _str(ss, args...); + return ss.str(); + } +}; + +// Specializations for already-a-string types. +template <> +struct _str_wrapper final { + // return by reference to avoid the binary size of a string copy + static const std::string& call(const std::string& str) { + return str; + } +}; + +template <> +struct _str_wrapper final { + static const char* call(const char* str) { + return str; + } +}; + +// For c10::str() with an empty argument list (which is common in our assert +// macros), we don't want to pay the binary size for constructing and +// destructing a stringstream or even constructing a string. +template <> +struct _str_wrapper<> final { + static CompileTimeEmptyString call() { + return CompileTimeEmptyString(); + } +}; + +} // namespace detail + +// Convert a list of string-like arguments into a single string. +template +inline auto str(const Args&... args) { + return detail::_str_wrapper< + typename detail::CanonicalizeStrTypes::type...>::call(args...); +} + +template +inline std::string Join(const std::string& delimiter, const Container& v) { + std::stringstream s; + int cnt = static_cast(v.size()) - 1; + for (auto i = v.begin(); i != v.end(); ++i, --cnt) { + s << (*i) << (cnt ? delimiter : ""); + } + return std::move(s).str(); +} + +// Replace all occurrences of "from" substring to "to" string. +// Returns number of replacements +size_t C10_API +ReplaceAll(std::string& s, std::string_view from, std::string_view to); + +/// Represents a location in source code (for debugging). +struct C10_API SourceLocation { + const char* function; + const char* file; + uint32_t line; +}; + +std::ostream& operator<<(std::ostream& out, const SourceLocation& loc); + +// unix isprint but insensitive to locale +inline bool isPrint(char s) { + return s > 0x1f && s < 0x7f; +} + +inline void printQuotedString(std::ostream& stmt, const std::string_view str) { + stmt << '"'; + for (auto s : str) { + switch (s) { + case '\\': + stmt << "\\\\"; + break; + case '\'': + stmt << "\\'"; + break; + case '\"': + stmt << "\\\""; + break; + case '\a': + stmt << "\\a"; + break; + case '\b': + stmt << "\\b"; + break; + case '\f': + stmt << "\\f"; + break; + case '\n': + stmt << "\\n"; + break; + case '\r': + stmt << "\\r"; + break; + case '\t': + stmt << "\\t"; + break; + case '\v': + stmt << "\\v"; + break; + default: + if (isPrint(s)) { + stmt << s; + } else { + // C++ io has stateful formatting settings. Messing with + // them is probably worse than doing this manually. + // NOLINTNEXTLINE(*c-arrays*) + char buf[4] = "000"; + // NOLINTNEXTLINE(*narrowing-conversions) + buf[2] += s % 8; + s /= 8; + // NOLINTNEXTLINE(*narrowing-conversions) + buf[1] += s % 8; + s /= 8; + // NOLINTNEXTLINE(*narrowing-conversions) + buf[0] += s; + stmt << "\\" << buf; + } + break; + } + } + stmt << '"'; +} + +template +std::optional tryToNumber(const char* symbol) = delete; +template +std::optional tryToNumber(const std::string& symbol) = delete; + +/* + * Convert a string to a 64 bit integer. Trailing whitespaces are not supported. + * Similarly, integer string with trailing characters like "123abc" will be + * rejected. + */ +template <> +C10_API std::optional tryToNumber(const char* symbol); +template <> +C10_API std::optional tryToNumber(const std::string& symbol); + +/* + * Convert a string to a double. Trailing whitespaces are not supported. + * Similarly, integer string with trailing characters like "123abc" will + * be rejected. + */ +template <> +C10_API std::optional tryToNumber(const char* symbol); +template <> +C10_API std::optional tryToNumber(const std::string& symbol); + +C10_API std::vector split( + std::string_view target, + char delimiter); +} // namespace c10 + +C10_CLANG_DIAGNOSTIC_POP() + +#endif // C10_UTIL_STRINGUTIL_H_ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Synchronized.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Synchronized.h new file mode 100644 index 0000000000000000000000000000000000000000..c78564263ebfe172abcb5c097a8c222606e8f019 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Synchronized.h @@ -0,0 +1,67 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include + +namespace c10 { + +/** + * A very simple Synchronization class for error-free use of data + * in a multi-threaded context. See folly/docs/Synchronized.md for + * the inspiration of this class. + * + * Full URL: + * https://github.com/facebook/folly/blob/main/folly/docs/Synchronized.md + * + * This class implements a small subset of the generic functionality + * implemented by folly:Synchronized. Specifically, only withLock + * is implemented here since it's the smallest possible API that is + * able to cover a large surface area of functionality offered by + * folly::Synchronized. + */ +template +class Synchronized final { + mutable std::mutex mutex_; + T data_; + + public: + Synchronized() = default; + Synchronized(T const& data) : data_(data) {} + Synchronized(T&& data) : data_(std::move(data)) {} + + // Don't permit copy construction, move, assignment, or + // move assignment, since the underlying std::mutex + // isn't necessarily copyable/moveable. + Synchronized(Synchronized const&) = delete; + Synchronized(Synchronized&&) = delete; + Synchronized operator=(Synchronized const&) = delete; + Synchronized operator=(Synchronized&&) = delete; + ~Synchronized() = default; + + /** + * To use, call withLock with a callback that accepts T either + * by copy or by reference. Use the protected variable in the + * provided callback safely. + */ + template + auto withLock(CB&& cb) { + std::lock_guard guard(this->mutex_); + return std::forward(cb)(this->data_); + } + + /** + * To use, call withLock with a callback that accepts T either + * by copy or by const reference. Use the protected variable in + * the provided callback safely. + */ + template + auto withLock(CB&& cb) const { + std::lock_guard guard(this->mutex_); + return std::forward(cb)(this->data_); + } +}; +} // end namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/ThreadLocal.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/ThreadLocal.h new file mode 100644 index 0000000000000000000000000000000000000000..e5b92117a67fed4731ba92dc0b116b6c5aa80bcd --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/ThreadLocal.h @@ -0,0 +1,161 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include + +/** + * Android versions with libgnustl incorrectly handle thread_local C++ + * qualifier with composite types. NDK up to r17 version is affected. + * + * (A fix landed on Jun 4 2018: + * https://android-review.googlesource.com/c/toolchain/gcc/+/683601) + * + * In such cases, use c10::ThreadLocal wrapper + * which is `pthread_*` based with smart pointer semantics. + * + * In addition, convenient macro C10_DEFINE_TLS_static is available. + * To define static TLS variable of type std::string, do the following + * ``` + * C10_DEFINE_TLS_static(std::string, str_tls_); + * /////// + * { + * *str_tls_ = "abc"; + * assert(str_tls_->length(), 3); + * } + * ``` + * + * (see c10/test/util/ThreadLocal_test.cpp for more examples) + */ +#if !defined(C10_PREFER_CUSTOM_THREAD_LOCAL_STORAGE) + +#if defined(C10_ANDROID) && defined(__GLIBCXX__) && __GLIBCXX__ < 20180604 +#define C10_PREFER_CUSTOM_THREAD_LOCAL_STORAGE +#endif // defined(C10_ANDROID) && defined(__GLIBCXX__) && __GLIBCXX__ < 20180604 + +#endif // !defined(C10_PREFER_CUSTOM_THREAD_LOCAL_STORAGE) + +#if defined(C10_PREFER_CUSTOM_THREAD_LOCAL_STORAGE) +#include +#include +#include +#include +namespace c10 { + +/** + * @brief Temporary thread_local C++ qualifier replacement for Android + * based on `pthread_*`. + * To be used with composite types that provide default ctor. + */ +template +class ThreadLocal { + public: + ThreadLocal() { + pthread_key_create( + &key_, [](void* buf) { delete static_cast(buf); }); + } + + ~ThreadLocal() { + if (void* current = pthread_getspecific(key_)) { + delete static_cast(current); + } + + pthread_key_delete(key_); + } + + ThreadLocal(const ThreadLocal&) = delete; + ThreadLocal& operator=(const ThreadLocal&) = delete; + + Type& get() { + if (void* current = pthread_getspecific(key_)) { + return *static_cast(current); + } + + std::unique_ptr ptr = std::make_unique(); + if (0 == pthread_setspecific(key_, ptr.get())) { + return *ptr.release(); + } + + int err = errno; + TORCH_INTERNAL_ASSERT(false, "pthread_setspecific() failed, errno = ", err); + } + + Type& operator*() { + return get(); + } + + Type* operator->() { + return &get(); + } + + private: + pthread_key_t key_; +}; + +} // namespace c10 + +#define C10_DEFINE_TLS_static(Type, Name) static ::c10::ThreadLocal Name + +#define C10_DECLARE_TLS_class_static(Class, Type, Name) \ + static ::c10::ThreadLocal Name + +#define C10_DEFINE_TLS_class_static(Class, Type, Name) \ + ::c10::ThreadLocal Class::Name + +#else // defined(C10_PREFER_CUSTOM_THREAD_LOCAL_STORAGE) + +namespace c10 { + +/** + * @brief Default thread_local implementation for non-Android cases. + * To be used with composite types that provide default ctor. + */ +template +class ThreadLocal { + public: + using Accessor = Type* (*)(); + explicit ThreadLocal(Accessor accessor) : accessor_(accessor) {} + + ThreadLocal(const ThreadLocal&) = delete; + ThreadLocal(ThreadLocal&&) noexcept = default; + ThreadLocal& operator=(const ThreadLocal&) = delete; + ThreadLocal& operator=(ThreadLocal&&) noexcept = default; + ~ThreadLocal() = default; + + Type& get() { + return *accessor_(); + } + + Type& operator*() { + return get(); + } + + Type* operator->() { + return &get(); + } + + private: + Accessor accessor_; +}; + +} // namespace c10 + +#define C10_DEFINE_TLS_static(Type, Name) \ + static ::c10::ThreadLocal Name([]() { \ + static thread_local Type var; \ + return &var; \ + }) + +#define C10_DECLARE_TLS_class_static(Class, Type, Name) \ + static ::c10::ThreadLocal Name + +#define C10_DEFINE_TLS_class_static(Class, Type, Name) \ + ::c10::ThreadLocal Class::Name([]() { \ + static thread_local Type var; \ + return &var; \ + }) + +#endif // defined(C10_PREFER_CUSTOM_THREAD_LOCAL_STORAGE) + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/ThreadLocalDebugInfo.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/ThreadLocalDebugInfo.h new file mode 100644 index 0000000000000000000000000000000000000000..03ba6f5b39ba567f65bfa375df66c413a88c171b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/ThreadLocalDebugInfo.h @@ -0,0 +1,90 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include + +#include +#include + +namespace c10 { + +enum class C10_API_ENUM DebugInfoKind : uint8_t { + PRODUCER_INFO = 0, + MOBILE_RUNTIME_INFO, + PROFILER_STATE, + INFERENCE_CONTEXT, // for inference usage + PARAM_COMMS_INFO, + + TEST_INFO, // used only in tests + TEST_INFO_2, // used only in tests +}; + +class C10_API DebugInfoBase { + public: + DebugInfoBase() = default; + virtual ~DebugInfoBase() = default; +}; + +// Thread local debug information is propagated across the forward +// (including async fork tasks) and backward passes and is supposed +// to be utilized by the user's code to pass extra information from +// the higher layers (e.g. model id) down to the lower levels +// (e.g. to the operator observers used for debugging, logging, +// profiling, etc) +class C10_API ThreadLocalDebugInfo { + public: + static DebugInfoBase* get(DebugInfoKind kind); + + // Get current ThreadLocalDebugInfo + static std::shared_ptr current(); + + // Internal, use DebugInfoGuard/ThreadLocalStateGuard + static void _forceCurrentDebugInfo( + std::shared_ptr info); + + // Push debug info struct of a given kind + static void _push(DebugInfoKind kind, std::shared_ptr info); + // Pop debug info, throws in case the last pushed + // debug info is not of a given kind + static std::shared_ptr _pop(DebugInfoKind kind); + // Peek debug info, throws in case the last pushed debug info is not of the + // given kind + static std::shared_ptr _peek(DebugInfoKind kind); + + private: + std::shared_ptr info_; + DebugInfoKind kind_; + std::shared_ptr parent_info_; + + friend class DebugInfoGuard; +}; + +// DebugInfoGuard is used to set debug information, +// ThreadLocalDebugInfo is semantically immutable, the values are set +// through the scope-based guard object. +// Nested DebugInfoGuard adds/overrides existing values in the scope, +// restoring the original values after exiting the scope. +// Users can access the values through the ThreadLocalDebugInfo::get() call; +class C10_API DebugInfoGuard { + public: + DebugInfoGuard(DebugInfoKind kind, std::shared_ptr info); + + explicit DebugInfoGuard(std::shared_ptr info); + + ~DebugInfoGuard(); + + DebugInfoGuard(const DebugInfoGuard&) = delete; + DebugInfoGuard(DebugInfoGuard&&) = delete; + DebugInfoGuard& operator=(const DebugInfoGuard&) = delete; + DebugInfoGuard& operator=(DebugInfoGuard&&) = delete; + + private: + bool active_ = false; + std::shared_ptr prev_info_ = nullptr; +}; + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Type.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Type.h new file mode 100644 index 0000000000000000000000000000000000000000..9f460d4bde11da8629abece4d994a800e7918fc4 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Type.h @@ -0,0 +1,35 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#ifndef C10_UTIL_TYPE_H_ +#define C10_UTIL_TYPE_H_ + +#include +#include +#ifdef __GXX_RTTI +#include +#endif // __GXX_RTTI + +#include + +namespace c10 { + +/// Utility to demangle a C++ symbol name. +C10_API std::string demangle(const char* name); + +/// Returns the printable name of the type. +template +inline const char* demangle_type() { +#ifdef __GXX_RTTI + static const auto& name = *(new std::string(demangle(typeid(T).name()))); + return name.c_str(); +#else // __GXX_RTTI + return "(RTTI disabled, cannot show name)"; +#endif // __GXX_RTTI +} + +} // namespace c10 + +#endif // C10_UTIL_TYPE_H_ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/TypeCast.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/TypeCast.h new file mode 100644 index 0000000000000000000000000000000000000000..1d95fd90929796735962e4fb4fe1855cda857ac5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/TypeCast.h @@ -0,0 +1,215 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +C10_CLANG_DIAGNOSTIC_PUSH() +#if C10_CLANG_HAS_WARNING("-Wimplicit-float-conversion") +C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion") +#endif +#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion") +C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") +#endif + +namespace c10 { + +template +struct needs_real { + constexpr static bool value = + (is_complex::value && !is_complex::value); +}; + +template +struct maybe_real { + C10_HOST_DEVICE static inline src_t apply(src_t src) { + return src; + } +}; + +template +struct maybe_real { + C10_HOST_DEVICE static inline decltype(auto) apply(src_t src) { + return src.real(); + } +}; + +template +struct maybe_bool { + C10_HOST_DEVICE static inline src_t apply(src_t src) { + return src; + } +}; + +template +struct maybe_bool { + C10_HOST_DEVICE static inline decltype(auto) apply(src_t src) { + // Don't use bool operator so as to also compile for ComplexHalf. + return src.real() || src.imag(); + } +}; + +// Note: deliberately ignores undefined behavior, consistent with NumPy. +// PyTorch's type conversions can cause a variety of undefined behavior, +// including float to integral overflow and signed to unsigned integer overflow. +// Some of this undefined behavior is addressed below. +template +struct static_cast_with_inter_type { + C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline dest_t apply( + src_t src) { + constexpr bool real = needs_real::value; + auto r = maybe_real::apply(src); + return static_cast(r); + } +}; + +// Partial template specialization for casting to bool. +// Need to handle complex types separately, as we don't +// simply want to cast the real part to bool. +template +struct static_cast_with_inter_type { + C10_HOST_DEVICE static inline bool apply(src_t src) { + constexpr bool complex = needs_real::value; + return static_cast(maybe_bool::apply(src)); + } +}; + +// Partial template instantiation for casting to uint8. +// Note: Converting from negative float values to unsigned integer types is +// undefined behavior in C++, and current CPU and GPU compilers exhibit +// divergent behavior. Casting from negative float values to signed +// integer types and then to unsigned integer types is not undefined, +// however, so this cast improves the consistency of type conversions +// to uint8 across compilers. +// Further note: Type conversions across compilers still have other undefined +// and divergent behavior. +template +struct static_cast_with_inter_type { + C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline uint8_t apply( + src_t src) { + constexpr bool real = needs_real::value; + return static_cast( + static_cast(maybe_real::apply(src))); + } +}; + +template <> +struct static_cast_with_inter_type, c10::BFloat16> { + C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex< + c10::Half> + apply(c10::BFloat16 src) { + return static_cast>(c10::complex{src}); + } +}; + +template <> +struct static_cast_with_inter_type, c10::Float8_e5m2> { + C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex< + c10::Half> + apply(c10::Float8_e5m2 src) { + return static_cast>(c10::complex{src}); + } +}; + +template <> +struct static_cast_with_inter_type< + c10::complex, + c10::Float8_e5m2fnuz> { + C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex< + c10::Half> + apply(c10::Float8_e5m2fnuz src) { + return static_cast>(c10::complex{src}); + } +}; + +template <> +struct static_cast_with_inter_type< + c10::complex, + c10::Float8_e4m3fn> { + C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex< + c10::Half> + apply(c10::Float8_e4m3fn src) { + return static_cast>(c10::complex{src}); + } +}; + +template <> +struct static_cast_with_inter_type< + c10::complex, + c10::Float8_e4m3fnuz> { + C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex< + c10::Half> + apply(c10::Float8_e4m3fnuz src) { + return static_cast>(c10::complex{src}); + } +}; + +// TODO(#146647): Can we make all these template specialization happen +// based off our apply macros? +template <> +struct static_cast_with_inter_type< + c10::complex, + c10::Float8_e8m0fnu> { + C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex< + c10::Half> + apply(c10::Float8_e8m0fnu src) { + return static_cast>(c10::complex{src}); + } +}; + +template <> +struct static_cast_with_inter_type, c10::Half> { + C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex< + c10::Half> + apply(c10::Half src) { + return static_cast>(c10::complex{src}); + } +}; + +template <> +struct static_cast_with_inter_type< + c10::complex, + c10::complex> { + C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex< + c10::Half> + apply(c10::complex src) { + return static_cast>( + static_cast>(src)); + } +}; + +template +C10_HOST_DEVICE To convert(From f) { + return static_cast_with_inter_type::apply(f); +} + +// Define separately to avoid being inlined and prevent code-size bloat +[[noreturn]] C10_API void report_overflow(const char* name); + +template +To checked_convert(From f, const char* name) { + // Converting to bool can't overflow so we exclude this case from checking. + if (!std::is_same_v && overflows(f)) { + report_overflow(name); + } + return convert(f); +} + +} // namespace c10 + +C10_CLANG_DIAGNOSTIC_POP() + +// Trigger tests for D25440771. TODO: Remove this line any time you want. + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/TypeIndex.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/TypeIndex.h new file mode 100644 index 0000000000000000000000000000000000000000..fe2282d2973c030f2abb788009acf8ce661f3fd8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/TypeIndex.h @@ -0,0 +1,132 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#if !defined(FBCODE_CAFFE2) && !defined(C10_NODEPRECATED) +#define C10_TYPENAME_SUPPORTS_CONSTEXPR 1 +#define C10_TYPENAME_CONSTEXPR constexpr +#endif + +namespace c10::util { + +struct type_index final : IdWrapper { + constexpr explicit type_index(uint64_t checksum) : IdWrapper(checksum) {} + + // Allow usage in std::map / std::set + // TODO Disallow this and rather use std::unordered_map/set everywhere + friend constexpr bool operator<(type_index lhs, type_index rhs) noexcept { + return lhs.underlyingId() < rhs.underlyingId(); + } + + friend std::ostream& operator<<(std::ostream& stream, type_index typeId) { + return stream << typeId.underlyingId(); + } +}; + +namespace detail { + +template +inline constexpr c10::c10_string_view fully_qualified_type_name_impl() { +#if defined(_MSC_VER) && !defined(__clang__) + constexpr std::string_view fun_sig = __FUNCSIG__; +#if defined(__NVCC__) + constexpr std::string_view prefix = + "c10::basic_string_view c10::util::detail::fully_qualified_type_name_impl<"; + constexpr std::string_view suffix = ">()"; +#else + constexpr std::string_view prefix = + "class c10::basic_string_view __cdecl c10::util::detail::fully_qualified_type_name_impl<"; + constexpr std::string_view suffix = ">(void)"; +#endif +#elif defined(__clang__) + constexpr std::string_view fun_sig = __PRETTY_FUNCTION__; + constexpr std::string_view prefix = + "c10::c10_string_view c10::util::detail::fully_qualified_type_name_impl() [T = "; + constexpr std::string_view suffix = "]"; +#elif defined(__GNUC__) + constexpr std::string_view fun_sig = __PRETTY_FUNCTION__; + constexpr std::string_view prefix = + "constexpr c10::c10_string_view c10::util::detail::fully_qualified_type_name_impl() [with T = "; + constexpr std::string_view suffix = + "; c10::c10_string_view = c10::basic_string_view]"; +#endif +#if !defined(__CUDA_ARCH__) && !defined(__CUDA_ARCH_LIST__) + static_assert(c10::starts_with( + static_cast(fun_sig), + static_cast(prefix))); + static_assert(c10::ends_with( + static_cast(fun_sig), + static_cast(suffix))); +#endif + return fun_sig.substr( + prefix.size(), fun_sig.size() - prefix.size() - suffix.size()); +} + +#if !defined(__CUDA_ARCH__) && !defined(__CUDA_ARCH_LIST__) +template +inline constexpr uint64_t type_index_impl() { +// Idea: __PRETTY_FUNCTION__ (or __FUNCSIG__ on msvc) contains a qualified name +// of this function, including its template parameter, i.e. including the +// type we want an id for. We use this name and run crc64 on it to get a type +// id. +#if defined(_MSC_VER) && !defined(__clang__) + return crc64(__FUNCSIG__, sizeof(__FUNCSIG__)).checksum(); +#elif defined(__clang__) + return crc64(__PRETTY_FUNCTION__, sizeof(__PRETTY_FUNCTION__)).checksum(); +#elif defined(__GNUC__) + return crc64(__PRETTY_FUNCTION__, sizeof(__PRETTY_FUNCTION__)).checksum(); +#endif +} +#endif + +} // namespace detail + +template +inline constexpr type_index get_type_index() { +#if !defined(__CUDA_ARCH__) && !defined(__CUDA_ARCH_LIST__) + // To enforce that this is really computed at compile time, we pass the + // type index through std::integral_constant. + return type_index{std::integral_constant< + uint64_t, + detail::type_index_impl>()>::value}; +#else + // There's nothing in theory preventing us from running this on device code + // except for nvcc throwing a compiler error if we enable it. + return (abort(), type_index(0)); +#endif +} + +#if !defined(TORCH_PEDANTIC) +// Use precomputed hashsum for std::string +// Needed to workaround ambiguity in class name resolution +// into __PRETTY_FUNCTION__ when abovementioned class is defined in inlined +// namespace. In multi-ABI C++ library, `std::string` is an alias to +// `std::__cxx11::basic_string` which depending on compiler flags can be +// resolved to `basic_string` either in `std` namespace or in +// `std::__cxx11` one (`__cxx11` is an inline namespace) +template <> +inline constexpr type_index get_type_index() { + // hashsum for std::basic_string + return type_index{4193213214807308375ULL}; +} +#endif + +template +inline constexpr std::string_view get_fully_qualified_type_name() noexcept { + return static_cast( + detail::fully_qualified_type_name_impl()); +} +} // namespace c10::util + +C10_DEFINE_HASH_FOR_IDWRAPPER(c10::util::type_index) + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/TypeList.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/TypeList.h new file mode 100644 index 0000000000000000000000000000000000000000..7386baccad1420dd13c2530c31b52b0344fe5b9e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/TypeList.h @@ -0,0 +1,6 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/TypeSafeSignMath.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/TypeSafeSignMath.h new file mode 100644 index 0000000000000000000000000000000000000000..f511333fc7d9ca2b9e29fd7512e4cd0cb8776b25 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/TypeSafeSignMath.h @@ -0,0 +1,6 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/TypeTraits.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/TypeTraits.h new file mode 100644 index 0000000000000000000000000000000000000000..9d49c82cbd8948cdd7bb2b9fd758f7875e5dfdb7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/TypeTraits.h @@ -0,0 +1,6 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Unicode.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Unicode.h new file mode 100644 index 0000000000000000000000000000000000000000..68d2c2ce7feac15b4fab16f4124e41633433a213 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Unicode.h @@ -0,0 +1,19 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#if defined(_WIN32) +#include +#include +#include +#endif + +namespace c10 { +#if defined(_WIN32) +C10_API std::wstring u8u16(const std::string& str); +C10_API std::string u16u8(const std::wstring& wstr); +#endif +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/UniqueVoidPtr.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/UniqueVoidPtr.h new file mode 100644 index 0000000000000000000000000000000000000000..dc2ba274cb76d7d7b7c810c8cc318abbe412106a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/UniqueVoidPtr.h @@ -0,0 +1,145 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include +#include +#include + +#include +#include + +namespace c10 { + +using DeleterFnPtr = void (*)(void*); + +namespace detail { + +// Does not delete anything +C10_API void deleteNothing(void* /*unused*/); + +// A detail::UniqueVoidPtr is an owning smart pointer like unique_ptr, but +// with three major differences: +// +// 1) It is specialized to void +// +// 2) It is specialized for a function pointer deleter +// void(void* ctx); i.e., the deleter doesn't take a +// reference to the data, just to a context pointer +// (erased as void*). In fact, internally, this pointer +// is implemented as having an owning reference to +// context, and a non-owning reference to data; this is why +// you release_context(), not release() (the conventional +// API for release() wouldn't give you enough information +// to properly dispose of the object later.) +// +// 3) The deleter is guaranteed to be called when the unique +// pointer is destructed and the context is non-null; this is different +// from std::unique_ptr where the deleter is not called if the +// data pointer is null. +// +// Some of the methods have slightly different types than std::unique_ptr +// to reflect this. +// +class UniqueVoidPtr { + private: + // Lifetime tied to ctx_ + void* data_; + std::unique_ptr ctx_; + + public: + UniqueVoidPtr() : data_(nullptr), ctx_(nullptr, &deleteNothing) {} + explicit UniqueVoidPtr(void* data) + : data_(data), ctx_(nullptr, &deleteNothing) {} + UniqueVoidPtr(void* data, void* ctx, DeleterFnPtr ctx_deleter) + : data_(data), ctx_(ctx, ctx_deleter ? ctx_deleter : &deleteNothing) {} + void* operator->() const { + return data_; + } + void clear() { + ctx_ = nullptr; + data_ = nullptr; + } + void* get() const { + return data_; + } + + bool /* success */ unsafe_reset_data_and_ctx(void* new_data_and_ctx) { + if (C10_UNLIKELY(ctx_.get_deleter() != &deleteNothing)) { + return false; + } + // seems quicker than calling the no-op deleter when we reset + // NOLINTNEXTLINE(bugprone-unused-return-value) + ctx_.release(); + ctx_.reset(new_data_and_ctx); + data_ = new_data_and_ctx; + return true; + } + + void* get_context() const { + return ctx_.get(); + } + void* release_context() { + return ctx_.release(); + } + std::unique_ptr&& move_context() { + return std::move(ctx_); + } + [[nodiscard]] bool compare_exchange_deleter( + DeleterFnPtr expected_deleter, + DeleterFnPtr new_deleter) { + if (get_deleter() != expected_deleter) + return false; + ctx_ = std::unique_ptr(ctx_.release(), new_deleter); + return true; + } + + template + T* cast_context(DeleterFnPtr expected_deleter) const { + if (get_deleter() != expected_deleter) + return nullptr; + return static_cast(get_context()); + } + operator bool() const { + return data_ || ctx_; + } + DeleterFnPtr get_deleter() const { + return ctx_.get_deleter(); + } +}; + +// Note [How UniqueVoidPtr is implemented] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// UniqueVoidPtr solves a common problem for allocators of tensor data, which +// is that the data pointer (e.g., float*) which you are interested in, is not +// the same as the context pointer (e.g., DLManagedTensor) which you need +// to actually deallocate the data. Under a conventional deleter design, you +// have to store extra context in the deleter itself so that you can actually +// delete the right thing. Implementing this with standard C++ is somewhat +// error-prone: if you use a std::unique_ptr to manage tensors, the deleter will +// not be called if the data pointer is nullptr, which can cause a leak if the +// context pointer is non-null (and the deleter is responsible for freeing both +// the data pointer and the context pointer). +// +// So, in our reimplementation of unique_ptr, which just store the context +// directly in the unique pointer, and attach the deleter to the context +// pointer itself. In simple cases, the context pointer is just the pointer +// itself. + +inline bool operator==(const UniqueVoidPtr& sp, std::nullptr_t) noexcept { + return !sp; +} +inline bool operator==(std::nullptr_t, const UniqueVoidPtr& sp) noexcept { + return !sp; +} +inline bool operator!=(const UniqueVoidPtr& sp, std::nullptr_t) noexcept { + return sp; +} +inline bool operator!=(std::nullptr_t, const UniqueVoidPtr& sp) noexcept { + return sp; +} + +} // namespace detail +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Unroll.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Unroll.h new file mode 100644 index 0000000000000000000000000000000000000000..c1470391c8c4ac75f5055848de538b66beea00b7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/Unroll.h @@ -0,0 +1,35 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include +#include + +// Utility to guarantee complete unrolling of a loop where the bounds are known +// at compile time. Various pragmas achieve similar effects, but are not as +// portable across compilers. + +// Example: c10::ForcedUnroll<4>{}(f); is equivalent to f(0); f(1); f(2); f(3); + +namespace c10 { + +template +struct ForcedUnroll { + template + C10_ALWAYS_INLINE void operator()(const Func& f, Args... args) const { + ForcedUnroll{}(f, args...); + f(std::integral_constant{}, args...); + } +}; + +template <> +struct ForcedUnroll<1> { + template + C10_ALWAYS_INLINE void operator()(const Func& f, Args... args) const { + f(std::integral_constant{}, args...); + } +}; + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/WaitCounter.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/WaitCounter.h new file mode 100644 index 0000000000000000000000000000000000000000..ccae4f78e54b38345cab1f41b97b70293d0a35a8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/WaitCounter.h @@ -0,0 +1,103 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include + +#include +#include +#include + +namespace c10::monitor { +namespace detail { +class WaitCounterImpl; + +class WaitCounterBackendIf { + public: + virtual ~WaitCounterBackendIf() = default; + + virtual intptr_t start( + std::chrono::steady_clock::time_point now) noexcept = 0; + virtual void stop( + std::chrono::steady_clock::time_point now, + intptr_t ctx) noexcept = 0; +}; + +class WaitCounterBackendFactoryIf { + public: + virtual ~WaitCounterBackendFactoryIf() = default; + + // May return nullptr. + // In this case the counter will be ignored by the given backend. + virtual std::unique_ptr create( + std::string_view key) noexcept = 0; +}; + +C10_API void registerWaitCounterBackend( + std::unique_ptr /*factory*/); + +C10_API std::vector> +getRegisteredWaitCounterBackends(); +} // namespace detail + +// A handle to a wait counter. +class C10_API WaitCounterHandle { + public: + explicit WaitCounterHandle(std::string_view key); + + class WaitGuard { + public: + WaitGuard(WaitGuard&& other) noexcept + : handle_{std::exchange(other.handle_, {})}, + ctxs_{std::move(other.ctxs_)} {} + WaitGuard(const WaitGuard&) = delete; + WaitGuard& operator=(const WaitGuard&) = delete; + WaitGuard& operator=(WaitGuard&&) = delete; + + ~WaitGuard() { + stop(); + } + + void stop() { + if (auto handle = std::exchange(handle_, nullptr)) { + handle->stop(ctxs_); + } + } + + private: + WaitGuard(WaitCounterHandle& handle, SmallVector&& ctxs) + : handle_{&handle}, ctxs_{std::move(ctxs)} {} + + friend class WaitCounterHandle; + + WaitCounterHandle* handle_; + SmallVector ctxs_; + }; + + // Starts a waiter + WaitGuard start(); + + private: + // Stops the waiter. Each start() call should be matched by exactly one stop() + // call. + void stop(const SmallVector& ctxs); + + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) + detail::WaitCounterImpl& impl_; +}; +} // namespace c10::monitor + +#define STATIC_WAIT_COUNTER(_key) \ + []() -> ::c10::monitor::WaitCounterHandle& { \ + static ::c10::monitor::WaitCounterHandle handle(#_key); \ + return handle; \ + }() + +#define STATIC_SCOPED_WAIT_COUNTER(_name) \ + auto C10_ANONYMOUS_VARIABLE(SCOPE_GUARD) = STATIC_WAIT_COUNTER(_name).start(); + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/WaitCounterDynamicBackend.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/WaitCounterDynamicBackend.h new file mode 100644 index 0000000000000000000000000000000000000000..141d5431adcc1f51286b864d02cc30c2035e3371 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/WaitCounterDynamicBackend.h @@ -0,0 +1,26 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +namespace c10::monitor::detail { + +struct WaitCounterDynamicBackend { + void* self{nullptr}; + intptr_t (*start)(void* self, int64_t nowUs){nullptr}; + void (*stop)(void* self, int64_t nowUs, intptr_t ctx){nullptr}; + void (*destroy)(void* self){nullptr}; +}; + +using WaitCounterDynamicBackendInit = + void (*)(WaitCounterDynamicBackend*, const char* key, std::size_t keyLen); + +// This name needs to be updated if anything in the API above is changed. +constexpr std::string_view kWaitCounterDynamicBackendInitFn = + "c10_monitor_wait_counter_dynamic_backend_init_v1"; +} // namespace c10::monitor::detail + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/accumulate.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/accumulate.h new file mode 100644 index 0000000000000000000000000000000000000000..df0899a2ce0697b9ff2d8c395dc81fbb9c2d0f84 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/accumulate.h @@ -0,0 +1,129 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Copyright 2004-present Facebook. All Rights Reserved. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace c10 { + +/// Sum of a list of integers; accumulates into the int64_t datatype +template < + typename C, + std::enable_if_t, int> = 0> +inline int64_t sum_integers(const C& container) { + // std::accumulate infers return type from `init` type, so if the `init` type + // is not large enough to hold the result, computation can overflow. We use + // `int64_t` here to avoid this. + return std::accumulate( + container.begin(), container.end(), static_cast(0)); +} + +/// Sum of integer elements referred to by iterators; accumulates into the +/// int64_t datatype +template < + typename Iter, + std::enable_if_t< + std::is_integral_v::value_type>, + int> = 0> +inline int64_t sum_integers(Iter begin, Iter end) { + // std::accumulate infers return type from `init` type, so if the `init` type + // is not large enough to hold the result, computation can overflow. We use + // `int64_t` here to avoid this. + return std::accumulate(begin, end, static_cast(0)); +} + +/// Product of a list of integers; accumulates into the int64_t datatype +template < + typename C, + std::enable_if_t, int> = 0> +inline int64_t multiply_integers(const C& container) { + // std::accumulate infers return type from `init` type, so if the `init` type + // is not large enough to hold the result, computation can overflow. We use + // `int64_t` here to avoid this. + return std::accumulate( + container.begin(), + container.end(), + static_cast(1), + std::multiplies<>()); +} + +/// Product of integer elements referred to by iterators; accumulates into the +/// int64_t datatype +template < + typename Iter, + std::enable_if_t< + std::is_integral_v::value_type>, + int> = 0> +inline int64_t multiply_integers(Iter begin, Iter end) { + // std::accumulate infers return type from `init` type, so if the `init` type + // is not large enough to hold the result, computation can overflow. We use + // `int64_t` here to avoid this. + return std::accumulate( + begin, end, static_cast(1), std::multiplies<>()); +} + +/// Return product of all dimensions starting from k +/// Returns 1 if k>=dims.size() +template < + typename C, + std::enable_if_t, int> = 0> +inline int64_t numelements_from_dim(const int k, const C& dims) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(k >= 0); + + if (k > static_cast(dims.size())) { + return 1; + } else { + auto cbegin = dims.cbegin(); + std::advance(cbegin, k); + return multiply_integers(cbegin, dims.cend()); + } +} + +/// Product of all dims up to k (not including dims[k]) +/// Throws an error if k>dims.size() +template < + typename C, + std::enable_if_t, int> = 0> +inline int64_t numelements_to_dim(const int k, const C& dims) { + TORCH_INTERNAL_ASSERT(0 <= k); + TORCH_INTERNAL_ASSERT((unsigned)k <= dims.size()); + + auto cend = dims.cbegin(); + std::advance(cend, k); + return multiply_integers(dims.cbegin(), cend); +} + +/// Product of all dims between k and l (including dims[k] and excluding +/// dims[l]) k and l may be supplied in either order +template < + typename C, + std::enable_if_t, int> = 0> +inline int64_t numelements_between_dim(int k, int l, const C& dims) { + TORCH_INTERNAL_ASSERT(0 <= k); + TORCH_INTERNAL_ASSERT(0 <= l); + + if (k > l) { + std::swap(k, l); + } + + TORCH_INTERNAL_ASSERT((unsigned)l < dims.size()); + + auto cbegin = dims.cbegin(); + auto cend = dims.cbegin(); + std::advance(cbegin, k); + std::advance(cend, l); + return multiply_integers(cbegin, cend); +} + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/bit_cast.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/bit_cast.h new file mode 100644 index 0000000000000000000000000000000000000000..948d03d509175254b3f54c60a4b501dd62f870b5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/bit_cast.h @@ -0,0 +1,6 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/bits.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/bits.h new file mode 100644 index 0000000000000000000000000000000000000000..fe5b67c454490e06d88752b708d9543cda0ae6d1 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/bits.h @@ -0,0 +1,6 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/complex.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/complex.h new file mode 100644 index 0000000000000000000000000000000000000000..ff5ea55c508872c075b181518ff6e1cf537bbc3a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/complex.h @@ -0,0 +1,83 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include + +#include +#include +#include + +// std functions +// +// The implementation of these functions also follow the design of C++20 + +namespace std { + +template +constexpr T real(const c10::complex& z) { + return z.real(); +} + +template +constexpr T imag(const c10::complex& z) { + return z.imag(); +} + +template +C10_HOST_DEVICE T abs(const c10::complex& z) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return thrust::abs(static_cast>(z)); +#else + return std::abs(static_cast>(z)); +#endif +} + +#if defined(USE_ROCM) +#define ROCm_Bug(x) +#else +#define ROCm_Bug(x) x +#endif + +template +C10_HOST_DEVICE T arg(const c10::complex& z) { + return ROCm_Bug(std)::atan2(std::imag(z), std::real(z)); +} + +#undef ROCm_Bug + +template +constexpr T norm(const c10::complex& z) { + return z.real() * z.real() + z.imag() * z.imag(); +} + +// For std::conj, there are other versions of it: +// constexpr std::complex conj( float z ); +// template< class DoubleOrInteger > +// constexpr std::complex conj( DoubleOrInteger z ); +// constexpr std::complex conj( long double z ); +// These are not implemented +// TODO(@zasdfgbnm): implement them as c10::conj +template +constexpr c10::complex conj(const c10::complex& z) { + return c10::complex(z.real(), -z.imag()); +} + +// Thrust does not have complex --> complex version of thrust::proj, +// so this function is not implemented at c10 right now. +// TODO(@zasdfgbnm): implement it by ourselves + +// There is no c10 version of std::polar, because std::polar always +// returns std::complex. Use c10::polar instead; + +} // namespace std + +#define C10_INTERNAL_INCLUDE_COMPLEX_REMAINING_H +// math functions are included in a separate file +#include // IWYU pragma: keep +// utilities for complex types +#include // IWYU pragma: keep +#undef C10_INTERNAL_INCLUDE_COMPLEX_REMAINING_H + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/complex_math.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/complex_math.h new file mode 100644 index 0000000000000000000000000000000000000000..33da59051855d7e726fe83d19b9c39e8ab355317 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/complex_math.h @@ -0,0 +1,411 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#if !defined(C10_INTERNAL_INCLUDE_COMPLEX_REMAINING_H) +#error \ + "c10/util/complex_math.h is not meant to be individually included. Include c10/util/complex.h instead." +#endif + +namespace c10_complex_math { + +// Exponential functions + +template +C10_HOST_DEVICE inline c10::complex exp(const c10::complex& x) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::exp(static_cast>(x))); +#else + return static_cast>( + std::exp(static_cast>(x))); +#endif +} + +template +C10_HOST_DEVICE inline c10::complex log(const c10::complex& x) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::log(static_cast>(x))); +#else + return static_cast>( + std::log(static_cast>(x))); +#endif +} + +template +C10_HOST_DEVICE inline c10::complex log10(const c10::complex& x) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::log10(static_cast>(x))); +#else + return static_cast>( + std::log10(static_cast>(x))); +#endif +} + +template +C10_HOST_DEVICE inline c10::complex log2(const c10::complex& x) { + const c10::complex log2 = c10::complex(::log(2.0), 0.0); + return c10_complex_math::log(x) / log2; +} + +// Power functions +// +#if defined(_LIBCPP_VERSION) || \ + (defined(__GLIBCXX__) && !defined(_GLIBCXX11_USE_C99_COMPLEX)) +namespace _detail { +C10_API c10::complex sqrt(const c10::complex& in); +C10_API c10::complex sqrt(const c10::complex& in); +C10_API c10::complex acos(const c10::complex& in); +C10_API c10::complex acos(const c10::complex& in); +} // namespace _detail +#endif + +template +C10_HOST_DEVICE inline c10::complex sqrt(const c10::complex& x) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::sqrt(static_cast>(x))); +#elif !( \ + defined(_LIBCPP_VERSION) || \ + (defined(__GLIBCXX__) && !defined(_GLIBCXX11_USE_C99_COMPLEX))) + return static_cast>( + std::sqrt(static_cast>(x))); +#else + return _detail::sqrt(x); +#endif +} + +template +C10_HOST_DEVICE inline c10::complex pow( + const c10::complex& x, + const c10::complex& y) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>(thrust::pow( + static_cast>(x), static_cast>(y))); +#else + return static_cast>(std::pow( + static_cast>(x), static_cast>(y))); +#endif +} + +template +C10_HOST_DEVICE inline c10::complex pow( + const c10::complex& x, + const T& y) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::pow(static_cast>(x), y)); +#else + return static_cast>( + std::pow(static_cast>(x), y)); +#endif +} + +template +C10_HOST_DEVICE inline c10::complex pow( + const T& x, + const c10::complex& y) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::pow(x, static_cast>(y))); +#else + return static_cast>( + std::pow(x, static_cast>(y))); +#endif +} + +template +C10_HOST_DEVICE inline c10::complex pow( + const c10::complex& x, + const c10::complex& y) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>(thrust::pow( + static_cast>(x), static_cast>(y))); +#else + return static_cast>(std::pow( + static_cast>(x), static_cast>(y))); +#endif +} + +template +C10_HOST_DEVICE inline c10::complex pow( + const c10::complex& x, + const U& y) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::pow(static_cast>(x), y)); +#else + return static_cast>( + std::pow(static_cast>(x), y)); +#endif +} + +template +C10_HOST_DEVICE inline c10::complex pow( + const T& x, + const c10::complex& y) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::pow(x, static_cast>(y))); +#else + return static_cast>( + std::pow(x, static_cast>(y))); +#endif +} + +// Trigonometric functions + +template +C10_HOST_DEVICE inline c10::complex sin(const c10::complex& x) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::sin(static_cast>(x))); +#else + return static_cast>( + std::sin(static_cast>(x))); +#endif +} + +template +C10_HOST_DEVICE inline c10::complex cos(const c10::complex& x) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::cos(static_cast>(x))); +#else + return static_cast>( + std::cos(static_cast>(x))); +#endif +} + +template +C10_HOST_DEVICE inline c10::complex tan(const c10::complex& x) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::tan(static_cast>(x))); +#else + return static_cast>( + std::tan(static_cast>(x))); +#endif +} + +template +C10_HOST_DEVICE inline c10::complex asin(const c10::complex& x) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::asin(static_cast>(x))); +#else + return static_cast>( + std::asin(static_cast>(x))); +#endif +} + +template +C10_HOST_DEVICE inline c10::complex acos(const c10::complex& x) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::acos(static_cast>(x))); +#elif !defined(_LIBCPP_VERSION) + return static_cast>( + std::acos(static_cast>(x))); +#else + return _detail::acos(x); +#endif +} + +template +C10_HOST_DEVICE inline c10::complex atan(const c10::complex& x) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::atan(static_cast>(x))); +#else + return static_cast>( + std::atan(static_cast>(x))); +#endif +} + +// Hyperbolic functions + +template +C10_HOST_DEVICE inline c10::complex sinh(const c10::complex& x) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::sinh(static_cast>(x))); +#else + return static_cast>( + std::sinh(static_cast>(x))); +#endif +} + +template +C10_HOST_DEVICE inline c10::complex cosh(const c10::complex& x) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::cosh(static_cast>(x))); +#else + return static_cast>( + std::cosh(static_cast>(x))); +#endif +} + +template +C10_HOST_DEVICE inline c10::complex tanh(const c10::complex& x) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::tanh(static_cast>(x))); +#else + return static_cast>( + std::tanh(static_cast>(x))); +#endif +} + +template +C10_HOST_DEVICE inline c10::complex asinh(const c10::complex& x) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::asinh(static_cast>(x))); +#else + return static_cast>( + std::asinh(static_cast>(x))); +#endif +} + +template +C10_HOST_DEVICE inline c10::complex acosh(const c10::complex& x) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::acosh(static_cast>(x))); +#else + return static_cast>( + std::acosh(static_cast>(x))); +#endif +} + +template +C10_HOST_DEVICE inline c10::complex atanh(const c10::complex& x) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::atanh(static_cast>(x))); +#else + return static_cast>( + std::atanh(static_cast>(x))); +#endif +} + +template +C10_HOST_DEVICE inline c10::complex log1p(const c10::complex& z) { +#if defined(__APPLE__) || defined(__MACOSX) || defined(__CUDACC__) || \ + defined(__HIPCC__) + // For Mac, the new implementation yielded a high relative error. Falling back + // to the old version for now. + // See https://github.com/numpy/numpy/pull/22611#issuecomment-1667945354 + // For CUDA we also use this one, as thrust::log(thrust::complex) takes + // *forever* to compile + + // log1p(z) = log(1 + z) + // Let's define 1 + z = r * e ^ (i * a), then we have + // log(r * e ^ (i * a)) = log(r) + i * a + // With z = x + iy, the term r can be written as + // r = ((1 + x) ^ 2 + y ^ 2) ^ 0.5 + // = (1 + x ^ 2 + 2 * x + y ^ 2) ^ 0.5 + // So, log(r) is + // log(r) = 0.5 * log(1 + x ^ 2 + 2 * x + y ^ 2) + // = 0.5 * log1p(x * (x + 2) + y ^ 2) + // we need to use the expression only on certain condition to avoid overflow + // and underflow from `(x * (x + 2) + y ^ 2)` + T x = z.real(); + T y = z.imag(); + T zabs = std::abs(z); + T theta = std::atan2(y, x + T(1)); + if (zabs < 0.5) { + T r = x * (T(2) + x) + y * y; + if (r == 0) { // handle underflow + return {x, theta}; + } + return {T(0.5) * std::log1p(r), theta}; + } else { + T z0 = std::hypot(x + 1, y); + return {std::log(z0), theta}; + } +#else + // CPU path + // Based on https://github.com/numpy/numpy/pull/22611#issuecomment-1667945354 + c10::complex u = z + T(1); + if (u == T(1)) { + return z; + } else { + auto log_u = log(u); + if (u - T(1) == z) { + return log_u; + } + return log_u * (z / (u - T(1))); + } +#endif +} + +template +C10_HOST_DEVICE inline c10::complex expm1(const c10::complex& z) { + // expm1(z) = exp(z) - 1 + // Define z = x + i * y + // f = e ^ (x + i * y) - 1 + // = e ^ x * e ^ (i * y) - 1 + // = (e ^ x * cos(y) - 1) + i * (e ^ x * sin(y)) + // = (e ^ x - 1) * cos(y) - (1 - cos(y)) + i * e ^ x * sin(y) + // = expm1(x) * cos(y) - 2 * sin(y / 2) ^ 2 + i * e ^ x * sin(y) + T x = z.real(); + T y = z.imag(); + T a = std::sin(y / 2); + T er = std::expm1(x) * std::cos(y) - T(2) * a * a; + T ei = std::exp(x) * std::sin(y); + return {er, ei}; +} + +} // namespace c10_complex_math + +using c10_complex_math::acos; +using c10_complex_math::acosh; +using c10_complex_math::asin; +using c10_complex_math::asinh; +using c10_complex_math::atan; +using c10_complex_math::atanh; +using c10_complex_math::cos; +using c10_complex_math::cosh; +using c10_complex_math::exp; +using c10_complex_math::expm1; +using c10_complex_math::log; +using c10_complex_math::log10; +using c10_complex_math::log1p; +using c10_complex_math::log2; +using c10_complex_math::pow; +using c10_complex_math::sin; +using c10_complex_math::sinh; +using c10_complex_math::sqrt; +using c10_complex_math::tan; +using c10_complex_math::tanh; + +namespace std { + +using c10_complex_math::acos; +using c10_complex_math::acosh; +using c10_complex_math::asin; +using c10_complex_math::asinh; +using c10_complex_math::atan; +using c10_complex_math::atanh; +using c10_complex_math::cos; +using c10_complex_math::cosh; +using c10_complex_math::exp; +using c10_complex_math::expm1; +using c10_complex_math::log; +using c10_complex_math::log10; +using c10_complex_math::log1p; +using c10_complex_math::log2; +using c10_complex_math::pow; +using c10_complex_math::sin; +using c10_complex_math::sinh; +using c10_complex_math::sqrt; +using c10_complex_math::tan; +using c10_complex_math::tanh; + +} // namespace std + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/complex_utils.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/complex_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..44152b72cb35b7df727ece02b089350be04a9f7f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/complex_utils.h @@ -0,0 +1,51 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#if !defined(C10_INTERNAL_INCLUDE_COMPLEX_REMAINING_H) +#error \ + "c10/util/complex_utils.h is not meant to be individually included. Include c10/util/complex.h instead." +#endif + +#include + +namespace c10 { + +template +struct is_complex : public std::false_type {}; + +template +struct is_complex> : public std::true_type {}; + +template +struct is_complex> : public std::true_type {}; + +// Extract double from std::complex; is identity otherwise +// TODO: Write in more idiomatic C++17 +template +struct scalar_value_type { + using type = T; +}; +template +struct scalar_value_type> { + using type = T; +}; +template +struct scalar_value_type> { + using type = T; +}; + +} // namespace c10 + +namespace std { + +template +class numeric_limits> : public numeric_limits {}; + +template +bool isnan(const c10::complex& v) { + return std::isnan(v.real()) || std::isnan(v.imag()); +} + +} // namespace std + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/copysign.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/copysign.h new file mode 100644 index 0000000000000000000000000000000000000000..6bc7c7956f3986ca3c3f10252bd6eb06a7fd1104 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/copysign.h @@ -0,0 +1,32 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +namespace c10 { + +// Note: Explicit implementation of copysign for Half and BFloat16 +// is needed to workaround g++-7/8 crash on aarch64, but also makes +// copysign faster for the half-precision types +template +inline auto copysign(const T& a, const U& b) { + return std::copysign(a, b); +} + +// Implement copysign for half precision floats using bit ops +// Sign is the most significant bit for both half and bfloat16 types +inline c10::Half copysign(c10::Half a, c10::Half b) { + return c10::Half((a.x & 0x7fff) | (b.x & 0x8000), c10::Half::from_bits()); +} + +inline c10::BFloat16 copysign(c10::BFloat16 a, c10::BFloat16 b) { + return c10::BFloat16( + (a.x & 0x7fff) | (b.x & 0x8000), c10::BFloat16::from_bits()); +} + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/env.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/env.h new file mode 100644 index 0000000000000000000000000000000000000000..538a6e271f9d56564bcd8ff73071974991513009 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/env.h @@ -0,0 +1,36 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include + +namespace c10::utils { + +// Set an environment variable. +C10_API void set_env( + const char* name, + const char* value, + bool overwrite = true); + +// Checks an environment variable is set. +C10_API bool has_env(const char* name) noexcept; + +// Reads an environment variable and returns +// - std::optional, if set equal to "1" +// - std::optional, if set equal to "0" +// - nullopt, otherwise +// +// NB: +// Issues a warning if the value of the environment variable is not 0 or 1. +C10_API std::optional check_env(const char* name); + +// Reads the value of an environment variable if it is set. +// However, check_env should be used if the value is assumed to be a flag. +C10_API std::optional get_env(const char* name) noexcept; + +} // namespace c10::utils + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/error.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/error.h new file mode 100644 index 0000000000000000000000000000000000000000..4afd8a9ab673ff71cb1d0a58e209262096e86347 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/error.h @@ -0,0 +1,16 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +namespace c10::utils { + +// Get an error string in the thread-safe way. +C10_API std::string str_error(int errnum); + +} // namespace c10::utils + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/flat_hash_map.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/flat_hash_map.h new file mode 100644 index 0000000000000000000000000000000000000000..653401395d4098ea77752e4bafdb64682ac8c242 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/flat_hash_map.h @@ -0,0 +1,2107 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Taken from +// https://github.com/skarupke/flat_hash_map/blob/2c4687431f978f02a3780e24b8b701d22aa32d9c/flat_hash_map.hpp +// with fixes applied: +// - https://github.com/skarupke/flat_hash_map/pull/25 +// - https://github.com/skarupke/flat_hash_map/pull/26 +// - replace size_t with uint64_t to fix it for 32bit +// - add "GCC diagnostic" pragma to ignore -Wshadow +// - make sherwood_v3_table::convertible_to_iterator public because GCC5 seems +// to have issues with it otherwise +// - fix compiler warnings in operator templated_iterator +// - make use of 'if constexpr' and eliminate AssignIfTrue template + +// Copyright Malte Skarupke 2017. +// Distributed under the Boost Software License, Version 1.0. +// (See http://www.boost.org/LICENSE_1_0.txt) + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +C10_CLANG_DIAGNOSTIC_PUSH() +#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion") +C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") +#endif + +#if defined(_MSC_VER) && !defined(__clang__) +#pragma warning(push) +#pragma warning(disable : 4624) // destructor was implicitly defined as deleted +#endif + +#ifdef _MSC_VER +#define SKA_NOINLINE(...) __declspec(noinline) __VA_ARGS__ +#else +#define SKA_NOINLINE(...) __VA_ARGS__ __attribute__((noinline)) +#endif + +namespace ska { +struct prime_number_hash_policy; +struct power_of_two_hash_policy; +struct fibonacci_hash_policy; + +namespace detailv3 { +template +struct functor_storage : Functor { + functor_storage() = default; + functor_storage(const Functor& functor) : Functor(functor) {} + template + Result operator()(Args&&... args) { + return static_cast(*this)(std::forward(args)...); + } + template + Result operator()(Args&&... args) const { + return static_cast(*this)(std::forward(args)...); + } +}; +template +struct functor_storage { + typedef Result (*function_ptr)(Args...); + function_ptr function; + functor_storage(function_ptr function) : function(function) {} + Result operator()(Args... args) const { + return function(std::forward(args)...); + } + operator function_ptr&() { + return function; + } + operator const function_ptr&() { + return function; + } +}; +template +struct KeyOrValueHasher : functor_storage { + typedef functor_storage hasher_storage; + KeyOrValueHasher() = default; + KeyOrValueHasher(const hasher& hash) : hasher_storage(hash) {} + uint64_t operator()(const key_type& key) { + return static_cast(*this)(key); + } + uint64_t operator()(const key_type& key) const { + return static_cast(*this)(key); + } + uint64_t operator()(const value_type& value) { + return static_cast(*this)(value.first); + } + uint64_t operator()(const value_type& value) const { + return static_cast(*this)(value.first); + } + template + uint64_t operator()(const std::pair& value) { + return static_cast(*this)(value.first); + } + template + uint64_t operator()(const std::pair& value) const { + return static_cast(*this)(value.first); + } +}; +template +struct KeyOrValueEquality : functor_storage { + typedef functor_storage equality_storage; + KeyOrValueEquality() = default; + KeyOrValueEquality(const key_equal& equality) : equality_storage(equality) {} + bool operator()(const key_type& lhs, const key_type& rhs) { + return static_cast(*this)(lhs, rhs); + } + bool operator()(const key_type& lhs, const value_type& rhs) { + return static_cast(*this)(lhs, rhs.first); + } + bool operator()(const value_type& lhs, const key_type& rhs) { + return static_cast(*this)(lhs.first, rhs); + } + bool operator()(const value_type& lhs, const value_type& rhs) { + return static_cast(*this)(lhs.first, rhs.first); + } + template + bool operator()(const key_type& lhs, const std::pair& rhs) { + return static_cast(*this)(lhs, rhs.first); + } + template + bool operator()(const std::pair& lhs, const key_type& rhs) { + return static_cast(*this)(lhs.first, rhs); + } + template + bool operator()(const value_type& lhs, const std::pair& rhs) { + return static_cast(*this)(lhs.first, rhs.first); + } + template + bool operator()(const std::pair& lhs, const value_type& rhs) { + return static_cast(*this)(lhs.first, rhs.first); + } + template + bool operator()(const std::pair& lhs, const std::pair& rhs) { + return static_cast(*this)(lhs.first, rhs.first); + } +}; +static constexpr int8_t min_lookups = 4; +template +struct sherwood_v3_entry { + sherwood_v3_entry() = default; + sherwood_v3_entry(int8_t distance_from_desired) + : distance_from_desired(distance_from_desired) {} + ~sherwood_v3_entry() = default; + + bool has_value() const { + return distance_from_desired >= 0; + } + bool is_empty() const { + return distance_from_desired < 0; + } + bool is_at_desired_position() const { + return distance_from_desired <= 0; + } + template + void emplace(int8_t distance, Args&&... args) { + new (std::addressof(value)) T(std::forward(args)...); + distance_from_desired = distance; + } + + void destroy_value() { + value.~T(); + distance_from_desired = -1; + } + + int8_t distance_from_desired = -1; + static constexpr int8_t special_end_value = 0; + union { + T value; + }; +}; + +inline int8_t log2(uint64_t value) { + // NOLINTNEXTLINE(*c-arrays*) + static constexpr int8_t table[64] = { + 63, 0, 58, 1, 59, 47, 53, 2, 60, 39, 48, 27, 54, 33, 42, 3, + 61, 51, 37, 40, 49, 18, 28, 20, 55, 30, 34, 11, 43, 14, 22, 4, + 62, 57, 46, 52, 38, 26, 32, 41, 50, 36, 17, 19, 29, 10, 13, 21, + 56, 45, 25, 31, 35, 16, 9, 12, 44, 24, 15, 8, 23, 7, 6, 5}; + value |= value >> 1; + value |= value >> 2; + value |= value >> 4; + value |= value >> 8; + value |= value >> 16; + value |= value >> 32; + return table[((value - (value >> 1)) * 0x07EDD5E59A4E28C2) >> 58]; +} + +inline uint64_t next_power_of_two(uint64_t i) { + --i; + i |= i >> 1; + i |= i >> 2; + i |= i >> 4; + i |= i >> 8; + i |= i >> 16; + i |= i >> 32; + ++i; + return i; +} + +// Implementation taken from http://en.cppreference.com/w/cpp/types/void_t +// (it takes CWG1558 into account and also works for older compilers) +template +struct make_void { + typedef void type; +}; +template +using void_t = typename make_void::type; + +template +struct HashPolicySelector { + typedef fibonacci_hash_policy type; +}; +template +struct HashPolicySelector> { + typedef typename T::hash_policy type; +}; + +template < + typename T, + typename FindKey, + typename ArgumentHash, + typename DetailHasher, + typename ArgumentEqual, + typename Equal, + typename ArgumentAlloc, + typename EntryAlloc> +class sherwood_v3_table : private EntryAlloc, + private DetailHasher, + private Equal { + using Entry = detailv3::sherwood_v3_entry; + using AllocatorTraits = std::allocator_traits; + using EntryPointer = typename AllocatorTraits::pointer; + + public: + struct convertible_to_iterator; + + using value_type = T; + using size_type = uint64_t; + using difference_type = std::ptrdiff_t; + using hasher = ArgumentHash; + using key_equal = ArgumentEqual; + using allocator_type = EntryAlloc; + using reference = value_type&; + using const_reference = const value_type&; + using pointer = value_type*; + using const_pointer = const value_type*; + + sherwood_v3_table() = default; + explicit sherwood_v3_table( + size_type bucket_count, + const ArgumentHash& hash = ArgumentHash(), + const ArgumentEqual& equal = ArgumentEqual(), + const ArgumentAlloc& alloc = ArgumentAlloc()) + : EntryAlloc(alloc), DetailHasher(hash), Equal(equal) { + rehash(bucket_count); + } + sherwood_v3_table(size_type bucket_count, const ArgumentAlloc& alloc) + : sherwood_v3_table( + bucket_count, + ArgumentHash(), + ArgumentEqual(), + alloc) {} + sherwood_v3_table( + size_type bucket_count, + const ArgumentHash& hash, + const ArgumentAlloc& alloc) + : sherwood_v3_table(bucket_count, hash, ArgumentEqual(), alloc) {} + explicit sherwood_v3_table(const ArgumentAlloc& alloc) : EntryAlloc(alloc) {} + template + sherwood_v3_table( + It first, + It last, + size_type bucket_count = 0, + const ArgumentHash& hash = ArgumentHash(), + const ArgumentEqual& equal = ArgumentEqual(), + const ArgumentAlloc& alloc = ArgumentAlloc()) + : sherwood_v3_table(bucket_count, hash, equal, alloc) { + insert(first, last); + } + template + sherwood_v3_table( + It first, + It last, + size_type bucket_count, + const ArgumentAlloc& alloc) + : sherwood_v3_table( + first, + last, + bucket_count, + ArgumentHash(), + ArgumentEqual(), + alloc) {} + template + sherwood_v3_table( + It first, + It last, + size_type bucket_count, + const ArgumentHash& hash, + const ArgumentAlloc& alloc) + : sherwood_v3_table( + first, + last, + bucket_count, + hash, + ArgumentEqual(), + alloc) {} + sherwood_v3_table( + std::initializer_list il, + size_type bucket_count = 0, + const ArgumentHash& hash = ArgumentHash(), + const ArgumentEqual& equal = ArgumentEqual(), + const ArgumentAlloc& alloc = ArgumentAlloc()) + : sherwood_v3_table(bucket_count, hash, equal, alloc) { + if (bucket_count == 0) + rehash(il.size()); + insert(il.begin(), il.end()); + } + sherwood_v3_table( + std::initializer_list il, + size_type bucket_count, + const ArgumentAlloc& alloc) + : sherwood_v3_table( + il, + bucket_count, + ArgumentHash(), + ArgumentEqual(), + alloc) {} + sherwood_v3_table( + std::initializer_list il, + size_type bucket_count, + const ArgumentHash& hash, + const ArgumentAlloc& alloc) + : sherwood_v3_table(il, bucket_count, hash, ArgumentEqual(), alloc) {} + sherwood_v3_table(const sherwood_v3_table& other) + : sherwood_v3_table( + other, + AllocatorTraits::select_on_container_copy_construction( + other.get_allocator())) {} + sherwood_v3_table(const sherwood_v3_table& other, const ArgumentAlloc& alloc) + : EntryAlloc(alloc), + DetailHasher(other), + Equal(other), + _max_load_factor(other._max_load_factor) { + rehash_for_other_container(other); + try { + insert(other.begin(), other.end()); + } catch (...) { + clear(); + deallocate_data(entries, num_slots_minus_one, max_lookups); + throw; + } + } + sherwood_v3_table(sherwood_v3_table&& other) noexcept + : EntryAlloc(std::move(other)), + DetailHasher(std::move(other)), + Equal(std::move(other)) { + swap_pointers(other); + } + sherwood_v3_table( + sherwood_v3_table&& other, + const ArgumentAlloc& alloc) noexcept + : EntryAlloc(alloc), + DetailHasher(std::move(other)), + Equal(std::move(other)) { + swap_pointers(other); + } + sherwood_v3_table& operator=(const sherwood_v3_table& other) { + if (this == std::addressof(other)) + return *this; + + clear(); + if constexpr (AllocatorTraits::propagate_on_container_copy_assignment:: + value) { + if (static_cast(*this) != + static_cast(other)) { + reset_to_empty_state(); + } + static_cast(*this) = other; + } + _max_load_factor = other._max_load_factor; + static_cast(*this) = other; + static_cast(*this) = other; + rehash_for_other_container(other); + insert(other.begin(), other.end()); + return *this; + } + sherwood_v3_table& operator=(sherwood_v3_table&& other) noexcept { + if (this == std::addressof(other)) + return *this; + else if constexpr (AllocatorTraits::propagate_on_container_move_assignment:: + value) { + clear(); + reset_to_empty_state(); + static_cast(*this) = std::move(other); + swap_pointers(other); + } else if ( + static_cast(*this) == static_cast(other)) { + swap_pointers(other); + } else { + clear(); + _max_load_factor = other._max_load_factor; + rehash_for_other_container(other); + for (T& elem : other) + emplace(std::move(elem)); + other.clear(); + } + static_cast(*this) = std::move(other); + static_cast(*this) = std::move(other); + return *this; + } + ~sherwood_v3_table() { + clear(); + deallocate_data(entries, num_slots_minus_one, max_lookups); + } + + const allocator_type& get_allocator() const { + return static_cast(*this); + } + const ArgumentEqual& key_eq() const { + return static_cast(*this); + } + const ArgumentHash& hash_function() const { + return static_cast(*this); + } + + template + struct templated_iterator { + templated_iterator() = default; + templated_iterator(EntryPointer current) : current(current) {} + EntryPointer current = EntryPointer(); + + using iterator_category = std::forward_iterator_tag; + using value_type = ValueType; + using difference_type = ptrdiff_t; + using pointer = ValueType*; + using reference = ValueType&; + + friend bool operator==( + const templated_iterator& lhs, + const templated_iterator& rhs) { + return lhs.current == rhs.current; + } + friend bool operator!=( + const templated_iterator& lhs, + const templated_iterator& rhs) { + return !(lhs == rhs); + } + + templated_iterator& operator++() { + do { + ++current; + } while (current->is_empty()); + return *this; + } + templated_iterator operator++(int) { + templated_iterator copy(*this); + ++*this; + return copy; + } + + ValueType& operator*() const { + return current->value; + } + ValueType* operator->() const { + return std::addressof(current->value); + } + + // the template automatically disables the operator when value_type is + // already const, because that would cause a lot of compiler warnings + // otherwise. + template < + class target_type = const value_type, + class = std::enable_if_t< + std::is_same_v && + !std::is_same_v>> + operator templated_iterator() const { + return {current}; + } + }; + using iterator = templated_iterator; + using const_iterator = templated_iterator; + + iterator begin() { + for (EntryPointer it = entries;; ++it) { + if (it->has_value()) + return {it}; + } + } + const_iterator begin() const { + for (EntryPointer it = entries;; ++it) { + if (it->has_value()) + return {it}; + } + } + const_iterator cbegin() const { + return begin(); + } + iterator end() { + return { + entries + static_cast(num_slots_minus_one + max_lookups)}; + } + const_iterator end() const { + return { + entries + static_cast(num_slots_minus_one + max_lookups)}; + } + const_iterator cend() const { + return end(); + } + + iterator find(const FindKey& key) { + uint64_t index = + hash_policy.index_for_hash(hash_object(key), num_slots_minus_one); + EntryPointer it = entries + ptrdiff_t(index); + for (int8_t distance = 0; it->distance_from_desired >= distance; + ++distance, ++it) { + if (compares_equal(key, it->value)) + return {it}; + } + return end(); + } + const_iterator find(const FindKey& key) const { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) + return const_cast(this)->find(key); + } + uint64_t count(const FindKey& key) const { + return find(key) == end() ? 0 : 1; + } + std::pair equal_range(const FindKey& key) { + iterator found = find(key); + if (found == end()) + return {found, found}; + else + return {found, std::next(found)}; + } + std::pair equal_range( + const FindKey& key) const { + const_iterator found = find(key); + if (found == end()) + return {found, found}; + else + return {found, std::next(found)}; + } + + template + std::pair emplace(Key&& key, Args&&... args) { + uint64_t index = + hash_policy.index_for_hash(hash_object(key), num_slots_minus_one); + EntryPointer current_entry = entries + ptrdiff_t(index); + int8_t distance_from_desired = 0; + for (; current_entry->distance_from_desired >= distance_from_desired; + ++current_entry, ++distance_from_desired) { + if (compares_equal(key, current_entry->value)) + return {{current_entry}, false}; + } + return emplace_new_key( + distance_from_desired, + current_entry, + std::forward(key), + std::forward(args)...); + } + + std::pair insert(const value_type& value) { + return emplace(value); + } + std::pair insert(value_type&& value) { + return emplace(std::move(value)); + } + template + iterator emplace_hint(const_iterator /*unused*/, Args&&... args) { + return emplace(std::forward(args)...).first; + } + iterator insert(const_iterator /*unused*/, const value_type& value) { + return emplace(value).first; + } + iterator insert(const_iterator /*unused*/, value_type&& value) { + return emplace(std::move(value)).first; + } + + template + void insert(It begin, It end) { + for (; begin != end; ++begin) { + emplace(*begin); + } + } + void insert(std::initializer_list il) { + insert(il.begin(), il.end()); + } + + void rehash(uint64_t num_buckets) { + num_buckets = std::max( + num_buckets, + static_cast( + std::ceil(num_elements / static_cast(_max_load_factor)))); + if (num_buckets == 0) { + reset_to_empty_state(); + return; + } + auto new_prime_index = hash_policy.next_size_over(num_buckets); + if (num_buckets == bucket_count()) + return; + int8_t new_max_lookups = compute_max_lookups(num_buckets); + EntryPointer new_buckets( + AllocatorTraits::allocate(*this, num_buckets + new_max_lookups)); + EntryPointer special_end_item = + new_buckets + static_cast(num_buckets + new_max_lookups - 1); + for (EntryPointer it = new_buckets; it != special_end_item; ++it) + it->distance_from_desired = -1; + special_end_item->distance_from_desired = Entry::special_end_value; + std::swap(entries, new_buckets); + std::swap(num_slots_minus_one, num_buckets); + --num_slots_minus_one; + hash_policy.commit(new_prime_index); + int8_t old_max_lookups = max_lookups; + max_lookups = new_max_lookups; + num_elements = 0; + for (EntryPointer + it = new_buckets, + end = it + static_cast(num_buckets + old_max_lookups); + it != end; + ++it) { + if (it->has_value()) { + emplace(std::move(it->value)); + it->destroy_value(); + } + } + deallocate_data(new_buckets, num_buckets, old_max_lookups); + } + + void reserve(uint64_t num_elements_) { + uint64_t required_buckets = num_buckets_for_reserve(num_elements_); + if (required_buckets > bucket_count()) + rehash(required_buckets); + } + + // the return value is a type that can be converted to an iterator + // the reason for doing this is that it's not free to find the + // iterator pointing at the next element. if you care about the + // next iterator, turn the return value into an iterator + convertible_to_iterator erase(const_iterator to_erase) { + EntryPointer current = to_erase.current; + current->destroy_value(); + --num_elements; + for (EntryPointer next = current + ptrdiff_t(1); + !next->is_at_desired_position(); + ++current, ++next) { + current->emplace(next->distance_from_desired - 1, std::move(next->value)); + next->destroy_value(); + } + return {to_erase.current}; + } + + iterator erase(const_iterator begin_it, const_iterator end_it) { + if (begin_it == end_it) + return {begin_it.current}; + for (EntryPointer it = begin_it.current, end = end_it.current; it != end; + ++it) { + if (it->has_value()) { + it->destroy_value(); + --num_elements; + } + } + if (end_it == this->end()) + return this->end(); + ptrdiff_t num_to_move = std::min( + static_cast(end_it.current->distance_from_desired), + end_it.current - begin_it.current); + EntryPointer to_return = end_it.current - num_to_move; + for (EntryPointer it = end_it.current; !it->is_at_desired_position();) { + EntryPointer target = it - num_to_move; + target->emplace( + it->distance_from_desired - num_to_move, std::move(it->value)); + it->destroy_value(); + ++it; + num_to_move = std::min( + static_cast(it->distance_from_desired), num_to_move); + } + return {to_return}; + } + + uint64_t erase(const FindKey& key) { + auto found = find(key); + if (found == end()) + return 0; + else { + erase(found); + return 1; + } + } + + void clear() { + for (EntryPointer it = entries, + end = it + + static_cast(num_slots_minus_one + max_lookups); + it != end; + ++it) { + if (it->has_value()) + it->destroy_value(); + } + num_elements = 0; + } + + void shrink_to_fit() { + rehash_for_other_container(*this); + } + + void swap(sherwood_v3_table& other) noexcept { + using std::swap; + swap_pointers(other); + swap(static_cast(*this), static_cast(other)); + swap( + static_cast(*this), static_cast(other)); + if (AllocatorTraits::propagate_on_container_swap::value) + swap(static_cast(*this), static_cast(other)); + } + + uint64_t size() const { + return num_elements; + } + uint64_t max_size() const { + return (AllocatorTraits::max_size(*this)) / sizeof(Entry); + } + uint64_t bucket_count() const { + return num_slots_minus_one ? num_slots_minus_one + 1 : 0; + } + size_type max_bucket_count() const { + return (AllocatorTraits::max_size(*this) - min_lookups) / sizeof(Entry); + } + uint64_t bucket(const FindKey& key) const { + return hash_policy.index_for_hash(hash_object(key), num_slots_minus_one); + } + float load_factor() const { + uint64_t buckets = bucket_count(); + if (buckets) + return static_cast(num_elements) / bucket_count(); + else + return 0; + } + void max_load_factor(float value) { + _max_load_factor = value; + } + float max_load_factor() const { + return _max_load_factor; + } + + bool empty() const { + return num_elements == 0; + } + + private: + EntryPointer entries = empty_default_table(); + uint64_t num_slots_minus_one = 0; + typename HashPolicySelector::type hash_policy; + int8_t max_lookups = detailv3::min_lookups - 1; + float _max_load_factor = 0.5f; + uint64_t num_elements = 0; + + EntryPointer empty_default_table() { + EntryPointer result = + AllocatorTraits::allocate(*this, detailv3::min_lookups); + EntryPointer special_end_item = + result + static_cast(detailv3::min_lookups - 1); + for (EntryPointer it = result; it != special_end_item; ++it) + it->distance_from_desired = -1; + special_end_item->distance_from_desired = Entry::special_end_value; + return result; + } + + static int8_t compute_max_lookups(uint64_t num_buckets) { + int8_t desired = detailv3::log2(num_buckets); + return std::max(detailv3::min_lookups, desired); + } + + uint64_t num_buckets_for_reserve(uint64_t num_elements_) const { + return static_cast(std::ceil( + static_cast(num_elements_) / + std::min(0.5, static_cast(_max_load_factor)))); + } + void rehash_for_other_container(const sherwood_v3_table& other) { + rehash( + std::min(num_buckets_for_reserve(other.size()), other.bucket_count())); + } + + void swap_pointers(sherwood_v3_table& other) { + using std::swap; + swap(hash_policy, other.hash_policy); + swap(entries, other.entries); + swap(num_slots_minus_one, other.num_slots_minus_one); + swap(num_elements, other.num_elements); + swap(max_lookups, other.max_lookups); + swap(_max_load_factor, other._max_load_factor); + } + + template + SKA_NOINLINE(std::pair) + emplace_new_key( + int8_t distance_from_desired, + EntryPointer current_entry, + Key&& key, + Args&&... args) { + using std::swap; + if (num_slots_minus_one == 0 || distance_from_desired == max_lookups || + num_elements + 1 > + (num_slots_minus_one + 1) * static_cast(_max_load_factor)) { + grow(); + return emplace(std::forward(key), std::forward(args)...); + } else if (current_entry->is_empty()) { + current_entry->emplace( + distance_from_desired, + std::forward(key), + std::forward(args)...); + ++num_elements; + return {{current_entry}, true}; + } + value_type to_insert(std::forward(key), std::forward(args)...); + swap(distance_from_desired, current_entry->distance_from_desired); + swap(to_insert, current_entry->value); + iterator result = {current_entry}; + for (++distance_from_desired, ++current_entry;; ++current_entry) { + if (current_entry->is_empty()) { + current_entry->emplace(distance_from_desired, std::move(to_insert)); + ++num_elements; + return {result, true}; + } else if (current_entry->distance_from_desired < distance_from_desired) { + swap(distance_from_desired, current_entry->distance_from_desired); + swap(to_insert, current_entry->value); + ++distance_from_desired; + } else { + ++distance_from_desired; + if (distance_from_desired == max_lookups) { + swap(to_insert, result.current->value); + grow(); + return emplace(std::move(to_insert)); + } + } + } + } + + void grow() { + rehash(std::max(uint64_t(4), 2 * bucket_count())); + } + + void deallocate_data( + EntryPointer begin, + uint64_t num_slots_minus_one_, + int8_t max_lookups_) { + AllocatorTraits::deallocate( + *this, begin, num_slots_minus_one_ + max_lookups_ + 1); + } + + void reset_to_empty_state() { + deallocate_data(entries, num_slots_minus_one, max_lookups); + entries = empty_default_table(); + num_slots_minus_one = 0; + hash_policy.reset(); + max_lookups = detailv3::min_lookups - 1; + } + + template + uint64_t hash_object(const U& key) { + return static_cast(*this)(key); + } + template + uint64_t hash_object(const U& key) const { + return static_cast(*this)(key); + } + template + bool compares_equal(const L& lhs, const R& rhs) { + return static_cast(*this)(lhs, rhs); + } + + public: + struct convertible_to_iterator { + EntryPointer it; + + operator iterator() { + if (it->has_value()) + return {it}; + else + return ++iterator{it}; + } + operator const_iterator() { + if (it->has_value()) + return {it}; + else + return ++const_iterator{it}; + } + }; +}; +} // namespace detailv3 + +struct prime_number_hash_policy { + static uint64_t mod0(uint64_t /*unused*/) { + return 0llu; + } + static uint64_t mod2(uint64_t hash) { + return hash % 2llu; + } + static uint64_t mod3(uint64_t hash) { + return hash % 3llu; + } + static uint64_t mod5(uint64_t hash) { + return hash % 5llu; + } + static uint64_t mod7(uint64_t hash) { + return hash % 7llu; + } + static uint64_t mod11(uint64_t hash) { + return hash % 11llu; + } + static uint64_t mod13(uint64_t hash) { + return hash % 13llu; + } + static uint64_t mod17(uint64_t hash) { + return hash % 17llu; + } + static uint64_t mod23(uint64_t hash) { + return hash % 23llu; + } + static uint64_t mod29(uint64_t hash) { + return hash % 29llu; + } + static uint64_t mod37(uint64_t hash) { + return hash % 37llu; + } + static uint64_t mod47(uint64_t hash) { + return hash % 47llu; + } + static uint64_t mod59(uint64_t hash) { + return hash % 59llu; + } + static uint64_t mod73(uint64_t hash) { + return hash % 73llu; + } + static uint64_t mod97(uint64_t hash) { + return hash % 97llu; + } + static uint64_t mod127(uint64_t hash) { + return hash % 127llu; + } + static uint64_t mod151(uint64_t hash) { + return hash % 151llu; + } + static uint64_t mod197(uint64_t hash) { + return hash % 197llu; + } + static uint64_t mod251(uint64_t hash) { + return hash % 251llu; + } + static uint64_t mod313(uint64_t hash) { + return hash % 313llu; + } + static uint64_t mod397(uint64_t hash) { + return hash % 397llu; + } + static uint64_t mod499(uint64_t hash) { + return hash % 499llu; + } + static uint64_t mod631(uint64_t hash) { + return hash % 631llu; + } + static uint64_t mod797(uint64_t hash) { + return hash % 797llu; + } + static uint64_t mod1009(uint64_t hash) { + return hash % 1009llu; + } + static uint64_t mod1259(uint64_t hash) { + return hash % 1259llu; + } + static uint64_t mod1597(uint64_t hash) { + return hash % 1597llu; + } + static uint64_t mod2011(uint64_t hash) { + return hash % 2011llu; + } + static uint64_t mod2539(uint64_t hash) { + return hash % 2539llu; + } + static uint64_t mod3203(uint64_t hash) { + return hash % 3203llu; + } + static uint64_t mod4027(uint64_t hash) { + return hash % 4027llu; + } + static uint64_t mod5087(uint64_t hash) { + return hash % 5087llu; + } + static uint64_t mod6421(uint64_t hash) { + return hash % 6421llu; + } + static uint64_t mod8089(uint64_t hash) { + return hash % 8089llu; + } + static uint64_t mod10193(uint64_t hash) { + return hash % 10193llu; + } + static uint64_t mod12853(uint64_t hash) { + return hash % 12853llu; + } + static uint64_t mod16193(uint64_t hash) { + return hash % 16193llu; + } + static uint64_t mod20399(uint64_t hash) { + return hash % 20399llu; + } + static uint64_t mod25717(uint64_t hash) { + return hash % 25717llu; + } + static uint64_t mod32401(uint64_t hash) { + return hash % 32401llu; + } + static uint64_t mod40823(uint64_t hash) { + return hash % 40823llu; + } + static uint64_t mod51437(uint64_t hash) { + return hash % 51437llu; + } + static uint64_t mod64811(uint64_t hash) { + return hash % 64811llu; + } + static uint64_t mod81649(uint64_t hash) { + return hash % 81649llu; + } + static uint64_t mod102877(uint64_t hash) { + return hash % 102877llu; + } + static uint64_t mod129607(uint64_t hash) { + return hash % 129607llu; + } + static uint64_t mod163307(uint64_t hash) { + return hash % 163307llu; + } + static uint64_t mod205759(uint64_t hash) { + return hash % 205759llu; + } + static uint64_t mod259229(uint64_t hash) { + return hash % 259229llu; + } + static uint64_t mod326617(uint64_t hash) { + return hash % 326617llu; + } + static uint64_t mod411527(uint64_t hash) { + return hash % 411527llu; + } + static uint64_t mod518509(uint64_t hash) { + return hash % 518509llu; + } + static uint64_t mod653267(uint64_t hash) { + return hash % 653267llu; + } + static uint64_t mod823117(uint64_t hash) { + return hash % 823117llu; + } + static uint64_t mod1037059(uint64_t hash) { + return hash % 1037059llu; + } + static uint64_t mod1306601(uint64_t hash) { + return hash % 1306601llu; + } + static uint64_t mod1646237(uint64_t hash) { + return hash % 1646237llu; + } + static uint64_t mod2074129(uint64_t hash) { + return hash % 2074129llu; + } + static uint64_t mod2613229(uint64_t hash) { + return hash % 2613229llu; + } + static uint64_t mod3292489(uint64_t hash) { + return hash % 3292489llu; + } + static uint64_t mod4148279(uint64_t hash) { + return hash % 4148279llu; + } + static uint64_t mod5226491(uint64_t hash) { + return hash % 5226491llu; + } + static uint64_t mod6584983(uint64_t hash) { + return hash % 6584983llu; + } + static uint64_t mod8296553(uint64_t hash) { + return hash % 8296553llu; + } + static uint64_t mod10453007(uint64_t hash) { + return hash % 10453007llu; + } + static uint64_t mod13169977(uint64_t hash) { + return hash % 13169977llu; + } + static uint64_t mod16593127(uint64_t hash) { + return hash % 16593127llu; + } + static uint64_t mod20906033(uint64_t hash) { + return hash % 20906033llu; + } + static uint64_t mod26339969(uint64_t hash) { + return hash % 26339969llu; + } + static uint64_t mod33186281(uint64_t hash) { + return hash % 33186281llu; + } + static uint64_t mod41812097(uint64_t hash) { + return hash % 41812097llu; + } + static uint64_t mod52679969(uint64_t hash) { + return hash % 52679969llu; + } + static uint64_t mod66372617(uint64_t hash) { + return hash % 66372617llu; + } + static uint64_t mod83624237(uint64_t hash) { + return hash % 83624237llu; + } + static uint64_t mod105359939(uint64_t hash) { + return hash % 105359939llu; + } + static uint64_t mod132745199(uint64_t hash) { + return hash % 132745199llu; + } + static uint64_t mod167248483(uint64_t hash) { + return hash % 167248483llu; + } + static uint64_t mod210719881(uint64_t hash) { + return hash % 210719881llu; + } + static uint64_t mod265490441(uint64_t hash) { + return hash % 265490441llu; + } + static uint64_t mod334496971(uint64_t hash) { + return hash % 334496971llu; + } + static uint64_t mod421439783(uint64_t hash) { + return hash % 421439783llu; + } + static uint64_t mod530980861(uint64_t hash) { + return hash % 530980861llu; + } + static uint64_t mod668993977(uint64_t hash) { + return hash % 668993977llu; + } + static uint64_t mod842879579(uint64_t hash) { + return hash % 842879579llu; + } + static uint64_t mod1061961721(uint64_t hash) { + return hash % 1061961721llu; + } + static uint64_t mod1337987929(uint64_t hash) { + return hash % 1337987929llu; + } + static uint64_t mod1685759167(uint64_t hash) { + return hash % 1685759167llu; + } + static uint64_t mod2123923447(uint64_t hash) { + return hash % 2123923447llu; + } + static uint64_t mod2675975881(uint64_t hash) { + return hash % 2675975881llu; + } + static uint64_t mod3371518343(uint64_t hash) { + return hash % 3371518343llu; + } + static uint64_t mod4247846927(uint64_t hash) { + return hash % 4247846927llu; + } + static uint64_t mod5351951779(uint64_t hash) { + return hash % 5351951779llu; + } + static uint64_t mod6743036717(uint64_t hash) { + return hash % 6743036717llu; + } + static uint64_t mod8495693897(uint64_t hash) { + return hash % 8495693897llu; + } + static uint64_t mod10703903591(uint64_t hash) { + return hash % 10703903591llu; + } + static uint64_t mod13486073473(uint64_t hash) { + return hash % 13486073473llu; + } + static uint64_t mod16991387857(uint64_t hash) { + return hash % 16991387857llu; + } + static uint64_t mod21407807219(uint64_t hash) { + return hash % 21407807219llu; + } + static uint64_t mod26972146961(uint64_t hash) { + return hash % 26972146961llu; + } + static uint64_t mod33982775741(uint64_t hash) { + return hash % 33982775741llu; + } + static uint64_t mod42815614441(uint64_t hash) { + return hash % 42815614441llu; + } + static uint64_t mod53944293929(uint64_t hash) { + return hash % 53944293929llu; + } + static uint64_t mod67965551447(uint64_t hash) { + return hash % 67965551447llu; + } + static uint64_t mod85631228929(uint64_t hash) { + return hash % 85631228929llu; + } + static uint64_t mod107888587883(uint64_t hash) { + return hash % 107888587883llu; + } + static uint64_t mod135931102921(uint64_t hash) { + return hash % 135931102921llu; + } + static uint64_t mod171262457903(uint64_t hash) { + return hash % 171262457903llu; + } + static uint64_t mod215777175787(uint64_t hash) { + return hash % 215777175787llu; + } + static uint64_t mod271862205833(uint64_t hash) { + return hash % 271862205833llu; + } + static uint64_t mod342524915839(uint64_t hash) { + return hash % 342524915839llu; + } + static uint64_t mod431554351609(uint64_t hash) { + return hash % 431554351609llu; + } + static uint64_t mod543724411781(uint64_t hash) { + return hash % 543724411781llu; + } + static uint64_t mod685049831731(uint64_t hash) { + return hash % 685049831731llu; + } + static uint64_t mod863108703229(uint64_t hash) { + return hash % 863108703229llu; + } + static uint64_t mod1087448823553(uint64_t hash) { + return hash % 1087448823553llu; + } + static uint64_t mod1370099663459(uint64_t hash) { + return hash % 1370099663459llu; + } + static uint64_t mod1726217406467(uint64_t hash) { + return hash % 1726217406467llu; + } + static uint64_t mod2174897647073(uint64_t hash) { + return hash % 2174897647073llu; + } + static uint64_t mod2740199326961(uint64_t hash) { + return hash % 2740199326961llu; + } + static uint64_t mod3452434812973(uint64_t hash) { + return hash % 3452434812973llu; + } + static uint64_t mod4349795294267(uint64_t hash) { + return hash % 4349795294267llu; + } + static uint64_t mod5480398654009(uint64_t hash) { + return hash % 5480398654009llu; + } + static uint64_t mod6904869625999(uint64_t hash) { + return hash % 6904869625999llu; + } + static uint64_t mod8699590588571(uint64_t hash) { + return hash % 8699590588571llu; + } + static uint64_t mod10960797308051(uint64_t hash) { + return hash % 10960797308051llu; + } + static uint64_t mod13809739252051(uint64_t hash) { + return hash % 13809739252051llu; + } + static uint64_t mod17399181177241(uint64_t hash) { + return hash % 17399181177241llu; + } + static uint64_t mod21921594616111(uint64_t hash) { + return hash % 21921594616111llu; + } + static uint64_t mod27619478504183(uint64_t hash) { + return hash % 27619478504183llu; + } + static uint64_t mod34798362354533(uint64_t hash) { + return hash % 34798362354533llu; + } + static uint64_t mod43843189232363(uint64_t hash) { + return hash % 43843189232363llu; + } + static uint64_t mod55238957008387(uint64_t hash) { + return hash % 55238957008387llu; + } + static uint64_t mod69596724709081(uint64_t hash) { + return hash % 69596724709081llu; + } + static uint64_t mod87686378464759(uint64_t hash) { + return hash % 87686378464759llu; + } + static uint64_t mod110477914016779(uint64_t hash) { + return hash % 110477914016779llu; + } + static uint64_t mod139193449418173(uint64_t hash) { + return hash % 139193449418173llu; + } + static uint64_t mod175372756929481(uint64_t hash) { + return hash % 175372756929481llu; + } + static uint64_t mod220955828033581(uint64_t hash) { + return hash % 220955828033581llu; + } + static uint64_t mod278386898836457(uint64_t hash) { + return hash % 278386898836457llu; + } + static uint64_t mod350745513859007(uint64_t hash) { + return hash % 350745513859007llu; + } + static uint64_t mod441911656067171(uint64_t hash) { + return hash % 441911656067171llu; + } + static uint64_t mod556773797672909(uint64_t hash) { + return hash % 556773797672909llu; + } + static uint64_t mod701491027718027(uint64_t hash) { + return hash % 701491027718027llu; + } + static uint64_t mod883823312134381(uint64_t hash) { + return hash % 883823312134381llu; + } + static uint64_t mod1113547595345903(uint64_t hash) { + return hash % 1113547595345903llu; + } + static uint64_t mod1402982055436147(uint64_t hash) { + return hash % 1402982055436147llu; + } + static uint64_t mod1767646624268779(uint64_t hash) { + return hash % 1767646624268779llu; + } + static uint64_t mod2227095190691797(uint64_t hash) { + return hash % 2227095190691797llu; + } + static uint64_t mod2805964110872297(uint64_t hash) { + return hash % 2805964110872297llu; + } + static uint64_t mod3535293248537579(uint64_t hash) { + return hash % 3535293248537579llu; + } + static uint64_t mod4454190381383713(uint64_t hash) { + return hash % 4454190381383713llu; + } + static uint64_t mod5611928221744609(uint64_t hash) { + return hash % 5611928221744609llu; + } + static uint64_t mod7070586497075177(uint64_t hash) { + return hash % 7070586497075177llu; + } + static uint64_t mod8908380762767489(uint64_t hash) { + return hash % 8908380762767489llu; + } + static uint64_t mod11223856443489329(uint64_t hash) { + return hash % 11223856443489329llu; + } + static uint64_t mod14141172994150357(uint64_t hash) { + return hash % 14141172994150357llu; + } + static uint64_t mod17816761525534927(uint64_t hash) { + return hash % 17816761525534927llu; + } + static uint64_t mod22447712886978529(uint64_t hash) { + return hash % 22447712886978529llu; + } + static uint64_t mod28282345988300791(uint64_t hash) { + return hash % 28282345988300791llu; + } + static uint64_t mod35633523051069991(uint64_t hash) { + return hash % 35633523051069991llu; + } + static uint64_t mod44895425773957261(uint64_t hash) { + return hash % 44895425773957261llu; + } + static uint64_t mod56564691976601587(uint64_t hash) { + return hash % 56564691976601587llu; + } + static uint64_t mod71267046102139967(uint64_t hash) { + return hash % 71267046102139967llu; + } + static uint64_t mod89790851547914507(uint64_t hash) { + return hash % 89790851547914507llu; + } + static uint64_t mod113129383953203213(uint64_t hash) { + return hash % 113129383953203213llu; + } + static uint64_t mod142534092204280003(uint64_t hash) { + return hash % 142534092204280003llu; + } + static uint64_t mod179581703095829107(uint64_t hash) { + return hash % 179581703095829107llu; + } + static uint64_t mod226258767906406483(uint64_t hash) { + return hash % 226258767906406483llu; + } + static uint64_t mod285068184408560057(uint64_t hash) { + return hash % 285068184408560057llu; + } + static uint64_t mod359163406191658253(uint64_t hash) { + return hash % 359163406191658253llu; + } + static uint64_t mod452517535812813007(uint64_t hash) { + return hash % 452517535812813007llu; + } + static uint64_t mod570136368817120201(uint64_t hash) { + return hash % 570136368817120201llu; + } + static uint64_t mod718326812383316683(uint64_t hash) { + return hash % 718326812383316683llu; + } + static uint64_t mod905035071625626043(uint64_t hash) { + return hash % 905035071625626043llu; + } + static uint64_t mod1140272737634240411(uint64_t hash) { + return hash % 1140272737634240411llu; + } + static uint64_t mod1436653624766633509(uint64_t hash) { + return hash % 1436653624766633509llu; + } + static uint64_t mod1810070143251252131(uint64_t hash) { + return hash % 1810070143251252131llu; + } + static uint64_t mod2280545475268481167(uint64_t hash) { + return hash % 2280545475268481167llu; + } + static uint64_t mod2873307249533267101(uint64_t hash) { + return hash % 2873307249533267101llu; + } + static uint64_t mod3620140286502504283(uint64_t hash) { + return hash % 3620140286502504283llu; + } + static uint64_t mod4561090950536962147(uint64_t hash) { + return hash % 4561090950536962147llu; + } + static uint64_t mod5746614499066534157(uint64_t hash) { + return hash % 5746614499066534157llu; + } + static uint64_t mod7240280573005008577(uint64_t hash) { + return hash % 7240280573005008577llu; + } + static uint64_t mod9122181901073924329(uint64_t hash) { + return hash % 9122181901073924329llu; + } + static uint64_t mod11493228998133068689(uint64_t hash) { + return hash % 11493228998133068689llu; + } + static uint64_t mod14480561146010017169(uint64_t hash) { + return hash % 14480561146010017169llu; + } + static uint64_t mod18446744073709551557(uint64_t hash) { + return hash % 18446744073709551557llu; + } + + using mod_function = uint64_t (*)(uint64_t); + + mod_function next_size_over(uint64_t& size) const { + // prime numbers generated by the following method: + // 1. start with a prime p = 2 + // 2. go to wolfram alpha and get p = NextPrime(2 * p) + // 3. repeat 2. until you overflow 64 bits + // you now have large gaps which you would hit if somebody called reserve() + // with an unlucky number. + // 4. to fill the gaps for every prime p go to wolfram alpha and get + // ClosestPrime(p * 2^(1/3)) and ClosestPrime(p * 2^(2/3)) and put those in + // the gaps + // 5. get PrevPrime(2^64) and put it at the end + // NOLINTNEXTLINE(*c-arrays*) + static constexpr const uint64_t prime_list[] = { + 2llu, + 3llu, + 5llu, + 7llu, + 11llu, + 13llu, + 17llu, + 23llu, + 29llu, + 37llu, + 47llu, + 59llu, + 73llu, + 97llu, + 127llu, + 151llu, + 197llu, + 251llu, + 313llu, + 397llu, + 499llu, + 631llu, + 797llu, + 1009llu, + 1259llu, + 1597llu, + 2011llu, + 2539llu, + 3203llu, + 4027llu, + 5087llu, + 6421llu, + 8089llu, + 10193llu, + 12853llu, + 16193llu, + 20399llu, + 25717llu, + 32401llu, + 40823llu, + 51437llu, + 64811llu, + 81649llu, + 102877llu, + 129607llu, + 163307llu, + 205759llu, + 259229llu, + 326617llu, + 411527llu, + 518509llu, + 653267llu, + 823117llu, + 1037059llu, + 1306601llu, + 1646237llu, + 2074129llu, + 2613229llu, + 3292489llu, + 4148279llu, + 5226491llu, + 6584983llu, + 8296553llu, + 10453007llu, + 13169977llu, + 16593127llu, + 20906033llu, + 26339969llu, + 33186281llu, + 41812097llu, + 52679969llu, + 66372617llu, + 83624237llu, + 105359939llu, + 132745199llu, + 167248483llu, + 210719881llu, + 265490441llu, + 334496971llu, + 421439783llu, + 530980861llu, + 668993977llu, + 842879579llu, + 1061961721llu, + 1337987929llu, + 1685759167llu, + 2123923447llu, + 2675975881llu, + 3371518343llu, + 4247846927llu, + 5351951779llu, + 6743036717llu, + 8495693897llu, + 10703903591llu, + 13486073473llu, + 16991387857llu, + 21407807219llu, + 26972146961llu, + 33982775741llu, + 42815614441llu, + 53944293929llu, + 67965551447llu, + 85631228929llu, + 107888587883llu, + 135931102921llu, + 171262457903llu, + 215777175787llu, + 271862205833llu, + 342524915839llu, + 431554351609llu, + 543724411781llu, + 685049831731llu, + 863108703229llu, + 1087448823553llu, + 1370099663459llu, + 1726217406467llu, + 2174897647073llu, + 2740199326961llu, + 3452434812973llu, + 4349795294267llu, + 5480398654009llu, + 6904869625999llu, + 8699590588571llu, + 10960797308051llu, + 13809739252051llu, + 17399181177241llu, + 21921594616111llu, + 27619478504183llu, + 34798362354533llu, + 43843189232363llu, + 55238957008387llu, + 69596724709081llu, + 87686378464759llu, + 110477914016779llu, + 139193449418173llu, + 175372756929481llu, + 220955828033581llu, + 278386898836457llu, + 350745513859007llu, + 441911656067171llu, + 556773797672909llu, + 701491027718027llu, + 883823312134381llu, + 1113547595345903llu, + 1402982055436147llu, + 1767646624268779llu, + 2227095190691797llu, + 2805964110872297llu, + 3535293248537579llu, + 4454190381383713llu, + 5611928221744609llu, + 7070586497075177llu, + 8908380762767489llu, + 11223856443489329llu, + 14141172994150357llu, + 17816761525534927llu, + 22447712886978529llu, + 28282345988300791llu, + 35633523051069991llu, + 44895425773957261llu, + 56564691976601587llu, + 71267046102139967llu, + 89790851547914507llu, + 113129383953203213llu, + 142534092204280003llu, + 179581703095829107llu, + 226258767906406483llu, + 285068184408560057llu, + 359163406191658253llu, + 452517535812813007llu, + 570136368817120201llu, + 718326812383316683llu, + 905035071625626043llu, + 1140272737634240411llu, + 1436653624766633509llu, + 1810070143251252131llu, + 2280545475268481167llu, + 2873307249533267101llu, + 3620140286502504283llu, + 4561090950536962147llu, + 5746614499066534157llu, + 7240280573005008577llu, + 9122181901073924329llu, + 11493228998133068689llu, + 14480561146010017169llu, + 18446744073709551557llu}; + // NOLINTNEXTLINE(*c-arrays*) + static constexpr uint64_t (*const mod_functions[])(uint64_t) = { + &mod0, + &mod2, + &mod3, + &mod5, + &mod7, + &mod11, + &mod13, + &mod17, + &mod23, + &mod29, + &mod37, + &mod47, + &mod59, + &mod73, + &mod97, + &mod127, + &mod151, + &mod197, + &mod251, + &mod313, + &mod397, + &mod499, + &mod631, + &mod797, + &mod1009, + &mod1259, + &mod1597, + &mod2011, + &mod2539, + &mod3203, + &mod4027, + &mod5087, + &mod6421, + &mod8089, + &mod10193, + &mod12853, + &mod16193, + &mod20399, + &mod25717, + &mod32401, + &mod40823, + &mod51437, + &mod64811, + &mod81649, + &mod102877, + &mod129607, + &mod163307, + &mod205759, + &mod259229, + &mod326617, + &mod411527, + &mod518509, + &mod653267, + &mod823117, + &mod1037059, + &mod1306601, + &mod1646237, + &mod2074129, + &mod2613229, + &mod3292489, + &mod4148279, + &mod5226491, + &mod6584983, + &mod8296553, + &mod10453007, + &mod13169977, + &mod16593127, + &mod20906033, + &mod26339969, + &mod33186281, + &mod41812097, + &mod52679969, + &mod66372617, + &mod83624237, + &mod105359939, + &mod132745199, + &mod167248483, + &mod210719881, + &mod265490441, + &mod334496971, + &mod421439783, + &mod530980861, + &mod668993977, + &mod842879579, + &mod1061961721, + &mod1337987929, + &mod1685759167, + &mod2123923447, + &mod2675975881, + &mod3371518343, + &mod4247846927, + &mod5351951779, + &mod6743036717, + &mod8495693897, + &mod10703903591, + &mod13486073473, + &mod16991387857, + &mod21407807219, + &mod26972146961, + &mod33982775741, + &mod42815614441, + &mod53944293929, + &mod67965551447, + &mod85631228929, + &mod107888587883, + &mod135931102921, + &mod171262457903, + &mod215777175787, + &mod271862205833, + &mod342524915839, + &mod431554351609, + &mod543724411781, + &mod685049831731, + &mod863108703229, + &mod1087448823553, + &mod1370099663459, + &mod1726217406467, + &mod2174897647073, + &mod2740199326961, + &mod3452434812973, + &mod4349795294267, + &mod5480398654009, + &mod6904869625999, + &mod8699590588571, + &mod10960797308051, + &mod13809739252051, + &mod17399181177241, + &mod21921594616111, + &mod27619478504183, + &mod34798362354533, + &mod43843189232363, + &mod55238957008387, + &mod69596724709081, + &mod87686378464759, + &mod110477914016779, + &mod139193449418173, + &mod175372756929481, + &mod220955828033581, + &mod278386898836457, + &mod350745513859007, + &mod441911656067171, + &mod556773797672909, + &mod701491027718027, + &mod883823312134381, + &mod1113547595345903, + &mod1402982055436147, + &mod1767646624268779, + &mod2227095190691797, + &mod2805964110872297, + &mod3535293248537579, + &mod4454190381383713, + &mod5611928221744609, + &mod7070586497075177, + &mod8908380762767489, + &mod11223856443489329, + &mod14141172994150357, + &mod17816761525534927, + &mod22447712886978529, + &mod28282345988300791, + &mod35633523051069991, + &mod44895425773957261, + &mod56564691976601587, + &mod71267046102139967, + &mod89790851547914507, + &mod113129383953203213, + &mod142534092204280003, + &mod179581703095829107, + &mod226258767906406483, + &mod285068184408560057, + &mod359163406191658253, + &mod452517535812813007, + &mod570136368817120201, + &mod718326812383316683, + &mod905035071625626043, + &mod1140272737634240411, + &mod1436653624766633509, + &mod1810070143251252131, + &mod2280545475268481167, + &mod2873307249533267101, + &mod3620140286502504283, + &mod4561090950536962147, + &mod5746614499066534157, + &mod7240280573005008577, + &mod9122181901073924329, + &mod11493228998133068689, + &mod14480561146010017169, + &mod18446744073709551557}; + const uint64_t* found = std::lower_bound( + std::begin(prime_list), std::end(prime_list) - 1, size); + size = *found; + return mod_functions[1 + found - prime_list]; + } + void commit(mod_function new_mod_function) { + current_mod_function = new_mod_function; + } + void reset() { + current_mod_function = &mod0; + } + + uint64_t index_for_hash(uint64_t hash, uint64_t /*num_slots_minus_one*/) + const { + return current_mod_function(hash); + } + uint64_t keep_in_range(uint64_t index, uint64_t num_slots_minus_one) const { + return index > num_slots_minus_one ? current_mod_function(index) : index; + } + + private: + mod_function current_mod_function = &mod0; +}; + +struct power_of_two_hash_policy { + uint64_t index_for_hash(uint64_t hash, uint64_t num_slots_minus_one) const { + return hash & num_slots_minus_one; + } + uint64_t keep_in_range(uint64_t index, uint64_t num_slots_minus_one) const { + return index_for_hash(index, num_slots_minus_one); + } + int8_t next_size_over(uint64_t& size) const { + size = detailv3::next_power_of_two(size); + return 0; + } + void commit(int8_t /*unused*/) {} + void reset() {} +}; + +struct fibonacci_hash_policy { + uint64_t index_for_hash(uint64_t hash, uint64_t /*num_slots_minus_one*/) + const { + return (11400714819323198485ull * hash) >> shift; + } + uint64_t keep_in_range(uint64_t index, uint64_t num_slots_minus_one) const { + return index & num_slots_minus_one; + } + + int8_t next_size_over(uint64_t& size) const { + size = std::max(uint64_t(2), detailv3::next_power_of_two(size)); + return static_cast(64 - detailv3::log2(size)); + } + void commit(int8_t shift_) { + shift = shift_; + } + void reset() { + shift = 63; + } + + private: + int8_t shift = 63; +}; + +template < + typename K, + typename V, + typename H = std::hash, + typename E = std::equal_to, + typename A = std::allocator>> +class flat_hash_map + : public detailv3::sherwood_v3_table< + std::pair, + K, + H, + detailv3::KeyOrValueHasher, H>, + E, + detailv3::KeyOrValueEquality, E>, + A, + typename std::allocator_traits::template rebind_alloc< + detailv3::sherwood_v3_entry>>> { + using Table = detailv3::sherwood_v3_table< + std::pair, + K, + H, + detailv3::KeyOrValueHasher, H>, + E, + detailv3::KeyOrValueEquality, E>, + A, + typename std::allocator_traits::template rebind_alloc< + detailv3::sherwood_v3_entry>>>; + + public: + using key_type = K; + using mapped_type = V; + + using Table::Table; + flat_hash_map() = default; + + inline V& operator[](const K& key) { + return emplace(key, convertible_to_value()).first->second; + } + inline V& operator[](K&& key) { + return emplace(std::move(key), convertible_to_value()).first->second; + } + V& at(const K& key) { + auto found = this->find(key); + if (found == this->end()) + throw std::out_of_range("Argument passed to at() was not in the map."); + return found->second; + } + const V& at(const K& key) const { + auto found = this->find(key); + if (found == this->end()) + throw std::out_of_range("Argument passed to at() was not in the map."); + return found->second; + } + + using Table::emplace; + std::pair emplace() { + return emplace(key_type(), convertible_to_value()); + } + template + std::pair insert_or_assign( + const key_type& key, + M&& m) { + auto emplace_result = emplace(key, std::forward(m)); + if (!emplace_result.second) + emplace_result.first->second = std::forward(m); + return emplace_result; + } + template + std::pair insert_or_assign( + key_type&& key, + M&& m) { + auto emplace_result = emplace(std::move(key), std::forward(m)); + if (!emplace_result.second) + emplace_result.first->second = std::forward(m); + return emplace_result; + } + template + typename Table::iterator insert_or_assign( + typename Table::const_iterator /*unused*/, + const key_type& key, + M&& m) { + return insert_or_assign(key, std::forward(m)).first; + } + template + typename Table::iterator insert_or_assign( + typename Table::const_iterator /*unused*/, + key_type&& key, + M&& m) { + return insert_or_assign(std::move(key), std::forward(m)).first; + } + + friend bool operator==(const flat_hash_map& lhs, const flat_hash_map& rhs) { + if (lhs.size() != rhs.size()) + return false; + for (const typename Table::value_type& value : lhs) { + auto found = rhs.find(value.first); + if (found == rhs.end() || value.second != found->second) + return false; + } + return true; + } + friend bool operator!=(const flat_hash_map& lhs, const flat_hash_map& rhs) { + return !(lhs == rhs); + } + + private: + struct convertible_to_value { + operator V() const { + return V(); + } + }; +}; + +template < + typename T, + typename H = std::hash, + typename E = std::equal_to, + typename A = std::allocator> +class flat_hash_set + : public detailv3::sherwood_v3_table< + T, + T, + H, + detailv3::functor_storage, + E, + detailv3::functor_storage, + A, + typename std::allocator_traits::template rebind_alloc< + detailv3::sherwood_v3_entry>> { + using Table = detailv3::sherwood_v3_table< + T, + T, + H, + detailv3::functor_storage, + E, + detailv3::functor_storage, + A, + typename std::allocator_traits::template rebind_alloc< + detailv3::sherwood_v3_entry>>; + + public: + using key_type = T; + + using Table::Table; + flat_hash_set() = default; + + template + std::pair emplace(Args&&... args) { + return Table::emplace(T(std::forward(args)...)); + } + std::pair emplace(const key_type& arg) { + return Table::emplace(arg); + } + std::pair emplace(key_type& arg) { + return Table::emplace(arg); + } + std::pair emplace(const key_type&& arg) { + return Table::emplace(std::move(arg)); + } + std::pair emplace(key_type&& arg) { + return Table::emplace(std::move(arg)); + } + + friend bool operator==(const flat_hash_set& lhs, const flat_hash_set& rhs) { + if (lhs.size() != rhs.size()) + return false; + for (const T& value : lhs) { + if (rhs.find(value) == rhs.end()) + return false; + } + return true; + } + friend bool operator!=(const flat_hash_set& lhs, const flat_hash_set& rhs) { + return !(lhs == rhs); + } +}; + +template +struct power_of_two_std_hash : std::hash { + typedef ska::power_of_two_hash_policy hash_policy; +}; + +} // end namespace ska + +C10_CLANG_DIAGNOSTIC_POP() + +#if defined(_MSC_VER) && !defined(__clang__) +#pragma warning(pop) +#endif + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/floating_point_utils.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/floating_point_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..b83f9c931e4cf13b648336b4331a6f33b0a6fda2 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/floating_point_utils.h @@ -0,0 +1,6 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/generic_math.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/generic_math.h new file mode 100644 index 0000000000000000000000000000000000000000..969e095ef59a8ad07bf80089a054d14e84c682d6 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/generic_math.h @@ -0,0 +1,113 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include + +#if defined(__CUDA_ARCH__) +#include +#define C10_COMPAT_COPYSIGN c10::cuda::compat::copysign +#elif defined(__HIPCC__) +#include +#define C10_COMPAT_COPYSIGN c10::hip::compat::copysign +#else +#include +#define C10_COMPAT_COPYSIGN c10::copysign +#endif + +// The functions in this file should be header-only as it is used under +// ABI-compatibility mode. + +namespace c10 { + +// NOTE: [Floor Division in Python] +// Python's __floordiv__ operator is more complicated than just floor(a / b). +// It aims to maintain the property: a == (a // b) * b + remainder(a, b) +// which can otherwise fail due to rounding errors in the remainder. +// So, instead it is calculated as: a // b = (a - remainder(a, b)) / b +// With some additional fix-ups added to the result. +// +// For reference, see CPython's implementation: +// https://github.com/python/cpython/blob/ace008c531dd685a30c1dd68f9b5ba35f20171cf/Objects/floatobject.c#L636 + +template +inline C10_HOST_DEVICE scalar_t div_floor_floating(scalar_t a, scalar_t b) + __ubsan_ignore_float_divide_by_zero__ { + if (C10_UNLIKELY(b == 0)) { + // Divide by zero: return standard IEEE result + return a / b; + } + + auto mod = std::fmod(a, b); + auto div = (a - mod) / b; + if ((mod != 0) && (b < 0) != (mod < 0)) { + div -= scalar_t(1); + } + + scalar_t floordiv; + if (div != 0) { + floordiv = std::floor(div); + if (div - floordiv > scalar_t(0.5)) { + floordiv += scalar_t(1.0); + } + } else { + floordiv = C10_COMPAT_COPYSIGN(scalar_t(0), a / b); + } + return floordiv; +} + +template +inline C10_HOST_DEVICE scalar_t div_floor_integer(scalar_t a, scalar_t b) { + if (C10_UNLIKELY( + std::is_signed::value && + a == std::numeric_limits::min() && b == scalar_t(-1))) { + return a; + } + + if (c10::signs_differ(a, b)) { + // Subtracts one from the results of truncation division if the + // divisor and dividend have different sign(bit)s and the remainder of + // the division is nonzero + const auto quot = a / b; + const auto rem = a % b; + return rem ? quot - 1 : quot; + } + return a / b; +} + +template < + typename scalar_t, + std::enable_if_t, int> = 0> +inline C10_HOST_DEVICE scalar_t div_mod(scalar_t a, scalar_t b) + __ubsan_ignore_float_divide_by_zero__ { + if (C10_UNLIKELY(b == 0)) { + // Divide by zero: return standard IEEE result + return std::fmod(a, b); + } + + auto mod = std::fmod(a, b); + if (mod == 0) { + mod = C10_COMPAT_COPYSIGN(scalar_t(0), b); + } else if ((b < 0) != (mod < 0)) { + mod += b; + } + return mod; +} + +template < + typename scalar_t, + std::enable_if_t, int> = 0> +inline C10_HOST_DEVICE scalar_t div_mod(scalar_t a, scalar_t b) { + auto mod = a % b; + if (mod != 0 && (b < 0) != (mod < 0)) { + mod += b; + } + return mod; +} + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/hash.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/hash.h new file mode 100644 index 0000000000000000000000000000000000000000..c3fff128439efb6d4ddf143493fd9a3d46b04435 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/hash.h @@ -0,0 +1,384 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace c10 { + +// NOTE: hash_combine and SHA1 hashing is based on implementation from Boost +// +// Boost Software License - Version 1.0 - August 17th, 2003 +// +// Permission is hereby granted, free of charge, to any person or organization +// obtaining a copy of the software and accompanying documentation covered by +// this license (the "Software") to use, reproduce, display, distribute, +// execute, and transmit the Software, and to prepare derivative works of the +// Software, and to permit third-parties to whom the Software is furnished to +// do so, all subject to the following: +// +// The copyright notices in the Software and this entire statement, including +// the above license grant, this restriction and the following disclaimer, +// must be included in all copies of the Software, in whole or in part, and +// all derivative works of the Software, unless such copies or derivative +// works are solely in the form of machine-executable object code generated by +// a source language processor. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT +// SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE +// FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, +// ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +inline size_t hash_combine(size_t seed, size_t value) { + return seed ^ (value + 0x9e3779b9 + (seed << 6u) + (seed >> 2u)); +} + +// Creates the SHA1 hash of a string. A 160-bit hash. +// Based on the implementation in Boost (see notice above). +// Note that SHA1 hashes are no longer considered cryptographically +// secure, but are the standard hash for generating unique ids. +// Usage: +// // Let 'code' be a std::string +// c10::sha1 sha1_hash{code}; +// const auto hash_code = sha1_hash.str(); +// TODO: Compare vs OpenSSL and/or CryptoPP implementations +struct sha1 { + typedef unsigned int(digest_type)[5]; + + sha1(const std::string& s = "") { + if (!s.empty()) { + reset(); + process_bytes(s.c_str(), s.size()); + } + } + + void reset() { + h_[0] = 0x67452301; + h_[1] = 0xEFCDAB89; + h_[2] = 0x98BADCFE; + h_[3] = 0x10325476; + h_[4] = 0xC3D2E1F0; + + block_byte_index_ = 0; + bit_count_low = 0; + bit_count_high = 0; + } + + std::string str() { + unsigned int digest[5]; + get_digest(digest); + + std::ostringstream buf; + for (unsigned int i : digest) { + buf << std::hex << std::setfill('0') << std::setw(8) << i; + } + + return buf.str(); + } + + private: + unsigned int left_rotate(unsigned int x, std::size_t n) { + return (x << n) ^ (x >> (32 - n)); + } + + void process_block_impl() { + unsigned int w[80]; + + for (std::size_t i = 0; i < 16; ++i) { + w[i] = (block_[i * 4 + 0] << 24); + w[i] |= (block_[i * 4 + 1] << 16); + w[i] |= (block_[i * 4 + 2] << 8); + w[i] |= (block_[i * 4 + 3]); + } + + for (std::size_t i = 16; i < 80; ++i) { + w[i] = left_rotate((w[i - 3] ^ w[i - 8] ^ w[i - 14] ^ w[i - 16]), 1); + } + + unsigned int a = h_[0]; + unsigned int b = h_[1]; + unsigned int c = h_[2]; + unsigned int d = h_[3]; + unsigned int e = h_[4]; + + for (std::size_t i = 0; i < 80; ++i) { + unsigned int f = 0; + unsigned int k = 0; + + if (i < 20) { + f = (b & c) | (~b & d); + k = 0x5A827999; + } else if (i < 40) { + f = b ^ c ^ d; + k = 0x6ED9EBA1; + } else if (i < 60) { + f = (b & c) | (b & d) | (c & d); + k = 0x8F1BBCDC; + } else { + f = b ^ c ^ d; + k = 0xCA62C1D6; + } + + unsigned temp = left_rotate(a, 5) + f + e + k + w[i]; + e = d; + d = c; + c = left_rotate(b, 30); + b = a; + a = temp; + } + + h_[0] += a; + h_[1] += b; + h_[2] += c; + h_[3] += d; + h_[4] += e; + } + + void process_byte_impl(unsigned char byte) { + block_[block_byte_index_++] = byte; + + if (block_byte_index_ == 64) { + block_byte_index_ = 0; + process_block_impl(); + } + } + + void process_byte(unsigned char byte) { + process_byte_impl(byte); + + // size_t max value = 0xFFFFFFFF + // if (bit_count_low + 8 >= 0x100000000) { // would overflow + // if (bit_count_low >= 0x100000000-8) { + if (bit_count_low < 0xFFFFFFF8) { + bit_count_low += 8; + } else { + bit_count_low = 0; + + if (bit_count_high <= 0xFFFFFFFE) { + ++bit_count_high; + } else { + TORCH_CHECK(false, "sha1 too many bytes"); + } + } + } + + void process_block(void const* bytes_begin, void const* bytes_end) { + unsigned char const* begin = static_cast(bytes_begin); + unsigned char const* end = static_cast(bytes_end); + for (; begin != end; ++begin) { + process_byte(*begin); + } + } + + void process_bytes(void const* buffer, std::size_t byte_count) { + unsigned char const* b = static_cast(buffer); + process_block(b, b + byte_count); + } + + void get_digest(digest_type& digest) { + // append the bit '1' to the message + process_byte_impl(0x80); + + // append k bits '0', where k is the minimum number >= 0 + // such that the resulting message length is congruent to 56 (mod 64) + // check if there is enough space for padding and bit_count + if (block_byte_index_ > 56) { + // finish this block + while (block_byte_index_ != 0) { + process_byte_impl(0); + } + + // one more block + while (block_byte_index_ < 56) { + process_byte_impl(0); + } + } else { + while (block_byte_index_ < 56) { + process_byte_impl(0); + } + } + + // append length of message (before pre-processing) + // as a 64-bit big-endian integer + process_byte_impl( + static_cast((bit_count_high >> 24) & 0xFF)); + process_byte_impl( + static_cast((bit_count_high >> 16) & 0xFF)); + process_byte_impl(static_cast((bit_count_high >> 8) & 0xFF)); + process_byte_impl(static_cast((bit_count_high) & 0xFF)); + process_byte_impl(static_cast((bit_count_low >> 24) & 0xFF)); + process_byte_impl(static_cast((bit_count_low >> 16) & 0xFF)); + process_byte_impl(static_cast((bit_count_low >> 8) & 0xFF)); + process_byte_impl(static_cast((bit_count_low) & 0xFF)); + + // get final digest + digest[0] = h_[0]; + digest[1] = h_[1]; + digest[2] = h_[2]; + digest[3] = h_[3]; + digest[4] = h_[4]; + } + + unsigned int h_[5]{}; + unsigned char block_[64]{}; + std::size_t block_byte_index_{}; + std::size_t bit_count_low{}; + std::size_t bit_count_high{}; +}; + +constexpr uint64_t twang_mix64(uint64_t key) noexcept { + key = (~key) + (key << 21); // key *= (1 << 21) - 1; key -= 1; + key = key ^ (key >> 24); + key = key + (key << 3) + (key << 8); // key *= 1 + (1 << 3) + (1 << 8) + key = key ^ (key >> 14); + key = key + (key << 2) + (key << 4); // key *= 1 + (1 << 2) + (1 << 4) + key = key ^ (key >> 28); + key = key + (key << 31); // key *= 1 + (1 << 31) + return key; +} + +//////////////////////////////////////////////////////////////////////////////// +// c10::hash implementation +//////////////////////////////////////////////////////////////////////////////// + +namespace _hash_detail { + +// Use template argument deduction to shorten calls to c10::hash +template +size_t simple_get_hash(const T& o); + +template +using type_if_not_enum = std::enable_if_t, V>; + +// Use SFINAE to dispatch to std::hash if possible, cast enum types to int +// automatically, and fall back to T::hash otherwise. NOTE: C++14 added support +// for hashing enum types to the standard, and some compilers implement it even +// when C++14 flags aren't specified. This is why we have to disable this +// overload if T is an enum type (and use the one below in this case). +template +auto dispatch_hash(const T& o) + -> decltype(std::hash()(o), type_if_not_enum()) { + return std::hash()(o); +} + +template +std::enable_if_t, size_t> dispatch_hash(const T& o) { + using R = std::underlying_type_t; + return std::hash()(static_cast(o)); +} + +template +auto dispatch_hash(const T& o) -> decltype(T::hash(o), size_t()) { + return T::hash(o); +} + +} // namespace _hash_detail + +// Hasher struct +template +struct hash { + size_t operator()(const T& o) const { + return _hash_detail::dispatch_hash(o); + } +}; + +// Specialization for std::tuple +template +struct hash> { + template + struct tuple_hash { + size_t operator()(const std::tuple& t) const { + return hash_combine( + _hash_detail::simple_get_hash(std::get(t)), + tuple_hash()(t)); + } + }; + + template + struct tuple_hash<0, Ts...> { + size_t operator()(const std::tuple& t) const { + return _hash_detail::simple_get_hash(std::get<0>(t)); + } + }; + + size_t operator()(const std::tuple& t) const { + return tuple_hash()(t); + } +}; + +template +struct hash> { + size_t operator()(const std::pair& pair) const { + std::tuple tuple = std::make_tuple(pair.first, pair.second); + return _hash_detail::simple_get_hash(tuple); + } +}; + +template +struct hash> { + size_t operator()(c10::ArrayRef v) const { + size_t seed = 0; + for (const auto& elem : v) { + seed = hash_combine(seed, _hash_detail::simple_get_hash(elem)); + } + return seed; + } +}; + +// Specialization for std::vector +template +struct hash> { + size_t operator()(const std::vector& v) const { + return hash>()(v); + } +}; + +namespace _hash_detail { + +template +size_t simple_get_hash(const T& o) { + return c10::hash()(o); +} + +} // namespace _hash_detail + +// Use this function to actually hash multiple things in one line. +// Dispatches to c10::hash, so it can hash containers. +// Example: +// +// static size_t hash(const MyStruct& s) { +// return get_hash(s.member1, s.member2, s.member3); +// } +template +size_t get_hash(const Types&... args) { + return c10::hash()(std::tie(args...)); +} + +// Specialization for c10::complex +template +struct hash> { + size_t operator()(const c10::complex& c) const { + return get_hash(c.real(), c.imag()); + } +}; + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/int128.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/int128.h new file mode 100644 index 0000000000000000000000000000000000000000..73687a69d1bbc0bfe4a0d449cbf43f10437e29bd --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/int128.h @@ -0,0 +1,403 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// This file is based on the uint128 implementation of protobuf at +// https://github.com/protocolbuffers/protobuf/blob/1e88936fce10cf773cb72b44c6a7f48b38c7578b/src/google/protobuf/stubs/int128.h +// +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +#include +#include +#include + +namespace c10 { + +struct uint128_pod; + +// TODO(xiaofeng): Define GOOGLE_PROTOBUF_HAS_CONSTEXPR when constexpr is +// available. +#ifdef GOOGLE_PROTOBUF_HAS_CONSTEXPR +#define UINT128_CONSTEXPR constexpr +#else +#define UINT128_CONSTEXPR +#endif + +class uint128; +inline uint128& operator<<=(uint128& self, int amount); + +// An unsigned 128-bit integer type. Thread-compatible. +class C10_API uint128 { + public: + UINT128_CONSTEXPR uint128(); // Sets to 0, but don't trust on this behavior. + UINT128_CONSTEXPR uint128(uint64_t top, uint64_t bottom); +#ifndef SWIG + UINT128_CONSTEXPR uint128(int bottom); + UINT128_CONSTEXPR uint128(uint32_t bottom); // Top 96 bits = 0 +#endif + UINT128_CONSTEXPR uint128(uint64_t bottom); // hi_ = 0 + UINT128_CONSTEXPR uint128(const uint128_pod& val); + + // Trivial copy constructor, assignment operator and destructor. + + void Initialize(uint64_t top, uint64_t bottom); + + // Arithmetic operators. + uint128& operator+=(const uint128& b); + uint128& operator-=(const uint128& b); + uint128& operator*=(const uint128& b); + // Long division/modulo for uint128. + uint128& operator/=(const uint128& b); + uint128& operator%=(const uint128& b); + uint128 operator++(int); + uint128 operator--(int); + // Make msvc happy with using operator<<= from DivModImpl + // which is a static function, and linker complained about missing + // static version of this overload + friend uint128& operator<<=(uint128& /*self*/, int /*amount*/); + uint128& operator>>=(int /*amount*/); + uint128& operator&=(const uint128& b); + uint128& operator|=(const uint128& b); + uint128& operator^=(const uint128& b); + uint128& operator++(); + uint128& operator--(); + + friend uint64_t Uint128Low64(const uint128& v); + friend uint64_t Uint128High64(const uint128& v); + + // We add "std::" to avoid including all of port.h. + C10_API friend std::ostream& operator<<(std::ostream& o, const uint128& b); + + private: + static void DivModImpl( + uint128 dividend, + uint128 divisor, + uint128* quotient_ret, + uint128* remainder_ret); + + // Little-endian memory order optimizations can benefit from + // having lo_ first, hi_ last. + // See util/endian/endian.h and Load128/Store128 for storing a uint128. + uint64_t lo_; + uint64_t hi_; + + // Not implemented, just declared for catching automatic type conversions. + uint128(uint8_t); + uint128(uint16_t); + uint128(float v); + uint128(double v); +}; + +// This is a POD form of uint128 which can be used for static variables which +// need to be operated on as uint128. +struct uint128_pod { + // Note: The ordering of fields is different than 'class uint128' but the + // same as its 2-arg constructor. This enables more obvious initialization + // of static instances, which is the primary reason for this struct in the + // first place. This does not seem to defeat any optimizations wrt + // operations involving this struct. + uint64_t hi; + uint64_t lo; +}; + +C10_API extern const uint128_pod kuint128max; + +// allow uint128 to be logged +C10_API extern std::ostream& operator<<(std::ostream& o, const uint128& b); + +// Methods to access low and high pieces of 128-bit value. +// Defined externally from uint128 to facilitate conversion +// to native 128-bit types when compilers support them. +inline uint64_t Uint128Low64(const uint128& v) { + return v.lo_; +} +inline uint64_t Uint128High64(const uint128& v) { + return v.hi_; +} + +// TODO: perhaps it would be nice to have int128, a signed 128-bit type? + +// -------------------------------------------------------------------------- +// Implementation details follow +// -------------------------------------------------------------------------- +inline bool operator==(const uint128& lhs, const uint128& rhs) { + return ( + Uint128Low64(lhs) == Uint128Low64(rhs) && + Uint128High64(lhs) == Uint128High64(rhs)); +} +inline bool operator!=(const uint128& lhs, const uint128& rhs) { + return !(lhs == rhs); +} + +inline UINT128_CONSTEXPR uint128::uint128() : lo_(0), hi_(0) {} +inline UINT128_CONSTEXPR uint128::uint128(uint64_t top, uint64_t bottom) + : lo_(bottom), hi_(top) {} +inline UINT128_CONSTEXPR uint128::uint128(const uint128_pod& v) + : lo_(v.lo), hi_(v.hi) {} +inline UINT128_CONSTEXPR uint128::uint128(uint64_t bottom) + : lo_(bottom), hi_(0) {} +#ifndef SWIG +inline UINT128_CONSTEXPR uint128::uint128(uint32_t bottom) + : lo_(bottom), hi_(0) {} +inline UINT128_CONSTEXPR uint128::uint128(int bottom) + : lo_(bottom), hi_(static_cast((bottom < 0) ? -1 : 0)) {} +#endif + +#undef UINT128_CONSTEXPR + +inline void uint128::Initialize(uint64_t top, uint64_t bottom) { + hi_ = top; + lo_ = bottom; +} + +// Comparison operators. + +#define CMP128(op) \ + inline bool operator op(const uint128& lhs, const uint128& rhs) { \ + return (Uint128High64(lhs) == Uint128High64(rhs)) \ + ? (Uint128Low64(lhs) op Uint128Low64(rhs)) \ + : (Uint128High64(lhs) op Uint128High64(rhs)); \ + } + +CMP128(<) +CMP128(>) +CMP128(>=) +CMP128(<=) + +#undef CMP128 + +// Unary operators + +inline uint128 operator-(const uint128& val) { + const uint64_t hi_flip = ~Uint128High64(val); + const uint64_t lo_flip = ~Uint128Low64(val); + const uint64_t lo_add = lo_flip + 1; + if (lo_add < lo_flip) { + return uint128(hi_flip + 1, lo_add); + } + return uint128(hi_flip, lo_add); +} + +inline bool operator!(const uint128& val) { + return !Uint128High64(val) && !Uint128Low64(val); +} + +// Logical operators. + +inline uint128 operator~(const uint128& val) { + return uint128(~Uint128High64(val), ~Uint128Low64(val)); +} + +#define LOGIC128(op) \ + inline uint128 operator op(const uint128& lhs, const uint128& rhs) { \ + return uint128( \ + Uint128High64(lhs) op Uint128High64(rhs), \ + Uint128Low64(lhs) op Uint128Low64(rhs)); \ + } + +LOGIC128(|) +LOGIC128(&) +LOGIC128(^) + +#undef LOGIC128 + +#define LOGICASSIGN128(op) \ + inline uint128& uint128::operator op(const uint128 & other) { \ + hi_ op other.hi_; \ + lo_ op other.lo_; \ + return *this; \ + } + +LOGICASSIGN128(|=) +LOGICASSIGN128(&=) +LOGICASSIGN128(^=) + +#undef LOGICASSIGN128 + +// Shift operators. + +inline uint128 operator<<(const uint128& val, int amount) { + // uint64_t shifts of >= 64 are undefined, so we will need some + // special-casing. + if (amount < 64) { + if (amount == 0) { + return val; + } + uint64_t new_hi = + (Uint128High64(val) << amount) | (Uint128Low64(val) >> (64 - amount)); + uint64_t new_lo = Uint128Low64(val) << amount; + return uint128(new_hi, new_lo); + } else if (amount < 128) { + return uint128(Uint128Low64(val) << (amount - 64), 0); + } else { + return uint128(0, 0); + } +} + +inline uint128 operator>>(const uint128& val, int amount) { + // uint64_t shifts of >= 64 are undefined, so we will need some + // special-casing. + if (amount < 64) { + if (amount == 0) { + return val; + } + uint64_t new_hi = Uint128High64(val) >> amount; + uint64_t new_lo = + (Uint128Low64(val) >> amount) | (Uint128High64(val) << (64 - amount)); + return uint128(new_hi, new_lo); + } else if (amount < 128) { + return uint128(0, Uint128High64(val) >> (amount - 64)); + } else { + return uint128(0, 0); + } +} + +inline uint128& operator<<=(uint128& self, int amount) { + // uint64_t shifts of >= 64 are undefined, so we will need some + // special-casing. + if (amount < 64) { + if (amount != 0) { + self.hi_ = (self.hi_ << amount) | (self.lo_ >> (64 - amount)); + self.lo_ = self.lo_ << amount; + } + } else if (amount < 128) { + self.hi_ = self.lo_ << (amount - 64); + self.lo_ = 0; + } else { + self.hi_ = 0; + self.lo_ = 0; + } + return self; +} + +inline uint128& uint128::operator>>=(int amount) { + // uint64_t shifts of >= 64 are undefined, so we will need some + // special-casing. + if (amount < 64) { + if (amount != 0) { + lo_ = (lo_ >> amount) | (hi_ << (64 - amount)); + hi_ = hi_ >> amount; + } + } else if (amount < 128) { + lo_ = hi_ >> (amount - 64); + hi_ = 0; + } else { + lo_ = 0; + hi_ = 0; + } + return *this; +} + +inline uint128 operator+(const uint128& lhs, const uint128& rhs) { + return uint128(lhs) += rhs; +} + +inline uint128 operator-(const uint128& lhs, const uint128& rhs) { + return uint128(lhs) -= rhs; +} + +inline uint128 operator*(const uint128& lhs, const uint128& rhs) { + return uint128(lhs) *= rhs; +} + +inline uint128 operator/(const uint128& lhs, const uint128& rhs) { + return uint128(lhs) /= rhs; +} + +inline uint128 operator%(const uint128& lhs, const uint128& rhs) { + return uint128(lhs) %= rhs; +} + +inline uint128& uint128::operator+=(const uint128& b) { + hi_ += b.hi_; + uint64_t lolo = lo_ + b.lo_; + if (lolo < lo_) + ++hi_; + lo_ = lolo; + return *this; +} + +inline uint128& uint128::operator-=(const uint128& b) { + hi_ -= b.hi_; + if (b.lo_ > lo_) + --hi_; + lo_ -= b.lo_; + return *this; +} + +inline uint128& uint128::operator*=(const uint128& b) { + uint64_t a96 = hi_ >> 32; + uint64_t a64 = hi_ & 0xffffffffu; + uint64_t a32 = lo_ >> 32; + uint64_t a00 = lo_ & 0xffffffffu; + uint64_t b96 = b.hi_ >> 32; + uint64_t b64 = b.hi_ & 0xffffffffu; + uint64_t b32 = b.lo_ >> 32; + uint64_t b00 = b.lo_ & 0xffffffffu; + // multiply [a96 .. a00] x [b96 .. b00] + // terms higher than c96 disappear off the high side + // terms c96 and c64 are safe to ignore carry bit + uint64_t c96 = a96 * b00 + a64 * b32 + a32 * b64 + a00 * b96; + uint64_t c64 = a64 * b00 + a32 * b32 + a00 * b64; + this->hi_ = (c96 << 32) + c64; + this->lo_ = 0; + // add terms after this one at a time to capture carry + *this += uint128(a32 * b00) << 32; + *this += uint128(a00 * b32) << 32; + *this += a00 * b00; + return *this; +} + +inline uint128 uint128::operator++(int) { + uint128 tmp(*this); + *this += 1; + return tmp; +} + +inline uint128 uint128::operator--(int) { + uint128 tmp(*this); + *this -= 1; + return tmp; +} + +inline uint128& uint128::operator++() { + *this += 1; + return *this; +} + +inline uint128& uint128::operator--() { + *this -= 1; + return *this; +} + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/intrusive_ptr.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/intrusive_ptr.h new file mode 100644 index 0000000000000000000000000000000000000000..148a9bf4a20002de4396c9e0a26ea695b8ed1c98 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/intrusive_ptr.h @@ -0,0 +1,1278 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace pybind11 { +template +class class_; +} + +namespace torch::utils { +class PyObjectPreservation; +} + +namespace c10 { +class intrusive_ptr_target; +namespace raw { +namespace weak_intrusive_ptr { +inline void incref(intrusive_ptr_target* self); +} +namespace intrusive_ptr { +inline void incref(intrusive_ptr_target* self); +} + +// constructor tag used by intrusive_ptr constructors +struct DontIncreaseRefcount {}; +} // namespace raw + +namespace detail { +constexpr uint64_t kImpracticallyHugeReferenceCount = 0x0FFFFFFF; +constexpr uint64_t kImpracticallyHugeWeakReferenceCount = + (kImpracticallyHugeReferenceCount << 32); +constexpr uint64_t kReferenceCountOne = 1; +constexpr uint64_t kWeakReferenceCountOne = (kReferenceCountOne << 32); +constexpr uint64_t kUniqueRef = (kReferenceCountOne | kWeakReferenceCountOne); +// Indicates whether the object has a PyObject wrapper. +constexpr uint64_t kHasPyObject = (uint64_t(1) << 63); + +template +struct intrusive_target_default_null_type final { + static constexpr TTarget* singleton() noexcept { + return nullptr; + } +}; + +template +TTarget* assign_ptr_(TTarget* rhs) { + if (FromNullType::singleton() == rhs) { + return ToNullType::singleton(); + } else { + return rhs; + } +} + +inline uint32_t refcount(uint64_t combined_refcount) { + return static_cast(combined_refcount); +} + +inline uint32_t weakcount(uint64_t combined_refcount) { + return static_cast((combined_refcount & ~kHasPyObject) >> 32); +} + +inline bool has_pyobject(uint64_t combined_refcount) { + return (combined_refcount & kHasPyObject) != 0; +} + +inline bool is_uniquely_owned(uint64_t combined_refcount) { + return (combined_refcount & ~detail::kHasPyObject) == detail::kUniqueRef; +} + +// The only requirement for refcount increment is that it happens-before +// decrement, so no additional memory ordering is needed. +inline uint64_t atomic_combined_refcount_increment( + std::atomic& combined_refcount, + uint64_t inc) { + return combined_refcount.fetch_add(inc, std::memory_order_relaxed) + inc; +} + +inline uint32_t atomic_weakcount_increment( + std::atomic& combined_refcount) { + return detail::weakcount(atomic_combined_refcount_increment( + combined_refcount, kWeakReferenceCountOne)); +} + +// The requirement is that all modifications to the managed object happen-before +// invocation of the managed object destructor, and that allocation of the +// managed object storage happens-before deallocation of the storage. +// +// To get this ordering, all non-final decrements must synchronize-with the +// final decrement. So all non-final decrements have to store-release while the +// final decrement has to load-acquire, either directly or with the help of +// fences. But it's easiest just to have all decrements be acq-rel. And it turns +// out, on modern architectures and chips, it's also fastest. +inline uint64_t atomic_combined_refcount_decrement( + std::atomic& combined_refcount, + uint64_t dec) { + return combined_refcount.fetch_sub(dec, std::memory_order_acq_rel) - dec; +} + +inline uint32_t atomic_weakcount_decrement( + std::atomic& combined_refcount) { + return detail::weakcount(atomic_combined_refcount_decrement( + combined_refcount, kWeakReferenceCountOne)); +} + +template +struct TargetTraits { + static constexpr bool can_have_pyobject = false; +}; + +} // namespace detail + +/** + * intrusive_ptr is an alternative to shared_ptr that has better + * performance because it does the refcounting intrusively + * (i.e. in a member of the object itself). + * Your class T needs to inherit from intrusive_ptr_target to allow it to be + * used in an intrusive_ptr. Your class's constructor should not allow + *`this` to escape to other threads or create an intrusive_ptr from `this`. + */ + +// Note [Stack allocated intrusive_ptr_target safety] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// A well known problem with std::enable_shared_from_this is that it +// allows you to create a std::shared_ptr from a stack allocated object, +// which is totally bogus because the object will die once you return +// from the stack. In intrusive_ptr, we can detect that this has occurred, +// because we set the refcount/weakcount of objects which inherit from +// intrusive_ptr_target to zero, *unless* we can prove that the object +// was dynamically allocated (e.g., via make_intrusive). +// +// Thus, whenever you transmute a T* into a intrusive_ptr, we check +// and make sure that the refcount isn't zero (or, a more subtle +// test for weak_intrusive_ptr, for which the refcount may validly +// be zero, but the weak refcount better not be zero), because that +// tells us if the object was allocated by us. If it wasn't, no +// intrusive_ptr for you! + +// NOLINTNEXTLINE(cppcoreguidelines-virtual-class-destructor) +class C10_API intrusive_ptr_target { + // Note [Weak references for intrusive refcounting] + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + // Here's the scheme: + // + // - refcount == number of strong references to the object + // weakcount == number of weak references to the object, + // plus one more if refcount > 0 + // An invariant: refcount > 0 => weakcount > 0 + // + // - c10::StorageImpl stays live as long as there are any strong + // or weak pointers to it (weakcount > 0, since strong + // references count as a +1 to weakcount) + // + // - finalizers are called and data_ptr is deallocated when refcount == 0 + // + // - Once refcount == 0, it can never again be > 0 (the transition + // from > 0 to == 0 is monotonic) + // + // - When you access c10::StorageImpl via a weak pointer, you must + // atomically increment the use count, if it is greater than 0. + // If it is not, you must report that the storage is dead. + // + //.We use a single combined count for refcount and weakcount so that + // we can atomically operate on both at the same time for performance + // and defined behaviors. + // + // Note [PyObject preservation for Tensor and Storages] + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + // intrusive_ptr has special support for preserving PyObject wrappers + // for TensorImpl and StorageImpl. The most significant bit (kHasPyObject) of + // the combined_refcount_ is used to indicate whether the object has a + // PyObject wrapper. + // + // - The PyObject, if it exists, holds a strong reference to the + // intrusive_ptr_target. + // + // - When the refcount goes from 1 to 2, we incref the PyObject. + // + // - When the refcount goes from 2 to 1, we decref the PyObject. + // + // In other words, the intrusive_ptr keeps the PyObject alive as long as there + // are other C++ references to the intrusive_ptr_target. + + mutable std::atomic combined_refcount_; + static_assert(sizeof(std::atomic) == 8); + static_assert(alignof(std::atomic) == 8); + static_assert(std::atomic::is_always_lock_free); + + template + friend class intrusive_ptr; + friend inline void raw::intrusive_ptr::incref(intrusive_ptr_target* self); + + template + friend class weak_intrusive_ptr; + friend inline void raw::weak_intrusive_ptr::incref( + intrusive_ptr_target* self); + + template + friend struct ExclusivelyOwnedTensorTraits; + + friend class torch::utils::PyObjectPreservation; + + protected: + // protected destructor. We never want to destruct intrusive_ptr_target* + // directly. + virtual ~intrusive_ptr_target() { +// Disable -Wterminate and -Wexceptions so we're allowed to use assertions +// (i.e. throw exceptions) in a destructor. +// We also have to disable -Wunknown-warning-option and -Wpragmas, because +// some other compilers don't know about -Wterminate or -Wexceptions and +// will show a warning about unknown warning options otherwise. +#if defined(_MSC_VER) && !defined(__clang__) +#pragma warning(push) +#pragma warning( \ + disable : 4297) // function assumed not to throw an exception but does +#else +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wpragmas" +#pragma GCC diagnostic ignored "-Wunknown-warning-option" +#pragma GCC diagnostic ignored "-Wterminate" +#pragma GCC diagnostic ignored "-Wexceptions" +#endif + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + // Second condition is there to accommodate + // unsafe_adapt_non_heap_allocated: since we are doing our own + // deallocation in that case, it is correct for each + // expected_decref to have happened (some user code tried to + // decref and thus free the object, but it didn't happen right + // away) or not (no user code tried to free the object, and + // now it's getting destroyed through whatever mechanism the + // caller of unsafe_adapt_non_heap_allocated wanted to + // use). We choose our reference count such that the count + // will not dip below kImpracticallyHugeReferenceCount regardless. + refcount() == 0 || + refcount() >= detail::kImpracticallyHugeReferenceCount, + "Tried to destruct an intrusive_ptr_target that still has intrusive_ptr to it; refcount was ", + refcount()); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + // See ~intrusive_ptr for optimization that will frequently result in 1 + // at destruction time. + weakcount() == 1 || weakcount() == 0 || + weakcount() == detail::kImpracticallyHugeReferenceCount - 1 || + weakcount() == detail::kImpracticallyHugeReferenceCount, + "Tried to destruct an intrusive_ptr_target that still has weak_intrusive_ptr to it"); +#if defined(_MSC_VER) && !defined(__clang__) +#pragma warning(pop) +#else +#pragma GCC diagnostic pop +#endif + } + + constexpr intrusive_ptr_target() noexcept : combined_refcount_(0) {} + + // intrusive_ptr_target supports copy and move: but refcount and weakcount + // don't participate (since they are intrinsic properties of the memory + // location) + intrusive_ptr_target(intrusive_ptr_target&& /*other*/) noexcept + : intrusive_ptr_target() {} + + intrusive_ptr_target& operator=(intrusive_ptr_target&& /*other*/) noexcept { + return *this; + } + + intrusive_ptr_target(const intrusive_ptr_target& /*other*/) noexcept + : intrusive_ptr_target() {} + + intrusive_ptr_target& operator=( + const intrusive_ptr_target& /*other*/) noexcept { + return *this; + } + + private: + /** + * This is called when refcount reaches zero. + * You can override this to release expensive resources. + * There might still be weak references, so your object might not get + * destructed yet, but you can assume the object isn't used anymore, + * i.e. no more calls to methods or accesses to members (we just can't + * destruct it yet because we need the weakcount accessible). + * + * If there are no weak references (i.e. your class is about to be + * destructed), this function WILL NOT be called. + */ + virtual void release_resources() {} + + /** + * These two methods are called when the refcount transitions between one + * and two and the object has a PyObject wrapper. + */ + virtual void incref_pyobject() const noexcept {} + virtual void decref_pyobject() const noexcept {} + virtual bool try_incref_pyobject() const noexcept { + return false; + } + + uint32_t refcount(std::memory_order order = std::memory_order_relaxed) const { + return detail::refcount(combined_refcount_.load(order)); + } + + uint32_t weakcount( + std::memory_order order = std::memory_order_relaxed) const { + return detail::weakcount(combined_refcount_.load(order)); + } +}; + +namespace detail { + +#ifndef C10_MOBILE +template <> +struct TargetTraits { + // A generic intrusive_ptr may actually be a TensorImpl + // or StorageImpl, so we have to allow for PyObject support. + static constexpr bool can_have_pyobject = true; +}; +#endif + +} // namespace detail + +template +class weak_intrusive_ptr; + +template < + class TTarget, + class NullType = detail::intrusive_target_default_null_type> +class intrusive_ptr final { + private: +// the following static assert would be nice to have but it requires +// the target class T to be fully defined when intrusive_ptr is instantiated +// this is a problem for classes that contain pointers to themselves +// static_assert( +// std::is_base_of_v, +// "intrusive_ptr can only be used for classes that inherit from +// intrusive_ptr_target."); +#ifndef _WIN32 + // This static_assert triggers on MSVC + // error C2131: expression did not evaluate to a constant + static_assert( + // NOLINTNEXTLINE(misc-redundant-expression) + NullType::singleton() == NullType::singleton(), + "NullType must have a constexpr singleton() method"); +#endif + static_assert( + std::is_base_of_v< + TTarget, + std::remove_pointer_t>, + "NullType::singleton() must return a element_type* pointer"); + + TTarget* target_; + + template + friend struct ExclusivelyOwnedTensorTraits; + template + friend class intrusive_ptr; + friend class weak_intrusive_ptr; + + // Make pybind11::class_ be a friend class of intrusive_ptr, so that custom + // smart holder in pybind11 could access the private constructor of + // intrusive_ptr(T*) which took the ownership of the object. This is required + // by customer holder macro PYBIND11_DECLARE_HOLDER_TYPE, where it uses + // intrusive_ptr(TTarget*) to initialize and take ownership of the object. For + // details, see + // https://pybind11.readthedocs.io/en/stable/advanced/smart_ptrs.html#custom-smart-pointers + template + friend class pybind11::class_; + + void retain_() noexcept { + if (target_ != NullType::singleton()) { + uint64_t combined = detail::atomic_combined_refcount_increment( + target_->combined_refcount_, detail::kReferenceCountOne); + uint32_t new_refcount = detail::refcount(combined); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + new_refcount != 1, + "intrusive_ptr: Cannot increase refcount after it reached zero."); + + if constexpr (detail::TargetTraits::can_have_pyobject) { + // If the refcount transitioned from 1 to 2, we need to incref the + // PyObject. In other words, we need to ensure that the PyObject stays + // alive now that we have a C++ reference to this object in addition to + // the PyObject itself. + if (detail::has_pyobject(combined) && detail::refcount(combined) == 2) { + target_->incref_pyobject(); + } + } else { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + !detail::has_pyobject(combined), + "TargetTraits indicates that type cannot have PyObject, but refcount has PyObject bit set."); + } + } + } + + void reset_() noexcept { + if (target_ != NullType::singleton()) { + reset_not_null_(target_); + } + } + + // C10_NOINLINE to keep binary size a bit smaller. We pass TTarget* here + // to avoid an extra pointer dereference in the call from reset_(). + C10_NOINLINE static void reset_not_null_(TTarget* target) noexcept { + if (detail::is_uniquely_owned( + target->combined_refcount_.load(std::memory_order_acquire))) { + // Both counts are 1, so there are no weak references and + // we are releasing the last strong reference. No other + // threads can observe the effects of this target deletion + // call (e.g. calling use_count()) without a data race. + target->combined_refcount_.store(0, std::memory_order_relaxed); + delete target; + return; + } + + auto combined_refcount = detail::atomic_combined_refcount_decrement( + target->combined_refcount_, detail::kReferenceCountOne); + uint32_t new_refcount = detail::refcount(combined_refcount); + bool has_pyobject = detail::has_pyobject(combined_refcount); + if (new_refcount == 0) { + if (detail::weakcount(combined_refcount) == 1) { + delete target; + return; + } + // See comment above about weakcount. As long as refcount>0, + // weakcount is one larger than the actual number of weak references. + // So we need to decrement it here. + release_resources_and_decrement_weakrefs_(target); + } else if constexpr (detail::TargetTraits::can_have_pyobject) { + // If the refcount transitioned from 2 to 1, we need to decref the + // PyObject. In other words, we don't want to keep the PyObject alive if + // there are no C++ references to this object other than the PyObject + // itself. + if (has_pyobject && new_refcount == 1) { + target->decref_pyobject(); + } + } else { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + !has_pyobject, + "TargetTraits indicates that type cannot have PyObject, but refcount has PyObject bit set."); + } + } + + C10_NOINLINE static void release_resources_and_decrement_weakrefs_( + TTarget* target) noexcept { + // justification for const_cast: release_resources is basically a + // destructor and a destructor always mutates the object, even for + // const objects. + const_cast*>(target)->release_resources(); + if (detail::atomic_weakcount_decrement(target->combined_refcount_) == 0) { + delete target; + } + } + + // raw pointer constructors are not public because we shouldn't make + // intrusive_ptr out of raw pointers except from inside the make_intrusive(), + // reclaim() and weak_intrusive_ptr::lock() implementations. + + // This constructor will increase the ref counter for you. + // This constructor will be used by the make_intrusive(), and also pybind11, + // which wrap the intrusive_ptr holder around the raw pointer and incref + // correspondingly (pybind11 requires raw pointer constructor to incref by + // default). + explicit intrusive_ptr(TTarget* target) + : intrusive_ptr(target, raw::DontIncreaseRefcount{}) { + if (target_ != NullType::singleton()) { + // We just created result.target_, so we know no other thread has + // access to it, so we know we needn't care about memory ordering. + // (On x86_64, a store with memory_order_relaxed generates a plain old + // `mov`, whereas an atomic increment does a lock-prefixed `add`, which is + // much more expensive: https://godbolt.org/z/eKPzj8.) + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + target_->combined_refcount_.load(std::memory_order_relaxed) == 0, + "intrusive_ptr: Newly-created target had non-zero refcounts. Does its " + "constructor do something strange like incref or create an " + "intrusive_ptr from `this`?"); + target_->combined_refcount_.store( + detail::kUniqueRef, std::memory_order_relaxed); + } + } + + public: + using element_type = TTarget; + + intrusive_ptr() noexcept + : intrusive_ptr(NullType::singleton(), raw::DontIncreaseRefcount{}) {} + + /* implicit */ intrusive_ptr(std::nullptr_t) noexcept + : intrusive_ptr(NullType::singleton(), raw::DontIncreaseRefcount{}) {} + + // This constructor will not increase the ref counter for you. + // We use the tagged dispatch mechanism to explicitly mark this constructor + // to not increase the refcount + explicit intrusive_ptr( + TTarget* target, + raw::DontIncreaseRefcount /*unused*/) noexcept + : target_(target) {} + + explicit intrusive_ptr(std::unique_ptr rhs) noexcept + : intrusive_ptr(rhs.release()) {} + + intrusive_ptr(intrusive_ptr&& rhs) noexcept : target_(rhs.target_) { + rhs.target_ = NullType::singleton(); + } + + template + // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved) + /* implicit */ intrusive_ptr(intrusive_ptr&& rhs) noexcept + : target_( + detail::assign_ptr_(rhs.target_)) { + static_assert( + std::is_convertible_v, + "Type mismatch. intrusive_ptr move constructor got pointer of wrong type."); + rhs.target_ = FromNullType::singleton(); + } + + intrusive_ptr(const intrusive_ptr& rhs) : target_(rhs.target_) { + retain_(); + } + + template + /* implicit */ intrusive_ptr(const intrusive_ptr& rhs) + : target_( + detail::assign_ptr_(rhs.target_)) { + static_assert( + std::is_convertible_v, + "Type mismatch. intrusive_ptr copy constructor got pointer of wrong type."); + retain_(); + } + + ~intrusive_ptr() noexcept { + reset_(); + } + + intrusive_ptr& operator=(intrusive_ptr&& rhs) & noexcept { + // NOLINTNEXTLINE(*assign*) + return this->template operator= (std::move(rhs)); + } + + template + intrusive_ptr& operator=(intrusive_ptr&& rhs) & noexcept { + static_assert( + std::is_convertible_v, + "Type mismatch. intrusive_ptr move assignment got pointer of wrong type."); + intrusive_ptr tmp = std::move(rhs); + swap(tmp); + return *this; + } + + // Assignment is implemented using copy and swap. That's safe for self + // assignment. + // NOLINTNEXTLINE(bugprone-unhandled-self-assignment) + intrusive_ptr& operator=(const intrusive_ptr& rhs) & noexcept { + // NOLINTNEXTLINE(*assign-operator, *assignment-signature) + return this->template operator= (rhs); + } + + template + intrusive_ptr& operator=( + const intrusive_ptr& rhs) & noexcept { + static_assert( + std::is_convertible_v, + "Type mismatch. intrusive_ptr copy assignment got pointer of wrong type."); + intrusive_ptr tmp = rhs; + swap(tmp); + return *this; + } + + TTarget* get() const noexcept { + return target_; + } + + TTarget& operator*() const noexcept { + return *target_; + } + + TTarget* operator->() const noexcept { + return target_; + } + + operator bool() const noexcept { + return target_ != NullType::singleton(); + } + + void reset() noexcept { + reset_(); + target_ = NullType::singleton(); + } + + void swap(intrusive_ptr& rhs) noexcept { + std::swap(target_, rhs.target_); + } + + // We do a lot of null-pointer checks in our code, good to have this be cheap. + bool defined() const noexcept { + return target_ != NullType::singleton(); + } + + uint32_t use_count() const noexcept { + if (target_ == NullType::singleton()) { + return 0; + } + return target_->refcount(std::memory_order_relaxed); + } + + uint32_t weak_use_count() const noexcept { + if (target_ == NullType::singleton()) { + return 0; + } + return target_->weakcount(std::memory_order_relaxed); + } + + bool unique() const noexcept { + return use_count() == 1; + } + + /** + * Stronger than unique() in that it must not have any weakrefs as well. + */ + bool is_uniquely_owned() const noexcept { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(target_ != NullType::singleton()); + return detail::is_uniquely_owned( + target_->combined_refcount_.load(std::memory_order_acquire)); + } + + /** + * Returns an owning (!) pointer to the underlying object and makes the + * intrusive_ptr instance invalid. That means the refcount is not decreased. + * You *must* put the returned pointer back into a intrusive_ptr using + * intrusive_ptr::reclaim(ptr) to properly destruct it. + * This is helpful for C APIs. + */ + TTarget* release() noexcept { + // NOLINTNEXTLINE(clang-analyzer-core.uninitialized.Assign) + TTarget* result = target_; + target_ = NullType::singleton(); + return result; + } + + /** + * Takes an owning pointer to TTarget* and creates an intrusive_ptr that takes + * over ownership. That means the refcount is not increased. + * This is the counter-part to intrusive_ptr::release() and the pointer + * passed in *must* have been created using intrusive_ptr::release(). + */ + static intrusive_ptr reclaim(TTarget* owning_ptr) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + owning_ptr == NullType::singleton() || owning_ptr->refcount() == 0 || + owning_ptr->weakcount(), + "TTarget violates the invariant that refcount > 0 => weakcount > 0"); + return intrusive_ptr(owning_ptr, raw::DontIncreaseRefcount{}); + } + + /** + * Takes an owning pointer to TTarget* and creates an intrusive_ptr + * representing a new reference, i.e. the raw pointer retains + * ownership. + */ + static intrusive_ptr reclaim_copy(TTarget* owning_ptr) { + auto ret = reclaim(owning_ptr); + ret.retain_(); + return ret; + } + + /** + * Allocate a heap object with args and wrap it inside a intrusive_ptr and + * incref. This is a helper function to let make_intrusive() access private + * intrusive_ptr constructors. + */ + template + static intrusive_ptr make(Args&&... args) { + return intrusive_ptr(new TTarget(std::forward(args)...)); + } + + /** + * Turn a new instance of TTarget (e.g., literally allocated + * using new TTarget(...) into an intrusive_ptr. If possible, + * use intrusive_ptr::make instead which statically guarantees + * that the allocation was done properly. + * + * At the moment, the only reason this method exists is because + * pybind11 holder types expect to be able to allocate in + * this way (because pybind11 handles the new allocation itself). + */ + static intrusive_ptr unsafe_steal_from_new(TTarget* raw_ptr) { + return intrusive_ptr(raw_ptr); + } + + /** + * Turn an instance of TTarget that should not be reference counted + * (e.g., allocated into an arena with placement new) into an + * intrusive_ptr. This is gratuitously unsafe and should only be + * used if you can guarantee that the pointer will not escape and be + * refcounted as normal. + * + * `expected_decrefs` is a debugging parameter: it indicates the + * number of strong owners the intrusive_ptr_target in question is + * expected to get. In most use cases, this will likely be 1. + * + * The reason this method exists is for manually sharing + * StorageImpls across Tensors in the static runtime. It needs + * access to private intrusive_ptr members so that the refcounts can + * be initialized to custom values. + */ + static intrusive_ptr unsafe_adapt_non_heap_allocated( + TTarget* raw_ptr, + uint32_t expected_decrefs) { + intrusive_ptr result(raw_ptr, raw::DontIncreaseRefcount{}); + // kImpracticallyHugeReferenceCount is impractically huge for a reference + // count, while being in no danger of overflowing uint32_t. We actually only + // need to initialize the refcount to 2 -- we are just doing an unbalanced + // incref to prevent the non-heap-allocated target from being + // freed, and we are optimizing that incref by directly + // initializing the refcounts rather than doing an expensive + // atomic increment. The reason to use kImpracticallyHugeReferenceCount is + // to accommodate the debug assertions in ~intrusive_ptr_target. +#ifdef NDEBUG + expected_decrefs = 0; +#endif + result.target_->combined_refcount_.store( + detail::refcount( + detail::kImpracticallyHugeReferenceCount + expected_decrefs) | + detail::kImpracticallyHugeWeakReferenceCount, + std::memory_order_relaxed); + return result; + } + + /** + * Turn a **non-owning raw pointer** to an intrusive_ptr. It is + * the moral equivalent of enable_shared_from_this on a shared pointer. + * + * This method is only valid for objects that are already live. If + * you are looking for the moral equivalent of unique_ptr(T*) + * constructor, see steal_from_new. + * + * TODO: https://github.com/pytorch/pytorch/issues/56482 + */ + static intrusive_ptr unsafe_reclaim_from_nonowning(TTarget* raw_ptr) { + // See Note [Stack allocated intrusive_ptr_target safety] + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + raw_ptr == NullType::singleton() || raw_ptr->refcount() > 0, + "intrusive_ptr: Can only reclaim pointers that are owned by someone"); + auto ptr = reclaim(raw_ptr); // doesn't increase refcount + ptr.retain_(); + return ptr; + } +}; + +template < + class TTarget, + class NullType = detail::intrusive_target_default_null_type, + class... Args> +inline intrusive_ptr make_intrusive(Args&&... args) { + return intrusive_ptr::make(std::forward(args)...); +} + +template +inline void swap( + intrusive_ptr& lhs, + intrusive_ptr& rhs) noexcept { + lhs.swap(rhs); +} + +// To allow intrusive_ptr inside std::map or std::set, we need operator< +template +inline bool operator<( + const intrusive_ptr& lhs, + const intrusive_ptr& rhs) noexcept { + return lhs.get() < rhs.get(); +} + +template +inline bool operator==( + const intrusive_ptr& lhs, + const intrusive_ptr& rhs) noexcept { + return lhs.get() == rhs.get(); +} + +template +inline bool operator==( + const intrusive_ptr& lhs, + std::nullptr_t) noexcept { + return lhs.get() == nullptr; +} + +template +inline bool operator==( + std::nullptr_t, + const intrusive_ptr& rhs) noexcept { + return nullptr == rhs.get(); +} + +template +inline bool operator!=( + const intrusive_ptr& lhs, + const intrusive_ptr& rhs) noexcept { + return !operator==(lhs, rhs); +} + +template +inline bool operator!=( + const intrusive_ptr& lhs, + std::nullptr_t) noexcept { + return !operator==(lhs, nullptr); +} + +template +inline bool operator!=( + std::nullptr_t, + const intrusive_ptr& rhs) noexcept { + return !operator==(nullptr, rhs); +} +template +struct MaybeOwnedTraits> { + using owned_type = c10::intrusive_ptr; + using borrow_type = c10::intrusive_ptr; + + static borrow_type createBorrow(const owned_type& from) { + return borrow_type::reclaim(from.get()); + } + + static void assignBorrow(borrow_type& lhs, const borrow_type& rhs) { + lhs.release(); + lhs = borrow_type::reclaim(rhs.get()); + } + + static void destroyBorrow(borrow_type& toDestroy) { + toDestroy.release(); + } + + static const owned_type& referenceFromBorrow( + const borrow_type& borrow) noexcept { + return borrow; + } + + static const owned_type* pointerFromBorrow( + const borrow_type& borrow) noexcept { + return &borrow; + } + + static bool debugBorrowIsValid(const borrow_type& /*borrow*/) noexcept { + return true; + } +}; + +template < + typename TTarget, + class NullType = detail::intrusive_target_default_null_type> +class weak_intrusive_ptr final { + private: + static_assert( + std::is_base_of_v, + "intrusive_ptr can only be used for classes that inherit from intrusive_ptr_target."); +#ifndef _WIN32 + // This static_assert triggers on MSVC + // error C2131: expression did not evaluate to a constant + static_assert( + NullType::singleton() == NullType::singleton(), + "NullType must have a constexpr singleton() method"); +#endif + static_assert( + std::is_base_of_v< + TTarget, + std::remove_pointer_t>, + "NullType::singleton() must return a element_type* pointer"); + + TTarget* target_; + + template + friend class weak_intrusive_ptr; + + void retain_() { + if (target_ != NullType::singleton()) { + uint32_t new_weakcount = + detail::atomic_weakcount_increment(target_->combined_refcount_); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + new_weakcount != 1, + "weak_intrusive_ptr: Cannot increase weakcount after it reached zero."); + } + } + + void reset_() noexcept { + if (target_ != NullType::singleton() && + detail::atomic_weakcount_decrement(target_->combined_refcount_) == 0) { + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDelete) + delete target_; + } + target_ = NullType::singleton(); + } + + constexpr explicit weak_intrusive_ptr(TTarget* target) : target_(target) {} + + public: + using element_type = TTarget; + + explicit weak_intrusive_ptr(const intrusive_ptr& ptr) + : weak_intrusive_ptr(ptr.get()) { + retain_(); + } + + weak_intrusive_ptr(weak_intrusive_ptr&& rhs) noexcept : target_(rhs.target_) { + rhs.target_ = NullType::singleton(); + } + + template + /* implicit */ weak_intrusive_ptr( + // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved) + weak_intrusive_ptr&& rhs) noexcept + : target_( + detail::assign_ptr_(rhs.target_)) { + static_assert( + std::is_convertible_v, + "Type mismatch. weak_intrusive_ptr move constructor got pointer of wrong type."); + rhs.target_ = FromNullType::singleton(); + } + + weak_intrusive_ptr(const weak_intrusive_ptr& rhs) : target_(rhs.target_) { + retain_(); + } + + template + /* implicit */ weak_intrusive_ptr( + const weak_intrusive_ptr& rhs) + : target_( + detail::assign_ptr_(rhs.target_)) { + static_assert( + std::is_convertible_v, + "Type mismatch. weak_intrusive_ptr copy constructor got pointer of wrong type."); + retain_(); + } + + ~weak_intrusive_ptr() noexcept { + reset_(); + } + + weak_intrusive_ptr& operator=(weak_intrusive_ptr&& rhs) & noexcept { + // NOLINTNEXTLINE(*assign*) + return this->template operator= (std::move(rhs)); + } + + template + weak_intrusive_ptr& operator=( + weak_intrusive_ptr&& rhs) & noexcept { + static_assert( + std::is_convertible_v, + "Type mismatch. weak_intrusive_ptr move assignment got pointer of wrong type."); + weak_intrusive_ptr tmp = std::move(rhs); + swap(tmp); + return *this; + } + + weak_intrusive_ptr& operator=(const weak_intrusive_ptr& rhs) & noexcept { + if (this == &rhs) { + return *this; + } + // NOLINTNEXTLINE(*assign*) + return this->template operator= (rhs); + } + + weak_intrusive_ptr& operator=( + const intrusive_ptr& rhs) & noexcept { + weak_intrusive_ptr tmp(rhs); + swap(tmp); + return *this; + } + + template + weak_intrusive_ptr& operator=( + const weak_intrusive_ptr& rhs) & noexcept { + static_assert( + std::is_convertible_v, + "Type mismatch. weak_intrusive_ptr copy assignment got pointer of wrong type."); + weak_intrusive_ptr tmp = rhs; + swap(tmp); + return *this; + } + + void reset() noexcept { + reset_(); + } + + void swap(weak_intrusive_ptr& rhs) noexcept { + TTarget* tmp = target_; + target_ = rhs.target_; + rhs.target_ = tmp; + } + + // NB: This should ONLY be used by the std::hash implementation + // for weak_intrusive_ptr. Another way you could do this is + // friend std::hash, but this triggers two + // bugs: + // + // (1) It triggers an nvcc bug, where std::hash in a friend class + // declaration gets preprocessed into hash, which then cannot + // actually be found. The error in this case looks like: + // + // error: no template named 'hash'; did you mean 'std::hash'? + // + // (2) On OS X, std::hash is declared as a struct, not a class. + // This twings: + // + // error: class 'hash' was previously declared as a struct + // [-Werror,-Wmismatched-tags] + // + // Both of these are work-aroundable, but on the whole, I decided + // it would be simpler and easier to make work if we just expose + // an unsafe getter for target_ + // + TTarget* _unsafe_get_target() const noexcept { + return target_; + } + + uint32_t use_count() const noexcept { + if (target_ == NullType::singleton()) { + return 0; + } + return target_->refcount( + std::memory_order_relaxed); // refcount, not weakcount! + } + + uint32_t weak_use_count() const noexcept { + if (target_ == NullType::singleton()) { + return 0; + } + return target_->weakcount(std::memory_order_relaxed); + } + + bool expired() const noexcept { + return use_count() == 0; + } + + intrusive_ptr lock() const noexcept { + if (target_ == NullType::singleton()) { + return intrusive_ptr(); + } else { + bool increfed = false; + auto combined_refcount = + target_->combined_refcount_.load(std::memory_order_relaxed); + do { + if (detail::refcount(combined_refcount) == 0) { + // Object already destructed, no strong references left anymore. + // Return nullptr. + return intrusive_ptr(); + } + if constexpr (detail::TargetTraits::can_have_pyobject) { + if (detail::has_pyobject(combined_refcount) && + detail::refcount(combined_refcount) == 1 && !increfed) { + // Object has a python wrapper with no other C++ references. + // We need to to incref the Python object before we acquire a + // strong reference to the C++ object to avoid a situation + // where the Python object is deallocated concurrently. + if (!target_->try_incref_pyobject()) { + return intrusive_ptr(); + } + increfed = true; + } + } + } while (!target_->combined_refcount_.compare_exchange_weak( + combined_refcount, + combined_refcount + detail::kReferenceCountOne, + std::memory_order_acquire, + std::memory_order_relaxed)); + + if constexpr (detail::TargetTraits::can_have_pyobject) { + if (increfed && detail::refcount(combined_refcount) != 1) { + target_->decref_pyobject(); + } + } + + return intrusive_ptr( + target_, raw::DontIncreaseRefcount{}); + } + } + + /** + * Returns an owning (but still only weakly referenced) pointer to the + * underlying object and makes the weak_intrusive_ptr instance invalid. + * That means the weakcount is not decreased. + * You *must* put the returned pointer back into a weak_intrusive_ptr using + * weak_intrusive_ptr::reclaim(ptr) to properly destruct it. + * This is helpful for C APIs. + */ + TTarget* release() noexcept { + TTarget* result = target_; + target_ = NullType::singleton(); + return result; + } + + /** + * Takes an owning (but must be weakly referenced) pointer to TTarget* and + * creates a weak_intrusive_ptr that takes over ownership. + * This means that the weakcount is not increased. + * This is the counter-part to weak_intrusive_ptr::release() and the pointer + * passed in *must* have been created using weak_intrusive_ptr::release(). + */ + static weak_intrusive_ptr reclaim(TTarget* owning_weak_ptr) { + // See Note [Stack allocated intrusive_ptr_target safety] + // if refcount > 0, weakcount must be >1 for weak references to exist. + // see weak counting explanation at top of this file. + // if refcount == 0, weakcount only must be >0. + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + owning_weak_ptr == NullType::singleton() || + owning_weak_ptr->weakcount() > 1 || + (owning_weak_ptr->refcount() == 0 && + owning_weak_ptr->weakcount() > 0), + "weak_intrusive_ptr: Can only weak_intrusive_ptr::reclaim() owning pointers that were created using weak_intrusive_ptr::release()."); + return weak_intrusive_ptr(owning_weak_ptr); + } + + /** + * Takes a pointer to TTarget* (may be weak or strong) and creates a + * new weak_intrusive_ptr representing a new weak reference, i.e. + * the raw pointer retains ownership. + */ + static weak_intrusive_ptr reclaim_copy(TTarget* owning_ptr) { + auto ret = reclaim(owning_ptr); + ret.retain_(); + return ret; + } + + template + friend bool operator<( + const weak_intrusive_ptr& lhs, + const weak_intrusive_ptr& rhs) noexcept; + template + friend bool operator==( + const weak_intrusive_ptr& lhs, + const weak_intrusive_ptr& rhs) noexcept; +}; + +template +inline void swap( + weak_intrusive_ptr& lhs, + weak_intrusive_ptr& rhs) noexcept { + lhs.swap(rhs); +} + +// To allow weak_intrusive_ptr inside std::map or std::set, we need operator< +template +inline bool operator<( + const weak_intrusive_ptr& lhs, + const weak_intrusive_ptr& rhs) noexcept { + return lhs.target_ < rhs.target_; +} + +template +inline bool operator==( + const weak_intrusive_ptr& lhs, + const weak_intrusive_ptr& rhs) noexcept { + return lhs.target_ == rhs.target_; +} + +template +inline bool operator!=( + const weak_intrusive_ptr& lhs, + const weak_intrusive_ptr& rhs) noexcept { + return !operator==(lhs, rhs); +} + +// Alias for documentary purposes, to more easily distinguish +// weak raw intrusive pointers from intrusive pointers. +using weak_intrusive_ptr_target = intrusive_ptr_target; + +// This namespace provides some methods for working with +// raw pointers that subclass intrusive_ptr_target. They are not provided +// as methods on intrusive_ptr_target, because ideally you would not need these +// methods at all (use smart pointers), but if you are dealing with legacy code +// that still needs to pass around raw pointers, you may find these quite +// useful. +// +// An important usage note: some functions are only valid if you have a +// strong raw pointer to the object, while others are only valid if you +// have a weak raw pointer to the object. ONLY call intrusive_ptr namespace +// functions on strong pointers, and weak_intrusive_ptr namespace functions +// on weak pointers. If you mix it up, you may get an assert failure. +namespace raw { + +namespace intrusive_ptr { + +// WARNING: Unlike the reclaim() API, it is NOT valid to pass +// NullType::singleton to this function +inline void incref(intrusive_ptr_target* self) { + if (self) { + uint64_t combined = detail::atomic_combined_refcount_increment( + self->combined_refcount_, detail::kReferenceCountOne); + +#ifndef C10_MOBILE + if (detail::has_pyobject(combined) && detail::refcount(combined) == 2) { + self->incref_pyobject(); + } +#else + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!detail::has_pyobject(combined)); +#endif + } +} + +// WARNING: Unlike the reclaim() API, it is NOT valid to pass +// NullType::singleton to this function +inline void decref(intrusive_ptr_target* self) { + // Let it die + c10::intrusive_ptr::reclaim(self); + // NB: Caller still has 'self' pointer, but it's now invalid. + // If you want more safety, used the actual c10::intrusive_ptr class +} + +template +inline T* make_weak(T* self) { + // NB: 'this' is a strong pointer, but we return a weak pointer + auto ptr = c10::intrusive_ptr::reclaim(self); + c10::weak_intrusive_ptr wptr(ptr); + ptr.release(); + return wptr.release(); +} + +inline uint32_t use_count(intrusive_ptr_target* self) { + auto ptr = c10::intrusive_ptr::reclaim(self); + auto r = ptr.use_count(); + ptr.release(); + return r; +} + +} // namespace intrusive_ptr + +namespace weak_intrusive_ptr { + +inline void incref(weak_intrusive_ptr_target* self) { + detail::atomic_weakcount_increment(self->combined_refcount_); +} + +inline void decref(weak_intrusive_ptr_target* self) { + // Let it die + c10::weak_intrusive_ptr::reclaim(self); + // NB: You still "have" the 'self' pointer, but it's now invalid. + // If you want more safety, used the actual c10::weak_intrusive_ptr class +} + +template +inline T* lock(T* self) { + auto wptr = c10::weak_intrusive_ptr::reclaim(self); + auto ptr = wptr.lock(); + wptr.release(); + return ptr.release(); +} + +// This gives the STRONG refcount of a WEAK pointer +inline uint32_t use_count(weak_intrusive_ptr_target* self) { + auto wptr = c10::weak_intrusive_ptr::reclaim(self); + auto r = wptr.use_count(); + wptr.release(); + return r; +} + +} // namespace weak_intrusive_ptr + +} // namespace raw + +} // namespace c10 + +namespace std { +// To allow intrusive_ptr and weak_intrusive_ptr inside std::unordered_map or +// std::unordered_set, we need std::hash +template +struct hash> { + size_t operator()(const c10::intrusive_ptr& x) const { + return std::hash()(x.get()); + } +}; +template +struct hash> { + size_t operator()(const c10::weak_intrusive_ptr& x) const { + return std::hash()(x._unsafe_get_target()); + } +}; +} // namespace std + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/irange.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/irange.h new file mode 100644 index 0000000000000000000000000000000000000000..bc2a018db397a56dee0199af77509fc23dfe405b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/irange.h @@ -0,0 +1,128 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Copyright 2004-present Facebook. All Rights Reserved. + +#pragma once + +#include + +#include +#include +#include +#include + +namespace c10 { + +namespace detail { + +template < + typename I, + bool one_sided = false, + std::enable_if_t, int> = 0> +struct integer_iterator { + using iterator_category = std::input_iterator_tag; + using value_type = I; + using difference_type = std::ptrdiff_t; + using pointer = I*; + using reference = I&; + + explicit constexpr integer_iterator(I val) : value(val) {} + + constexpr I operator*() const { + return value; + } + + constexpr I const* operator->() const { + return &value; + } + + constexpr integer_iterator& operator++() { + ++value; + return *this; + } + + constexpr integer_iterator operator++(int) { + const auto copy = *this; + ++*this; + return copy; + } + + constexpr bool operator==(const integer_iterator& other) const { + if constexpr (one_sided) { + // Range-for loops' end test is `begin != end`, not `begin < + // end`. To handle `c10::irange(n)` where n < 0 (which should be + // empty), we just make `begin != end` fail whenever `end` is + // negative. + return is_negative(other.value) || value == other.value; + } else { + return value == other.value; + } + // Suppress "warning: missing return statement at end of non-void function" + // which Nvidia's Robert Crovella confirms is an NVCC compiler error + // here https://stackoverflow.com/a/64561686/752843 on 2020-10-27 + // `__builtin_unreachable();` would be best here, but it's not + // available with all compilers. So we instead return an arbitrary + // value trusting that this line will, in fact, never be reached. + return false; // Horrible hack + } + + constexpr bool operator!=(const integer_iterator& other) const { + return !(*this == other); + } + + protected: + I value; +}; + +} // namespace detail + +template < + typename I, + bool one_sided = false, + std::enable_if_t, bool> = true> +struct integer_range { + public: + constexpr integer_range(I begin, I end) : begin_(begin), end_(end) {} + using iterator = detail::integer_iterator; + constexpr iterator begin() const { + return begin_; + } + constexpr iterator end() const { + return end_; + } + + private: + iterator begin_; + iterator end_; +}; + +/// Creates an integer range for the half-open interval [begin, end) +/// If end<=begin, then the range is empty. +/// The range has the type of the `end` integer; `begin` integer is +/// cast to this type. +template < + typename Integer1, + typename Integer2, + std::enable_if_t, bool> = true, + std::enable_if_t, bool> = true> +constexpr integer_range irange(Integer1 begin, Integer2 end) { + // If end<=begin then the range is empty; we can achieve this effect by + // choosing the larger of {begin, end} as the loop terminator + return { + static_cast(begin), + std::max(static_cast(begin), end)}; +} + +/// Creates an integer range for the half-open interval [0, end) +/// If end<=begin, then the range is empty +template < + typename Integer, + std::enable_if_t, bool> = true> +constexpr integer_range irange(Integer end) { + return {Integer(), end}; +} + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/llvmMathExtras.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/llvmMathExtras.h new file mode 100644 index 0000000000000000000000000000000000000000..6884e20d112ace8886c69b10499f830c58c3703f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/llvmMathExtras.h @@ -0,0 +1,910 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +//===-- llvm/Support/MathExtras.h - Useful math functions -------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains some functions that are useful for math stuff. +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef __ANDROID_NDK__ +#include +#endif + +#ifndef __has_builtin +#define __has_builtin(x) 0 +#endif + +#ifndef LLVM_GNUC_PREREQ +#if defined(__GNUC__) && defined(__GNUC_MINOR__) && defined(__GNUC_PATCHLEVEL__) +#define LLVM_GNUC_PREREQ(maj, min, patch) \ + ((__GNUC__ << 20) + (__GNUC_MINOR__ << 10) + __GNUC_PATCHLEVEL__ >= \ + ((maj) << 20) + ((min) << 10) + (patch)) +#elif defined(__GNUC__) && defined(__GNUC_MINOR__) +#define LLVM_GNUC_PREREQ(maj, min, patch) \ + ((__GNUC__ << 20) + (__GNUC_MINOR__ << 10) >= ((maj) << 20) + ((min) << 10)) +#else +#define LLVM_GNUC_PREREQ(maj, min, patch) 0 +#endif +#endif + +#ifdef _MSC_VER +// Declare these intrinsics manually rather including intrin.h. It's very +// expensive, and MathExtras.h is popular. +// #include +extern "C" { +unsigned char _BitScanForward(unsigned long* _Index, unsigned long _Mask); +unsigned char _BitScanForward64(unsigned long* _Index, unsigned __int64 _Mask); +unsigned char _BitScanReverse(unsigned long* _Index, unsigned long _Mask); +unsigned char _BitScanReverse64(unsigned long* _Index, unsigned __int64 _Mask); +} +#endif + +namespace c10::llvm { +/// The behavior an operation has on an input of 0. +enum ZeroBehavior { + /// The returned value is undefined. + ZB_Undefined, + /// The returned value is numeric_limits::max() + ZB_Max, + /// The returned value is numeric_limits::digits + ZB_Width +}; + +namespace detail { +template +struct TrailingZerosCounter { + static std::size_t count(T Val, ZeroBehavior /*unused*/) { + if (!Val) + return std::numeric_limits::digits; + if (Val & 0x1) + return 0; + + // Bisection method. + std::size_t ZeroBits = 0; + T Shift = std::numeric_limits::digits >> 1; + T Mask = std::numeric_limits::max() >> Shift; + while (Shift) { + if ((Val & Mask) == 0) { + Val >>= Shift; + ZeroBits |= Shift; + } + Shift >>= 1; + Mask >>= Shift; + } + return ZeroBits; + } +}; + +#if (defined(__GNUC__) && __GNUC__ >= 4) || defined(_MSC_VER) +template +struct TrailingZerosCounter { + static std::size_t count(T Val, ZeroBehavior ZB) { + if (ZB != ZB_Undefined && Val == 0) + return 32; + +#if __has_builtin(__builtin_ctz) || LLVM_GNUC_PREREQ(4, 0, 0) + return __builtin_ctz(Val); +#elif defined(_MSC_VER) + unsigned long Index; + _BitScanForward(&Index, Val); + return Index; +#endif + } +}; + +#if !defined(_MSC_VER) || defined(_M_X64) +template +struct TrailingZerosCounter { + static std::size_t count(T Val, ZeroBehavior ZB) { + if (ZB != ZB_Undefined && Val == 0) + return 64; + +#if __has_builtin(__builtin_ctzll) || LLVM_GNUC_PREREQ(4, 0, 0) + return __builtin_ctzll(Val); +#elif defined(_MSC_VER) + unsigned long Index; + _BitScanForward64(&Index, Val); + return Index; +#endif + } +}; +#endif +#endif +} // namespace detail + +/// Count number of 0's from the least significant bit to the most +/// stopping at the first 1. +/// +/// Only unsigned integral types are allowed. +/// +/// \param ZB the behavior on an input of 0. Only ZB_Width and ZB_Undefined are +/// valid arguments. +template +std::size_t countTrailingZeros(T Val, ZeroBehavior ZB = ZB_Width) { + static_assert( + std::numeric_limits::is_integer && !std::numeric_limits::is_signed, + "Only unsigned integral types are allowed."); + return llvm::detail::TrailingZerosCounter::count(Val, ZB); +} + +namespace detail { +template +struct LeadingZerosCounter { + static std::size_t count(T Val, ZeroBehavior /*unused*/) { + if (!Val) + return std::numeric_limits::digits; + + // Bisection method. + std::size_t ZeroBits = 0; + for (T Shift = std::numeric_limits::digits >> 1; Shift; Shift >>= 1) { + T Tmp = Val >> Shift; + if (Tmp) + Val = Tmp; + else + ZeroBits |= Shift; + } + return ZeroBits; + } +}; + +#if (defined(__GNUC__) && __GNUC__ >= 4) || defined(_MSC_VER) +template +struct LeadingZerosCounter { + static std::size_t count(T Val, ZeroBehavior ZB) { + if (ZB != ZB_Undefined && Val == 0) + return 32; + +#if __has_builtin(__builtin_clz) || LLVM_GNUC_PREREQ(4, 0, 0) + return __builtin_clz(Val); +#elif defined(_MSC_VER) + unsigned long Index; + _BitScanReverse(&Index, Val); + return Index ^ 31; +#endif + } +}; + +#if !defined(_MSC_VER) || defined(_M_X64) +template +struct LeadingZerosCounter { + static std::size_t count(T Val, ZeroBehavior ZB) { + if (ZB != ZB_Undefined && Val == 0) + return 64; + +#if __has_builtin(__builtin_clzll) || LLVM_GNUC_PREREQ(4, 0, 0) + return __builtin_clzll(Val); +#elif defined(_MSC_VER) + unsigned long Index; + _BitScanReverse64(&Index, Val); + return Index ^ 63; +#endif + } +}; +#endif +#endif +} // namespace detail + +/// Count number of 0's from the most significant bit to the least +/// stopping at the first 1. +/// +/// Only unsigned integral types are allowed. +/// +/// \param ZB the behavior on an input of 0. Only ZB_Width and ZB_Undefined are +/// valid arguments. +template +std::size_t countLeadingZeros(T Val, ZeroBehavior ZB = ZB_Width) { + static_assert( + std::numeric_limits::is_integer && !std::numeric_limits::is_signed, + "Only unsigned integral types are allowed."); + return llvm::detail::LeadingZerosCounter::count(Val, ZB); +} + +/// Get the index of the first set bit starting from the least +/// significant bit. +/// +/// Only unsigned integral types are allowed. +/// +/// \param ZB the behavior on an input of 0. Only ZB_Max and ZB_Undefined are +/// valid arguments. +template +T findFirstSet(T Val, ZeroBehavior ZB = ZB_Max) { + if (ZB == ZB_Max && Val == 0) + return std::numeric_limits::max(); + + return countTrailingZeros(Val, ZB_Undefined); +} + +/// Create a bitmask with the N right-most bits set to 1, and all other +/// bits set to 0. Only unsigned types are allowed. +template +T maskTrailingOnes(unsigned N) { + static_assert(std::is_unsigned_v, "Invalid type!"); + const unsigned Bits = CHAR_BIT * sizeof(T); + assert(N <= Bits && "Invalid bit index"); + return N == 0 ? 0 : (T(-1) >> (Bits - N)); +} + +/// Create a bitmask with the N left-most bits set to 1, and all other +/// bits set to 0. Only unsigned types are allowed. +template +T maskLeadingOnes(unsigned N) { + return ~maskTrailingOnes(CHAR_BIT * sizeof(T) - N); +} + +/// Create a bitmask with the N right-most bits set to 0, and all other +/// bits set to 1. Only unsigned types are allowed. +template +T maskTrailingZeros(unsigned N) { + return maskLeadingOnes(CHAR_BIT * sizeof(T) - N); +} + +/// Create a bitmask with the N left-most bits set to 0, and all other +/// bits set to 1. Only unsigned types are allowed. +template +T maskLeadingZeros(unsigned N) { + return maskTrailingOnes(CHAR_BIT * sizeof(T) - N); +} + +/// Get the index of the last set bit starting from the least +/// significant bit. +/// +/// Only unsigned integral types are allowed. +/// +/// \param ZB the behavior on an input of 0. Only ZB_Max and ZB_Undefined are +/// valid arguments. +template +T findLastSet(T Val, ZeroBehavior ZB = ZB_Max) { + if (ZB == ZB_Max && Val == 0) + return std::numeric_limits::max(); + + // Use ^ instead of - because both gcc and llvm can remove the associated ^ + // in the __builtin_clz intrinsic on x86. + return countLeadingZeros(Val, ZB_Undefined) ^ + (std::numeric_limits::digits - 1); +} + +/// Macro compressed bit reversal table for 256 bits. +/// +/// http://graphics.stanford.edu/~seander/bithacks.html#BitReverseTable +/// NOLINTNEXTLINE(*c-arrays*) +static constexpr unsigned char BitReverseTable256[256] = { +#define R2(n) n, n + 2 * 64, n + 1 * 64, n + 3 * 64 +#define R4(n) R2(n), R2(n + 2 * 16), R2(n + 1 * 16), R2(n + 3 * 16) +#define R6(n) R4(n), R4(n + 2 * 4), R4(n + 1 * 4), R4(n + 3 * 4) + R6(0), + R6(2), + R6(1), + R6(3) +#undef R2 +#undef R4 +#undef R6 +}; + +/// Reverse the bits in \p Val. +template +T reverseBits(T Val) { + // NOLINTNEXTLINE(*c-arrays*) + unsigned char in[sizeof(Val)]; + // NOLINTNEXTLINE(*c-arrays*) + unsigned char out[sizeof(Val)]; + std::memcpy(in, &Val, sizeof(Val)); + for (unsigned i = 0; i < sizeof(Val); ++i) + out[(sizeof(Val) - i) - 1] = BitReverseTable256[in[i]]; + std::memcpy(&Val, out, sizeof(Val)); + return Val; +} + +// NOTE: The following support functions use the _32/_64 extensions instead of +// type overloading so that signed and unsigned integers can be used without +// ambiguity. + +/// Return the high 32 bits of a 64 bit value. +constexpr inline uint32_t Hi_32(uint64_t Value) { + return static_cast(Value >> 32); +} + +/// Return the low 32 bits of a 64 bit value. +constexpr inline uint32_t Lo_32(uint64_t Value) { + return static_cast(Value); +} + +/// Make a 64-bit integer from a high / low pair of 32-bit integers. +constexpr inline uint64_t Make_64(uint32_t High, uint32_t Low) { + return ((uint64_t)High << 32) | (uint64_t)Low; +} + +/// Checks if an integer fits into the given bit width. +template +constexpr inline bool isInt(int64_t x) { + return N >= 64 || + (-(INT64_C(1) << (N - 1)) <= x && x < (INT64_C(1) << (N - 1))); +} +// Template specializations to get better code for common cases. +template <> +constexpr inline bool isInt<8>(int64_t x) { + return static_cast(x) == x; +} +template <> +constexpr inline bool isInt<16>(int64_t x) { + return static_cast(x) == x; +} +template <> +constexpr inline bool isInt<32>(int64_t x) { + return static_cast(x) == x; +} + +/// Checks if a signed integer is an N bit number shifted left by S. +template +constexpr inline bool isShiftedInt(int64_t x) { + static_assert( + N > 0, "isShiftedInt<0> doesn't make sense (refers to a 0-bit number."); + static_assert(N + S <= 64, "isShiftedInt with N + S > 64 is too wide."); + return isInt(x) && (x % (UINT64_C(1) << S) == 0); +} + +/// Checks if an unsigned integer fits into the given bit width. +/// +/// This is written as two functions rather than as simply +/// +/// return N >= 64 || X < (UINT64_C(1) << N); +/// +/// to keep MSVC from (incorrectly) warning on isUInt<64> that we're shifting +/// left too many places. +template +constexpr inline std::enable_if_t<(N < 64), bool> isUInt(uint64_t X) { + static_assert(N > 0, "isUInt<0> doesn't make sense"); + return X < (UINT64_C(1) << N); +} +template +constexpr inline std::enable_if_t= 64, bool> isUInt(uint64_t /*X*/) { + return true; +} + +// Template specializations to get better code for common cases. +template <> +constexpr inline bool isUInt<8>(uint64_t x) { + return static_cast(x) == x; +} +template <> +constexpr inline bool isUInt<16>(uint64_t x) { + return static_cast(x) == x; +} +template <> +constexpr inline bool isUInt<32>(uint64_t x) { + return static_cast(x) == x; +} + +/// Checks if a unsigned integer is an N bit number shifted left by S. +template +constexpr inline bool isShiftedUInt(uint64_t x) { + static_assert( + N > 0, "isShiftedUInt<0> doesn't make sense (refers to a 0-bit number)"); + static_assert( + N + S <= 64, "isShiftedUInt with N + S > 64 is too wide."); + // Per the two static_asserts above, S must be strictly less than 64. So + // 1 << S is not undefined behavior. + return isUInt(x) && (x % (UINT64_C(1) << S) == 0); +} + +/// Gets the maximum value for a N-bit unsigned integer. +inline uint64_t maxUIntN(uint64_t N) { + assert(N > 0 && N <= 64 && "integer width out of range"); + + // uint64_t(1) << 64 is undefined behavior, so we can't do + // (uint64_t(1) << N) - 1 + // without checking first that N != 64. But this works and doesn't have a + // branch. + return UINT64_MAX >> (64 - N); +} + +// Ignore the false warning "Arithmetic overflow" for MSVC +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable : 4146) +#endif + +/// Gets the minimum value for a N-bit signed integer. +inline int64_t minIntN(int64_t N) { + assert(N > 0 && N <= 64 && "integer width out of range"); + // NOLINTNEXTLINE(*-narrowing-conversions) + return -(UINT64_C(1) << (N - 1)); +} + +#ifdef _MSC_VER +#pragma warning(pop) +#endif + +/// Gets the maximum value for a N-bit signed integer. +inline int64_t maxIntN(int64_t N) { + assert(N > 0 && N <= 64 && "integer width out of range"); + + // This relies on two's complement wraparound when N == 64, so we convert to + // int64_t only at the very end to avoid UB. + // NOLINTNEXTLINE(*-narrowing-conversions) + return (UINT64_C(1) << (N - 1)) - 1; +} + +/// Checks if an unsigned integer fits into the given (dynamic) bit width. +inline bool isUIntN(unsigned N, uint64_t x) { + return N >= 64 || x <= maxUIntN(N); +} + +/// Checks if an signed integer fits into the given (dynamic) bit width. +inline bool isIntN(unsigned N, int64_t x) { + return N >= 64 || (minIntN(N) <= x && x <= maxIntN(N)); +} + +/// Return true if the argument is a non-empty sequence of ones starting at the +/// least significant bit with the remainder zero (32 bit version). +/// Ex. isMask_32(0x0000FFFFU) == true. +constexpr inline bool isMask_32(uint32_t Value) { + return Value && ((Value + 1) & Value) == 0; +} + +/// Return true if the argument is a non-empty sequence of ones starting at the +/// least significant bit with the remainder zero (64 bit version). +constexpr inline bool isMask_64(uint64_t Value) { + return Value && ((Value + 1) & Value) == 0; +} + +/// Return true if the argument contains a non-empty sequence of ones with the +/// remainder zero (32 bit version.) Ex. isShiftedMask_32(0x0000FF00U) == true. +constexpr inline bool isShiftedMask_32(uint32_t Value) { + return Value && isMask_32((Value - 1) | Value); +} + +/// Return true if the argument contains a non-empty sequence of ones with the +/// remainder zero (64 bit version.) +constexpr inline bool isShiftedMask_64(uint64_t Value) { + return Value && isMask_64((Value - 1) | Value); +} + +/// Return true if the argument is a power of two > 0. +/// Ex. isPowerOf2_32(0x00100000U) == true (32 bit edition.) +constexpr inline bool isPowerOf2_32(uint32_t Value) { + return Value && !(Value & (Value - 1)); +} + +/// Return true if the argument is a power of two > 0 (64 bit edition.) +constexpr inline bool isPowerOf2_64(uint64_t Value) { + return Value && !(Value & (Value - 1)); +} + +/// Count the number of ones from the most significant bit to the first +/// zero bit. +/// +/// Ex. countLeadingOnes(0xFF0FFF00) == 8. +/// Only unsigned integral types are allowed. +/// +/// \param ZB the behavior on an input of all ones. Only ZB_Width and +/// ZB_Undefined are valid arguments. +template +std::size_t countLeadingOnes(T Value, ZeroBehavior ZB = ZB_Width) { + static_assert( + std::numeric_limits::is_integer && !std::numeric_limits::is_signed, + "Only unsigned integral types are allowed."); + return countLeadingZeros(~Value, ZB); +} + +/// Count the number of ones from the least significant bit to the first +/// zero bit. +/// +/// Ex. countTrailingOnes(0x00FF00FF) == 8. +/// Only unsigned integral types are allowed. +/// +/// \param ZB the behavior on an input of all ones. Only ZB_Width and +/// ZB_Undefined are valid arguments. +template +std::size_t countTrailingOnes(T Value, ZeroBehavior ZB = ZB_Width) { + static_assert( + std::numeric_limits::is_integer && !std::numeric_limits::is_signed, + "Only unsigned integral types are allowed."); + return countTrailingZeros(~Value, ZB); +} + +namespace detail { +template +struct PopulationCounter { + static unsigned count(T Value) { + // Generic version, forward to 32 bits. + static_assert(SizeOfT <= 4, "Not implemented!"); +#if defined(__GNUC__) && __GNUC__ >= 4 + return __builtin_popcount(Value); +#else + uint32_t v = Value; + v = v - ((v >> 1) & 0x55555555); + v = (v & 0x33333333) + ((v >> 2) & 0x33333333); + return ((v + (v >> 4) & 0xF0F0F0F) * 0x1010101) >> 24; +#endif + } +}; + +template +struct PopulationCounter { + static unsigned count(T Value) { +#if defined(__GNUC__) && __GNUC__ >= 4 + return __builtin_popcountll(Value); +#else + uint64_t v = Value; + v = v - ((v >> 1) & 0x5555555555555555ULL); + v = (v & 0x3333333333333333ULL) + ((v >> 2) & 0x3333333333333333ULL); + v = (v + (v >> 4)) & 0x0F0F0F0F0F0F0F0FULL; + return unsigned((uint64_t)(v * 0x0101010101010101ULL) >> 56); +#endif + } +}; +} // namespace detail + +/// Count the number of set bits in a value. +/// Ex. countPopulation(0xF000F000) = 8 +/// Returns 0 if the word is zero. +template +inline unsigned countPopulation(T Value) { + static_assert( + std::numeric_limits::is_integer && !std::numeric_limits::is_signed, + "Only unsigned integral types are allowed."); + return detail::PopulationCounter::count(Value); +} + +/// Return the log base 2 of the specified value. +inline double Log2(double Value) { +#if defined(__ANDROID_API__) && __ANDROID_API__ < 18 + return __builtin_log(Value) / __builtin_log(2.0); +#else + return log2(Value); +#endif +} + +/// Return the floor log base 2 of the specified value, -1 if the value is zero. +/// (32 bit edition.) +/// Ex. Log2_32(32) == 5, Log2_32(1) == 0, Log2_32(0) == -1, Log2_32(6) == 2 +inline unsigned Log2_32(uint32_t Value) { + return static_cast(31 - countLeadingZeros(Value)); +} + +/// Return the floor log base 2 of the specified value, -1 if the value is zero. +/// (64 bit edition.) +inline unsigned Log2_64(uint64_t Value) { + return static_cast(63 - countLeadingZeros(Value)); +} + +/// Return the ceil log base 2 of the specified value, 32 if the value is zero. +/// (32 bit edition). +/// Ex. Log2_32_Ceil(32) == 5, Log2_32_Ceil(1) == 0, Log2_32_Ceil(6) == 3 +inline unsigned Log2_32_Ceil(uint32_t Value) { + return static_cast(32 - countLeadingZeros(Value - 1)); +} + +/// Return the ceil log base 2 of the specified value, 64 if the value is zero. +/// (64 bit edition.) +inline unsigned Log2_64_Ceil(uint64_t Value) { + return static_cast(64 - countLeadingZeros(Value - 1)); +} + +/// Return the greatest common divisor of the values using Euclid's algorithm. +inline uint64_t GreatestCommonDivisor64(uint64_t A, uint64_t B) { + while (B) { + uint64_t T = B; + B = A % B; + A = T; + } + return A; +} + +/// This function takes a 64-bit integer and returns the bit equivalent double. +inline double BitsToDouble(uint64_t Bits) { + double D = 0; + static_assert(sizeof(uint64_t) == sizeof(double), "Unexpected type sizes"); + memcpy(&D, &Bits, sizeof(Bits)); + return D; +} + +/// This function takes a 32-bit integer and returns the bit equivalent float. +inline float BitsToFloat(uint32_t Bits) { + // TODO: Use std::bit_cast once C++20 becomes available. + return c10::bit_cast(Bits); +} + +/// This function takes a double and returns the bit equivalent 64-bit integer. +/// Note that copying doubles around changes the bits of NaNs on some hosts, +/// notably x86, so this routine cannot be used if these bits are needed. +inline uint64_t DoubleToBits(double Double) { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + uint64_t Bits; + static_assert(sizeof(uint64_t) == sizeof(double), "Unexpected type sizes"); + memcpy(&Bits, &Double, sizeof(Double)); + return Bits; +} + +/// This function takes a float and returns the bit equivalent 32-bit integer. +/// Note that copying floats around changes the bits of NaNs on some hosts, +/// notably x86, so this routine cannot be used if these bits are needed. +inline uint32_t FloatToBits(float Float) { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + uint32_t Bits; + static_assert(sizeof(uint32_t) == sizeof(float), "Unexpected type sizes"); + memcpy(&Bits, &Float, sizeof(Float)); + return Bits; +} + +/// A and B are either alignments or offsets. Return the minimum alignment that +/// may be assumed after adding the two together. +constexpr inline uint64_t MinAlign(uint64_t A, uint64_t B) { + // The largest power of 2 that divides both A and B. + // + // Replace "-Value" by "1+~Value" in the following commented code to avoid + // MSVC warning C4146 + // return (A | B) & -(A | B); + return (A | B) & (1 + ~(A | B)); +} + +/// Aligns \c Addr to \c Alignment bytes, rounding up. +/// +/// Alignment should be a power of two. This method rounds up, so +/// alignAddr(7, 4) == 8 and alignAddr(8, 4) == 8. +inline uintptr_t alignAddr(const void* Addr, size_t Alignment) { + assert( + Alignment && isPowerOf2_64((uint64_t)Alignment) && + "Alignment is not a power of two!"); + + assert((uintptr_t)Addr + Alignment - 1 >= (uintptr_t)Addr); + + return (((uintptr_t)Addr + Alignment - 1) & ~(uintptr_t)(Alignment - 1)); +} + +/// Returns the necessary adjustment for aligning \c Ptr to \c Alignment +/// bytes, rounding up. +inline size_t alignmentAdjustment(const void* Ptr, size_t Alignment) { + return alignAddr(Ptr, Alignment) - (uintptr_t)Ptr; +} + +/// Returns the next power of two (in 64-bits) that is strictly greater than A. +/// Returns zero on overflow. +inline uint64_t NextPowerOf2(uint64_t A) { + A |= (A >> 1); + A |= (A >> 2); + A |= (A >> 4); + A |= (A >> 8); + A |= (A >> 16); + A |= (A >> 32); + return A + 1; +} + +/// Returns the power of two which is less than or equal to the given value. +/// Essentially, it is a floor operation across the domain of powers of two. +inline uint64_t PowerOf2Floor(uint64_t A) { + if (!A) + return 0; + return 1ull << (63 - countLeadingZeros(A, ZB_Undefined)); +} + +/// Returns the power of two which is greater than or equal to the given value. +/// Essentially, it is a ceil operation across the domain of powers of two. +inline uint64_t PowerOf2Ceil(uint64_t A) { + if (!A) + return 0; + return NextPowerOf2(A - 1); +} + +/// Returns the next integer (mod 2**64) that is greater than or equal to +/// \p Value and is a multiple of \p Align. \p Align must be non-zero. +/// +/// If non-zero \p Skew is specified, the return value will be a minimal +/// integer that is greater than or equal to \p Value and equal to +/// \p Align * N + \p Skew for some integer N. If \p Skew is larger than +/// \p Align, its value is adjusted to '\p Skew mod \p Align'. +/// +/// Examples: +/// \code +/// alignTo(5, 8) = 8 +/// alignTo(17, 8) = 24 +/// alignTo(~0LL, 8) = 0 +/// alignTo(321, 255) = 510 +/// +/// alignTo(5, 8, 7) = 7 +/// alignTo(17, 8, 1) = 17 +/// alignTo(~0LL, 8, 3) = 3 +/// alignTo(321, 255, 42) = 552 +/// \endcode +inline uint64_t alignTo(uint64_t Value, uint64_t Align, uint64_t Skew = 0) { + assert(Align != 0u && "Align can't be 0."); + Skew %= Align; + return (Value + Align - 1 - Skew) / Align * Align + Skew; +} + +/// Returns the next integer (mod 2**64) that is greater than or equal to +/// \p Value and is a multiple of \c Align. \c Align must be non-zero. +template +constexpr inline uint64_t alignTo(uint64_t Value) { + static_assert(Align != 0u, "Align must be non-zero"); + return (Value + Align - 1) / Align * Align; +} + +/// Returns the integer ceil(Numerator / Denominator). +inline uint64_t divideCeil(uint64_t Numerator, uint64_t Denominator) { + return alignTo(Numerator, Denominator) / Denominator; +} + +/// \c alignTo for contexts where a constant expression is required. +/// \sa alignTo +/// +/// \todo FIXME: remove when \c constexpr becomes really \c constexpr +template +struct AlignTo { + static_assert(Align != 0u, "Align must be non-zero"); + template + struct from_value { + static const uint64_t value = (Value + Align - 1) / Align * Align; + }; +}; + +/// Returns the largest uint64_t less than or equal to \p Value and is +/// \p Skew mod \p Align. \p Align must be non-zero +inline uint64_t alignDown(uint64_t Value, uint64_t Align, uint64_t Skew = 0) { + assert(Align != 0u && "Align can't be 0."); + Skew %= Align; + return (Value - Skew) / Align * Align + Skew; +} + +/// Returns the offset to the next integer (mod 2**64) that is greater than +/// or equal to \p Value and is a multiple of \p Align. \p Align must be +/// non-zero. +inline uint64_t OffsetToAlignment(uint64_t Value, uint64_t Align) { + return alignTo(Value, Align) - Value; +} + +/// Sign-extend the number in the bottom B bits of X to a 32-bit integer. +/// Requires 0 < B <= 32. +template +constexpr inline int32_t SignExtend32(uint32_t X) { + static_assert(B > 0, "Bit width can't be 0."); + static_assert(B <= 32, "Bit width out of range."); + return int32_t(X << (32 - B)) >> (32 - B); +} + +/// Sign-extend the number in the bottom B bits of X to a 32-bit integer. +/// Requires 0 < B < 32. +inline int32_t SignExtend32(uint32_t X, unsigned B) { + assert(B > 0 && "Bit width can't be 0."); + assert(B <= 32 && "Bit width out of range."); + return int32_t(X << (32 - B)) >> (32 - B); +} + +/// Sign-extend the number in the bottom B bits of X to a 64-bit integer. +/// Requires 0 < B < 64. +template +constexpr inline int64_t SignExtend64(uint64_t x) { + static_assert(B > 0, "Bit width can't be 0."); + static_assert(B <= 64, "Bit width out of range."); + return int64_t(x << (64 - B)) >> (64 - B); +} + +/// Sign-extend the number in the bottom B bits of X to a 64-bit integer. +/// Requires 0 < B < 64. +inline int64_t SignExtend64(uint64_t X, unsigned B) { + assert(B > 0 && "Bit width can't be 0."); + assert(B <= 64 && "Bit width out of range."); + return int64_t(X << (64 - B)) >> (64 - B); +} + +/// Subtract two unsigned integers, X and Y, of type T and return the absolute +/// value of the result. +template +std::enable_if_t, T> AbsoluteDifference(T X, T Y) { + return std::max(X, Y) - std::min(X, Y); +} + +/// Add two unsigned integers, X and Y, of type T. Clamp the result to the +/// maximum representable value of T on overflow. ResultOverflowed indicates if +/// the result is larger than the maximum representable value of type T. +template +std::enable_if_t, T> SaturatingAdd( + T X, + T Y, + bool* ResultOverflowed = nullptr) { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + bool Dummy; + bool& Overflowed = ResultOverflowed ? *ResultOverflowed : Dummy; + // Hacker's Delight, p. 29 + T Z = X + Y; + Overflowed = (Z < X || Z < Y); + if (Overflowed) + return std::numeric_limits::max(); + else + return Z; +} + +/// Multiply two unsigned integers, X and Y, of type T. Clamp the result to the +/// maximum representable value of T on overflow. ResultOverflowed indicates if +/// the result is larger than the maximum representable value of type T. +template +std::enable_if_t, T> SaturatingMultiply( + T X, + T Y, + bool* ResultOverflowed = nullptr) { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + bool Dummy; + bool& Overflowed = ResultOverflowed ? *ResultOverflowed : Dummy; + + // Hacker's Delight, p. 30 has a different algorithm, but we don't use that + // because it fails for uint16_t (where multiplication can have undefined + // behavior due to promotion to int), and requires a division in addition + // to the multiplication. + + Overflowed = false; + + // Log2(Z) would be either Log2Z or Log2Z + 1. + // Special case: if X or Y is 0, Log2_64 gives -1, and Log2Z + // will necessarily be less than Log2Max as desired. + int Log2Z = Log2_64(X) + Log2_64(Y); + const T Max = std::numeric_limits::max(); + int Log2Max = Log2_64(Max); + if (Log2Z < Log2Max) { + return X * Y; + } + if (Log2Z > Log2Max) { + Overflowed = true; + return Max; + } + + // We're going to use the top bit, and maybe overflow one + // bit past it. Multiply all but the bottom bit then add + // that on at the end. + T Z = (X >> 1) * Y; + if (Z & ~(Max >> 1)) { + Overflowed = true; + return Max; + } + Z <<= 1; + if (X & 1) + return SaturatingAdd(Z, Y, ResultOverflowed); + + return Z; +} + +/// Multiply two unsigned integers, X and Y, and add the unsigned integer, A to +/// the product. Clamp the result to the maximum representable value of T on +/// overflow. ResultOverflowed indicates if the result is larger than the +/// maximum representable value of type T. +template +std::enable_if_t, T> SaturatingMultiplyAdd( + T X, + T Y, + T A, + bool* ResultOverflowed = nullptr) { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + bool Dummy; + bool& Overflowed = ResultOverflowed ? *ResultOverflowed : Dummy; + + T Product = SaturatingMultiply(X, Y, &Overflowed); + if (Overflowed) + return Product; + + return SaturatingAdd(A, Product, &Overflowed); +} + +/// Use this rather than HUGE_VALF; the latter causes warnings on MSVC. +extern const float huge_valf; +} // namespace c10::llvm + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/logging_common.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/logging_common.h new file mode 100644 index 0000000000000000000000000000000000000000..8d881f4de245b1fe650b322cffa1dc294bd019dd --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/logging_common.h @@ -0,0 +1,79 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#ifndef C10_UTIL_LOGGING_COMMON_H_ +#define C10_UTIL_LOGGING_COMMON_H_ + +#include +#include + +namespace c10 { + +// MessageLogger that throws exceptions instead of aborting (glog version) +// or logs and may abort (non-glog version). +class C10_API MessageLogger { + public: + MessageLogger( + const char* file, + int line, + int severity, + bool exit_on_fatal = true); + ~MessageLogger() noexcept(false); + + // Return the stream associated with the logger object. + std::stringstream& stream(); + + private: + // When there is a fatal log, and fatal == true, we abort + // otherwise, we throw. + void DealWithFatal(); + +#if defined(ANDROID) && !defined(C10_USE_GLOG) + const char* tag_{"native"}; +#endif + std::stringstream stream_; + int severity_; + bool exit_on_fatal_; +}; + +// This class is used to explicitly ignore values in the conditional +// logging macros. This avoids compiler warnings like "value computed +// is not used" and "statement has no effect". +class C10_API LoggerVoidify { + public: + LoggerVoidify() = default; + // This has to be an operator with a precedence lower than << but + // higher than ?: + void operator&(const std::ostream& s [[maybe_unused]]) {} +}; + +// Forward declarations for CheckNotNull functions +template +T& CheckNotNullCommon( + const char* file, + int line, + const char* names, + T& t, + bool fatal = true); + +template +T* CheckNotNull( + const char* file, + int line, + const char* names, + T* t, + bool fatal = true); + +template +T& CheckNotNull( + const char* file, + int line, + const char* names, + T& t, + bool fatal = true); + +} // namespace c10 + +#endif // C10_UTIL_LOGGING_COMMON_H_ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/logging_is_google_glog.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/logging_is_google_glog.h new file mode 100644 index 0000000000000000000000000000000000000000..082e0b86484f7b62d7b0d383d6c717ef5a9d9340 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/logging_is_google_glog.h @@ -0,0 +1,110 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#ifndef C10_UTIL_LOGGING_IS_GOOGLE_GLOG_H_ +#define C10_UTIL_LOGGING_IS_GOOGLE_GLOG_H_ + +#include +#include +#include + +#include // because some of the caffe2 code uses e.g. std::setw +// Using google glog. For glog 0.3.2 versions, stl_logging.h needs to be before +// logging.h to actually use stl_logging. Because template magic. +// In addition, we do not do stl logging in .cu files because nvcc does not like +// it. Some mobile platforms do not like stl_logging, so we add an +// overload in that case as well. + +#ifdef __CUDACC__ +#include +#endif + +#if !defined(__CUDACC__) && !defined(C10_USE_MINIMAL_GLOG) +#include + +// Old versions of glog don't declare this using declaration, so help +// them out. Fortunately, C++ won't complain if you declare the same +// using declaration multiple times. +namespace std { +using ::operator<<; +} + +#else // !defined(__CUDACC__) && !defined(C10_USE_MINIMAL_GLOG) + +// In the cudacc compiler scenario, we will simply ignore the container +// printout feature. Basically we need to register a fake overload for +// vector/string - here, we just ignore the entries in the logs. + +namespace std { +#define INSTANTIATE_FOR_CONTAINER(container) \ + template \ + ostream& operator<<(ostream& out, const container&) { \ + return out; \ + } + +INSTANTIATE_FOR_CONTAINER(vector) +INSTANTIATE_FOR_CONTAINER(map) +INSTANTIATE_FOR_CONTAINER(set) +#undef INSTANTIATE_FOR_CONTAINER +} // namespace std + +#endif + +#include +#include + +namespace c10 { + +[[noreturn]] void ThrowEnforceNotMet( + const char* file, + const int line, + const char* condition, + const std::string& msg, + const void* caller); + +template +T& CheckNotNullCommon( + const char* file, + int line, + const char* names, + T& t, + bool fatal) { + if (t == nullptr) { + MessageLogger(file, line, ::google::GLOG_FATAL, fatal).stream() + << "Check failed: '" << names << "' must be non NULL. "; + } + return t; +} + +template +T* CheckNotNull( + const char* file, + int line, + const char* names, + T* t, + bool fatal) { + return CheckNotNullCommon(file, line, names, t, fatal); +} + +template +T& CheckNotNull( + const char* file, + int line, + const char* names, + T& t, + bool fatal) { + return CheckNotNullCommon(file, line, names, t, fatal); +} + +} // namespace c10 + +// Log with source location information override (to be used in generic +// warning/error handlers implemented as functions, not macros) +// +// Note, we don't respect GOOGLE_STRIP_LOG here for simplicity +#define LOG_AT_FILE_LINE(n, file, line) \ + ::google::LogMessage(file, line, ::google::GLOG_##n).stream() + +#endif // C10_UTIL_LOGGING_IS_GOOGLE_GLOG_H_ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/logging_is_not_google_glog.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/logging_is_not_google_glog.h new file mode 100644 index 0000000000000000000000000000000000000000..efeffb93afc3e05f0780dc554534fb856c49de3b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/logging_is_not_google_glog.h @@ -0,0 +1,186 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#ifndef C10_UTIL_LOGGING_IS_NOT_GOOGLE_GLOG_H_ +#define C10_UTIL_LOGGING_IS_NOT_GOOGLE_GLOG_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +const char CAFFE2_SEVERITY_PREFIX[] = "FEWIV"; + +namespace c10 { + +// Log severity level constants. +const int GLOG_FATAL = 3; +const int GLOG_ERROR = 2; +const int GLOG_WARNING = 1; +const int GLOG_INFO = 0; + +// Helpers for TORCH_CHECK_NOTNULL(). Two are necessary to support both raw +// pointers and smart pointers. +template +T& CheckNotNullCommon( + const char* file, + int line, + const char* names, + T& t, + bool fatal) { + if (t == nullptr) { + MessageLogger(file, line, GLOG_FATAL, fatal).stream() + << "Check failed: '" << names << "' must be non NULL. "; + } + return t; +} + +template +T* CheckNotNull( + const char* file, + int line, + const char* names, + T* t, + bool fatal) { + return CheckNotNullCommon(file, line, names, t, fatal); +} + +template +T& CheckNotNull( + const char* file, + int line, + const char* names, + T& t, + bool fatal) { + return CheckNotNullCommon(file, line, names, t, fatal); +} +} // namespace c10 + +// ---------------------- Logging Macro definitions -------------------------- + +static_assert( + CAFFE2_LOG_THRESHOLD <= ::c10::GLOG_FATAL, + "CAFFE2_LOG_THRESHOLD should at most be GLOG_FATAL."); +// If n is under the compile time caffe log threshold, The _CAFFE_LOG(n) +// should not generate anything in optimized code. +#define LOG(n) \ + if (::c10::GLOG_##n >= CAFFE2_LOG_THRESHOLD) \ + ::c10::MessageLogger(__FILE__, __LINE__, ::c10::GLOG_##n).stream() +#define VLOG(n) \ + if (-n >= CAFFE2_LOG_THRESHOLD) \ + ::c10::MessageLogger(__FILE__, __LINE__, -n).stream() + +#define LOG_IF(n, condition) \ + if (::c10::GLOG_##n >= CAFFE2_LOG_THRESHOLD && (condition)) \ + ::c10::MessageLogger(__FILE__, __LINE__, ::c10::GLOG_##n).stream() +#define VLOG_IF(n, condition) \ + if (-n >= CAFFE2_LOG_THRESHOLD && (condition)) \ + ::c10::MessageLogger(__FILE__, __LINE__, -n).stream() + +#define VLOG_IS_ON(verboselevel) (CAFFE2_LOG_THRESHOLD <= -(verboselevel)) + +// Log with source location information override (to be used in generic +// warning/error handlers implemented as functions, not macros) +#define LOG_AT_FILE_LINE(n, file, line) \ + if (::c10::GLOG_##n >= CAFFE2_LOG_THRESHOLD) \ + ::c10::MessageLogger(file, line, ::c10::GLOG_##n).stream() + +// Log only if condition is met. Otherwise evaluates to void. +#define FATAL_IF(condition) \ + condition ? (void)0 \ + : ::c10::LoggerVoidify() & \ + ::c10::MessageLogger(__FILE__, __LINE__, ::c10::GLOG_FATAL).stream() + +// Check for a given boolean condition. +#define CHECK(condition) FATAL_IF(condition) << "Check failed: " #condition " " + +#ifndef NDEBUG +// Debug only version of CHECK +#define DCHECK(condition) FATAL_IF(condition) << "Check failed: " #condition " " +#define DLOG(severity) LOG(severity) +#else // NDEBUG +// Optimized version - generates no code. +#define DCHECK(condition) \ + while (false) \ + CHECK(condition) + +#define DLOG(n) \ + true ? (void)0 \ + : ::c10::LoggerVoidify() & \ + ::c10::MessageLogger(__FILE__, __LINE__, ::c10::GLOG_##n).stream() +#endif // NDEBUG + +// ---------------------- Support for std objects -------------------------- +// These are adapted from glog to support a limited set of logging capability +// for STL objects. + +namespace std { +// Forward declare these two, and define them after all the container streams +// operators so that we can recurse from pair -> container -> container -> pair +// properly. +template +std::ostream& operator<<(std::ostream& out, const std::pair& p); +} // namespace std + +namespace c10 { +template +void PrintSequence(std::ostream& ss, Iter begin, Iter end); +} // namespace c10 + +namespace std { +#define INSTANTIATE_FOR_CONTAINER(container) \ + template \ + std::ostream& operator<<( \ + std::ostream& out, const container& seq) { \ + c10::PrintSequence(out, seq.begin(), seq.end()); \ + return out; \ + } + +INSTANTIATE_FOR_CONTAINER(std::vector) +INSTANTIATE_FOR_CONTAINER(std::map) +INSTANTIATE_FOR_CONTAINER(std::set) +#undef INSTANTIATE_FOR_CONTAINER + +template +inline std::ostream& operator<<( + std::ostream& out, + const std::pair& p) { + out << '(' << p.first << ", " << p.second << ')'; + return out; +} + +inline std::ostream& operator<<( + std::ostream& out, + const std::nullptr_t& /*unused*/) { + out << "(null)"; + return out; +} +} // namespace std + +namespace c10 { +template +inline void PrintSequence(std::ostream& out, Iter begin, Iter end) { + // Output at most 100 elements -- appropriate if used for logging. + for (int i = 0; begin != end && i < 100; ++i, ++begin) { + if (i > 0) + out << ' '; + out << *begin; + } + if (begin != end) { + out << " ..."; + } +} +} // namespace c10 + +#endif // C10_UTIL_LOGGING_IS_NOT_GOOGLE_GLOG_H_ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/numa.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/numa.h new file mode 100644 index 0000000000000000000000000000000000000000..4ae58609b5d56135e59075d5428e03a6c99ff230 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/numa.h @@ -0,0 +1,46 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include + +C10_DECLARE_bool(caffe2_cpu_numa_enabled); + +namespace c10 { + +/** + * Check whether NUMA is enabled + */ +C10_API bool IsNUMAEnabled(); + +/** + * Bind to a given NUMA node + */ +C10_API void NUMABind(int numa_node_id); + +/** + * Get the NUMA id for a given pointer `ptr` + */ +C10_API int GetNUMANode(const void* ptr); + +/** + * Get number of NUMA nodes + */ +C10_API int GetNumNUMANodes(); + +/** + * Move the memory pointed to by `ptr` of a given size to another NUMA node + */ +C10_API void NUMAMove(void* ptr, size_t size, int numa_node_id); + +/** + * Get the current NUMA node id + */ +C10_API int GetCurrentNUMANode(); + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/order_preserving_flat_hash_map.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/order_preserving_flat_hash_map.h new file mode 100644 index 0000000000000000000000000000000000000000..e991a567ec5eac9c967f4743255de1eb51c9338a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/order_preserving_flat_hash_map.h @@ -0,0 +1,2222 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Taken from +// https://github.com/skarupke/flat_hash_map/blob/2c4687431f978f02a3780e24b8b701d22aa32d9c/flat_hash_map.hpp +// with fixes applied: +// - https://github.com/skarupke/flat_hash_map/pull/25 +// - https://github.com/skarupke/flat_hash_map/pull/26 +// - replace size_t with uint64_t to fix it for 32bit +// - add "GCC diagnostic" pragma to ignore -Wshadow +// - make sherwood_v3_table::convertible_to_iterator public because GCC5 seems +// to have issues with it otherwise +// - fix compiler warnings in operator templated_iterator +// - make use of 'if constexpr' and eliminate AssignIfTrue template + +// Copyright Malte Skarupke 2017. +// Distributed under the Boost Software License, Version 1.0. +// (See http://www.boost.org/LICENSE_1_0.txt) + +// Modified to maintain insertion and deletion order through a doubly-linked +// list + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef _MSC_VER +#define SKA_NOINLINE(...) __declspec(noinline) __VA_ARGS__ +#else +#define SKA_NOINLINE(...) __VA_ARGS__ __attribute__((noinline)) +#endif + +namespace ska_ordered { + +struct prime_number_hash_policy; +struct power_of_two_hash_policy; +struct fibonacci_hash_policy; + +namespace detailv3 { +template +struct functor_storage : Functor { + functor_storage() = default; + functor_storage(const Functor& functor) : Functor(functor) {} + template + Result operator()(Args&&... args) { + return static_cast(*this)(std::forward(args)...); + } + template + Result operator()(Args&&... args) const { + return static_cast(*this)(std::forward(args)...); + } +}; +template +struct functor_storage { + typedef Result (*function_ptr)(Args...); + function_ptr function; + functor_storage(function_ptr function) : function(function) {} + Result operator()(Args... args) const { + return function(std::forward(args)...); + } + operator function_ptr&() { + return function; + } + operator const function_ptr&() { + return function; + } +}; +template +struct KeyOrValueHasher : functor_storage { + typedef functor_storage hasher_storage; + KeyOrValueHasher() = default; + KeyOrValueHasher(const hasher& hash) : hasher_storage(hash) {} + uint64_t operator()(const key_type& key) { + return static_cast(*this)(key); + } + uint64_t operator()(const key_type& key) const { + return static_cast(*this)(key); + } + uint64_t operator()(const value_type& value) { + return static_cast(*this)(value.first); + } + uint64_t operator()(const value_type& value) const { + return static_cast(*this)(value.first); + } + template + uint64_t operator()(const std::pair& value) { + return static_cast(*this)(value.first); + } + template + uint64_t operator()(const std::pair& value) const { + return static_cast(*this)(value.first); + } +}; +template +struct KeyOrValueEquality : functor_storage { + typedef functor_storage equality_storage; + KeyOrValueEquality() = default; + KeyOrValueEquality(const key_equal& equality) : equality_storage(equality) {} + bool operator()(const key_type& lhs, const key_type& rhs) { + return static_cast(*this)(lhs, rhs); + } + bool operator()(const key_type& lhs, const value_type& rhs) { + return static_cast(*this)(lhs, rhs.first); + } + bool operator()(const value_type& lhs, const key_type& rhs) { + return static_cast(*this)(lhs.first, rhs); + } + bool operator()(const value_type& lhs, const value_type& rhs) { + return static_cast(*this)(lhs.first, rhs.first); + } + template + bool operator()(const key_type& lhs, const std::pair& rhs) { + return static_cast(*this)(lhs, rhs.first); + } + template + bool operator()(const std::pair& lhs, const key_type& rhs) { + return static_cast(*this)(lhs.first, rhs); + } + template + bool operator()(const value_type& lhs, const std::pair& rhs) { + return static_cast(*this)(lhs.first, rhs.first); + } + template + bool operator()(const std::pair& lhs, const value_type& rhs) { + return static_cast(*this)(lhs.first, rhs.first); + } + template + bool operator()(const std::pair& lhs, const std::pair& rhs) { + return static_cast(*this)(lhs.first, rhs.first); + } +}; +static constexpr int8_t min_lookups = 4; +template +// NOLINTNEXTLINE(cppcoreguidelines-special-member-functions) +struct sherwood_v3_entry { + // NOLINTNEXTLINE(modernize-use-equals-default) + sherwood_v3_entry() {} + sherwood_v3_entry(int8_t distance_from_desired) + : distance_from_desired(distance_from_desired) {} + // NOLINTNEXTLINE(modernize-use-equals-default) + ~sherwood_v3_entry() {} + + bool has_value() const { + return distance_from_desired >= 0; + } + bool is_empty() const { + return distance_from_desired < 0; + } + bool is_at_desired_position() const { + return distance_from_desired <= 0; + } + template + void emplace(int8_t distance, Args&&... args) { + new (std::addressof(value)) T(std::forward(args)...); + distance_from_desired = distance; + } + + void destroy_value() { + value.~T(); + distance_from_desired = -1; + } + + sherwood_v3_entry* prev = nullptr; + sherwood_v3_entry* next = nullptr; + int8_t distance_from_desired = -1; + static constexpr int8_t special_end_value = 0; + union { + T value; + }; +}; + +inline int8_t log2(uint64_t value) { + static constexpr std::array table = { + 63, 0, 58, 1, 59, 47, 53, 2, 60, 39, 48, 27, 54, 33, 42, 3, + 61, 51, 37, 40, 49, 18, 28, 20, 55, 30, 34, 11, 43, 14, 22, 4, + 62, 57, 46, 52, 38, 26, 32, 41, 50, 36, 17, 19, 29, 10, 13, 21, + 56, 45, 25, 31, 35, 16, 9, 12, 44, 24, 15, 8, 23, 7, 6, 5}; + value |= value >> 1; + value |= value >> 2; + value |= value >> 4; + value |= value >> 8; + value |= value >> 16; + value |= value >> 32; + return table[((value - (value >> 1)) * 0x07EDD5E59A4E28C2) >> 58]; +} + +inline uint64_t next_power_of_two(uint64_t i) { + --i; + i |= i >> 1; + i |= i >> 2; + i |= i >> 4; + i |= i >> 8; + i |= i >> 16; + i |= i >> 32; + ++i; + return i; +} + +// Implementation taken from http://en.cppreference.com/w/cpp/types/void_t +// (it takes CWG1558 into account and also works for older compilers) +template +struct make_void { + typedef void type; +}; +template +using void_t = typename make_void::type; + +template +struct HashPolicySelector { + typedef fibonacci_hash_policy type; +}; +template +struct HashPolicySelector> { + typedef typename T::hash_policy type; +}; + +template < + typename T, + typename FindKey, + typename ArgumentHash, + typename Hasher, + typename ArgumentEqual, + typename Equal, + typename ArgumentAlloc, + typename EntryAlloc> +class sherwood_v3_table : private EntryAlloc, private Hasher, private Equal { + using Entry = detailv3::sherwood_v3_entry; + using AllocatorTraits = std::allocator_traits; + using EntryPointer = typename AllocatorTraits::pointer; + + public: + struct convertible_to_iterator; + + using value_type = T; + using size_type = uint64_t; + using difference_type = std::ptrdiff_t; + using hasher = ArgumentHash; + using key_equal = ArgumentEqual; + using allocator_type = EntryAlloc; + using reference = value_type&; + using const_reference = const value_type&; + using pointer = value_type*; + using const_pointer = const value_type*; + + sherwood_v3_table() = default; + explicit sherwood_v3_table( + size_type bucket_count, + const ArgumentHash& hash = ArgumentHash(), + const ArgumentEqual& equal = ArgumentEqual(), + const ArgumentAlloc& alloc = ArgumentAlloc()) + : EntryAlloc(alloc), Hasher(hash), Equal(equal) { + rehash(bucket_count); + } + sherwood_v3_table(size_type bucket_count, const ArgumentAlloc& alloc) + : sherwood_v3_table( + bucket_count, + ArgumentHash(), + ArgumentEqual(), + alloc) {} + sherwood_v3_table( + size_type bucket_count, + const ArgumentHash& hash, + const ArgumentAlloc& alloc) + : sherwood_v3_table(bucket_count, hash, ArgumentEqual(), alloc) {} + explicit sherwood_v3_table(const ArgumentAlloc& alloc) : EntryAlloc(alloc) {} + template + sherwood_v3_table( + It first, + It last, + size_type bucket_count = 0, + const ArgumentHash& hash = ArgumentHash(), + const ArgumentEqual& equal = ArgumentEqual(), + const ArgumentAlloc& alloc = ArgumentAlloc()) + : sherwood_v3_table(bucket_count, hash, equal, alloc) { + insert(first, last); + } + template + sherwood_v3_table( + It first, + It last, + size_type bucket_count, + const ArgumentAlloc& alloc) + : sherwood_v3_table( + first, + last, + bucket_count, + ArgumentHash(), + ArgumentEqual(), + alloc) {} + template + sherwood_v3_table( + It first, + It last, + size_type bucket_count, + const ArgumentHash& hash, + const ArgumentAlloc& alloc) + : sherwood_v3_table( + first, + last, + bucket_count, + hash, + ArgumentEqual(), + alloc) {} + sherwood_v3_table( + std::initializer_list il, + size_type bucket_count = 0, + const ArgumentHash& hash = ArgumentHash(), + const ArgumentEqual& equal = ArgumentEqual(), + const ArgumentAlloc& alloc = ArgumentAlloc()) + : sherwood_v3_table(bucket_count, hash, equal, alloc) { + if (bucket_count == 0) + rehash(il.size()); + insert(il.begin(), il.end()); + } + sherwood_v3_table( + std::initializer_list il, + size_type bucket_count, + const ArgumentAlloc& alloc) + : sherwood_v3_table( + il, + bucket_count, + ArgumentHash(), + ArgumentEqual(), + alloc) {} + sherwood_v3_table( + std::initializer_list il, + size_type bucket_count, + const ArgumentHash& hash, + const ArgumentAlloc& alloc) + : sherwood_v3_table(il, bucket_count, hash, ArgumentEqual(), alloc) {} + sherwood_v3_table(const sherwood_v3_table& other) + : sherwood_v3_table( + other, + AllocatorTraits::select_on_container_copy_construction( + other.get_allocator())) {} + sherwood_v3_table(const sherwood_v3_table& other, const ArgumentAlloc& alloc) + : EntryAlloc(alloc), + Hasher(other), + Equal(other), + _max_load_factor(other._max_load_factor) { + rehash_for_other_container(other); + try { + insert(other.begin(), other.end()); + } catch (...) { + clear(); + deallocate_data(entries, num_slots_minus_one, max_lookups); + throw; + } + } + sherwood_v3_table(sherwood_v3_table&& other) noexcept + : EntryAlloc(std::move(other)), + Hasher(std::move(other)), + Equal(std::move(other)) { + swap_pointers(other); + } + sherwood_v3_table( + sherwood_v3_table&& other, + const ArgumentAlloc& alloc) noexcept + : EntryAlloc(alloc), Hasher(std::move(other)), Equal(std::move(other)) { + swap_pointers(other); + } + sherwood_v3_table& operator=(const sherwood_v3_table& other) { + if (this == std::addressof(other)) + return *this; + + clear(); + if constexpr (AllocatorTraits::propagate_on_container_copy_assignment:: + value) { + if (static_cast(*this) != + static_cast(other)) { + reset_to_empty_state(); + } + static_cast(*this) = other; + } + _max_load_factor = other._max_load_factor; + static_cast(*this) = other; + static_cast(*this) = other; + rehash_for_other_container(other); + insert(other.begin(), other.end()); + return *this; + } + sherwood_v3_table& operator=(sherwood_v3_table&& other) noexcept { + if (this == std::addressof(other)) + return *this; + else if constexpr (AllocatorTraits::propagate_on_container_move_assignment:: + value) { + clear(); + reset_to_empty_state(); + static_cast(*this) = std::move(other); + swap_pointers(other); + } else if ( + static_cast(*this) == static_cast(other)) { + swap_pointers(other); + } else { + clear(); + _max_load_factor = other._max_load_factor; + rehash_for_other_container(other); + for (T& elem : other) + emplace(std::move(elem)); + other.clear(); + } + static_cast(*this) = std::move(other); + static_cast(*this) = std::move(other); + return *this; + } + ~sherwood_v3_table() { + clear(); + deallocate_data(entries, num_slots_minus_one, max_lookups); + } + + const allocator_type& get_allocator() const { + return static_cast(*this); + } + const ArgumentEqual& key_eq() const { + return static_cast(*this); + } + const ArgumentHash& hash_function() const { + return static_cast(*this); + } + + template + struct templated_iterator { + templated_iterator() = default; + templated_iterator(EntryPointer current) : current(current) {} + EntryPointer current = EntryPointer(); + + using iterator_category = std::forward_iterator_tag; + using value_type = ValueType; + using difference_type = ptrdiff_t; + using pointer = ValueType*; + using reference = ValueType&; + + friend bool operator==( + const templated_iterator& lhs, + const templated_iterator& rhs) { + return lhs.current == rhs.current; + } + friend bool operator!=( + const templated_iterator& lhs, + const templated_iterator& rhs) { + return !(lhs == rhs); + } + + templated_iterator& operator++() { + current = current->next; + return *this; + } + templated_iterator operator++(int) { + templated_iterator copy(*this); + ++*this; + return copy; + } + + ValueType& operator*() const { + return current->value; + } + ValueType* operator->() const { + return std::addressof(current->value); + } + + // the template automatically disables the operator when value_type is + // already const, because that would cause a lot of compiler warnings + // otherwise. + template < + class target_type = const value_type, + class = std::enable_if_t< + std::is_same_v && + !std::is_same_v>> + operator templated_iterator() const { + return {current}; + } + }; + using iterator = templated_iterator; + using const_iterator = templated_iterator; + + iterator begin() { + return sentinel->next; + } + const_iterator begin() const { + return sentinel->next; + } + const_iterator cbegin() const { + return begin(); + } + iterator end() { + return sentinel; + } + const_iterator end() const { + return sentinel; + } + const_iterator cend() const { + return end(); + } + + iterator find(const FindKey& key) { + uint64_t index = + hash_policy.index_for_hash(hash_object(key), num_slots_minus_one); + EntryPointer it = entries + ptrdiff_t(index); + for (int8_t distance = 0; it->distance_from_desired >= distance; + ++distance, ++it) { + if (compares_equal(key, it->value)) + return {it}; + } + return end(); + } + const_iterator find(const FindKey& key) const { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) + return const_cast(this)->find(key); + } + uint64_t count(const FindKey& key) const { + return find(key) == end() ? 0 : 1; + } + std::pair equal_range(const FindKey& key) { + iterator found = find(key); + if (found == end()) + return {found, found}; + else + return {found, std::next(found)}; + } + std::pair equal_range( + const FindKey& key) const { + const_iterator found = find(key); + if (found == end()) + return {found, found}; + else + return {found, std::next(found)}; + } + + template + std::pair emplace(Key&& key, Args&&... args) { + uint64_t index = + hash_policy.index_for_hash(hash_object(key), num_slots_minus_one); + EntryPointer current_entry = entries + ptrdiff_t(index); + int8_t distance_from_desired = 0; + for (; current_entry->distance_from_desired >= distance_from_desired; + ++current_entry, ++distance_from_desired) { + // insertion of an existing key does not change ordering + if (compares_equal(key, current_entry->value)) + return {{current_entry}, false}; + } + return emplace_new_key( + distance_from_desired, + current_entry, + std::forward(key), + std::forward(args)...); + } + + std::pair insert(const value_type& value) { + return emplace(value); + } + std::pair insert(value_type&& value) { + return emplace(std::move(value)); + } + template + iterator emplace_hint(const_iterator /*unused*/, Args&&... args) { + return emplace(std::forward(args)...).first; + } + iterator insert(const_iterator /*unused*/, const value_type& value) { + return emplace(value).first; + } + iterator insert(const_iterator /*unused*/, value_type&& value) { + return emplace(std::move(value)).first; + } + + template + void insert(It begin, It end) { + for (; begin != end; ++begin) { + emplace(*begin); + } + } + void insert(std::initializer_list il) { + insert(il.begin(), il.end()); + } + + void rehash(uint64_t num_buckets) { + num_buckets = std::max( + num_buckets, + static_cast(std::ceil( + static_cast(num_elements) / + static_cast(_max_load_factor)))); + if (num_buckets == 0) { + reset_to_empty_state(); + return; + } + auto new_prime_index = hash_policy.next_size_over(num_buckets); + if (num_buckets == bucket_count()) + return; + int8_t new_max_lookups = compute_max_lookups(num_buckets); + EntryPointer new_buckets( + AllocatorTraits::allocate(*this, num_buckets + new_max_lookups)); + EntryPointer special_end_item = + new_buckets + static_cast(num_buckets + new_max_lookups - 1); + for (EntryPointer it = new_buckets; it != special_end_item; ++it) + it->distance_from_desired = -1; + special_end_item->distance_from_desired = Entry::special_end_value; + std::swap(entries, new_buckets); + std::swap(num_slots_minus_one, num_buckets); + --num_slots_minus_one; + hash_policy.commit(new_prime_index); + int8_t old_max_lookups = max_lookups; + max_lookups = new_max_lookups; + num_elements = 0; + + auto start = sentinel->next; + // point sentinel to itself; + reset_list(); + // reinsert list + for (EntryPointer it = start; it != sentinel;) { + auto next = it->next; + emplace(std::move(it->value)); + it->destroy_value(); + it = next; + } + + deallocate_data(new_buckets, num_buckets, old_max_lookups); + } + + void reserve(uint64_t num_elements_) { + uint64_t required_buckets = num_buckets_for_reserve(num_elements_); + if (required_buckets > bucket_count()) + rehash(required_buckets); + } + + void replace_linked_list_position( + EntryPointer to_be_replaced, + EntryPointer new_node) { + remove_from_list(new_node); + insert_after(new_node, to_be_replaced->prev); + remove_from_list(to_be_replaced); + } + + // the return value is a type that can be converted to an iterator + // the reason for doing this is that it's not free to find the + // iterator pointing at the next element. if you care about the + // next iterator, turn the return value into an iterator + convertible_to_iterator erase(const_iterator to_erase) { + EntryPointer current = to_erase.current; + remove_from_list(current); + current->destroy_value(); + --num_elements; + + for (EntryPointer next = current + ptrdiff_t(1); + !next->is_at_desired_position(); + ++current, ++next) { + // if an entry is being removed, and there are other entries with the + // same hash, the other entries get moved to their desired position by + // reinserting. + current->emplace(next->distance_from_desired - 1, std::move(next->value)); + replace_linked_list_position(next, current); + next->destroy_value(); + } + return {to_erase.current}; + } + + iterator erase(const_iterator begin_it, const_iterator end_it) { + // whenever an entry is removed and there are other entries with the same + // hash, the other entries must get moved to their desired position. + // any reference to a moved entry is invalidated. + // here, we iterate through the range, and make sure that we update + // the pointer to our next entry in the list or the end of the iterator + // when it is invalidated. + + auto curr_iter = begin_it.current; + auto next_iter = curr_iter->next; + auto end_iter = end_it.current; + + while (curr_iter != end_iter) { + remove_from_list(curr_iter); + curr_iter->destroy_value(); + --num_elements; + + for (EntryPointer next_hash_slot = curr_iter + ptrdiff_t(1); + !next_hash_slot->is_at_desired_position(); + ++curr_iter, ++next_hash_slot) { + curr_iter->emplace( + next_hash_slot->distance_from_desired - 1, + std::move(next_hash_slot->value)); + replace_linked_list_position(next_hash_slot, curr_iter); + next_hash_slot->destroy_value(); + + // we are invalidating next_iter or end_iter + if (next_hash_slot == end_iter) { + end_iter = curr_iter; + } else if (next_hash_slot == next_iter) { + next_iter = curr_iter; + } + } + curr_iter = next_iter; + next_iter = curr_iter->next; + } + + return {end_iter}; + } + + uint64_t erase(const FindKey& key) { + auto found = find(key); + if (found == end()) + return 0; + else { + erase(found); + return 1; + } + } + + void clear() { + for (EntryPointer it = entries, + end = it + + static_cast(num_slots_minus_one + max_lookups); + it != end; + ++it) { + if (it->has_value()) + it->destroy_value(); + } + reset_list(); + num_elements = 0; + } + + void shrink_to_fit() { + rehash_for_other_container(*this); + } + + void swap(sherwood_v3_table& other) noexcept { + using std::swap; + swap_pointers(other); + swap(static_cast(*this), static_cast(other)); + swap( + static_cast(*this), static_cast(other)); + if (AllocatorTraits::propagate_on_container_swap::value) + swap(static_cast(*this), static_cast(other)); + } + + uint64_t size() const { + return num_elements; + } + uint64_t max_size() const { + return (AllocatorTraits::max_size(*this)) / sizeof(Entry); + } + uint64_t bucket_count() const { + return num_slots_minus_one ? num_slots_minus_one + 1 : 0; + } + size_type max_bucket_count() const { + return (AllocatorTraits::max_size(*this) - min_lookups) / sizeof(Entry); + } + uint64_t bucket(const FindKey& key) const { + return hash_policy.index_for_hash(hash_object(key), num_slots_minus_one); + } + float load_factor() const { + uint64_t buckets = bucket_count(); + if (buckets) + return static_cast(num_elements) / bucket_count(); + else + return 0; + } + void max_load_factor(float value) { + _max_load_factor = value; + } + float max_load_factor() const { + return _max_load_factor; + } + + bool empty() const { + return num_elements == 0; + } + + private: + EntryPointer entries = empty_default_table(); + uint64_t num_slots_minus_one = 0; + typename HashPolicySelector::type hash_policy; + int8_t max_lookups = detailv3::min_lookups - 1; + float _max_load_factor = 0.5f; + uint64_t num_elements = 0; + std::unique_ptr> sentinel_val; + + // head of doubly linked list + EntryPointer sentinel = initSentinel(); + + EntryPointer initSentinel() { + // needs to be a pointer so that hash map can be used with forward declared + // types + sentinel_val = std::make_unique>(); + sentinel = sentinel_val.get(); + reset_list(); + return sentinel; + } + + EntryPointer empty_default_table() { + EntryPointer result = + AllocatorTraits::allocate(*this, detailv3::min_lookups); + EntryPointer special_end_item = + result + static_cast(detailv3::min_lookups - 1); + for (EntryPointer it = result; it != special_end_item; ++it) + it->distance_from_desired = -1; + special_end_item->distance_from_desired = Entry::special_end_value; + return result; + } + + static int8_t compute_max_lookups(uint64_t num_buckets) { + int8_t desired = detailv3::log2(num_buckets); + return std::max(detailv3::min_lookups, desired); + } + + uint64_t num_buckets_for_reserve(uint64_t num_elements_) const { + return static_cast(std::ceil( + static_cast(num_elements_) / + std::min(0.5, static_cast(_max_load_factor)))); + } + void rehash_for_other_container(const sherwood_v3_table& other) { + rehash( + std::min(num_buckets_for_reserve(other.size()), other.bucket_count())); + } + + void swap_pointers(sherwood_v3_table& other) { + using std::swap; + swap(hash_policy, other.hash_policy); + swap(entries, other.entries); + swap(num_slots_minus_one, other.num_slots_minus_one); + swap(num_elements, other.num_elements); + swap(max_lookups, other.max_lookups); + swap(_max_load_factor, other._max_load_factor); + swap(sentinel, other.sentinel); + swap(sentinel_val, other.sentinel_val); + } + + void reset_list() { + sentinel->next = sentinel; + sentinel->prev = sentinel; + } + + void remove_from_list(EntryPointer elem) { + elem->prev->next = elem->next; + elem->next->prev = elem->prev; + } + + void insert_after(EntryPointer new_elem, EntryPointer prev) { + auto next = prev->next; + + prev->next = new_elem; + new_elem->prev = prev; + + new_elem->next = next; + next->prev = new_elem; + } + + void swap_adjacent_nodes(EntryPointer before, EntryPointer after) { + // sentinel stays constant, so before->prev cannot equal after + auto before_prev = before->prev; + auto after_next = after->next; + + before_prev->next = after; + after->prev = before_prev; + + after_next->prev = before; + before->next = after_next; + + before->prev = after; + after->next = before; + } + + void swap_positions(EntryPointer p1, EntryPointer p2) { + if (p1 == p2) { + return; + } + if (p1->next == p2) { + return swap_adjacent_nodes(p1, p2); + } else if (p2->next == p1) { + return swap_adjacent_nodes(p2, p1); + } + + auto p1_prev = p1->prev; + auto p1_next = p1->next; + + auto p2_prev = p2->prev; + auto p2_next = p2->next; + + p1_prev->next = p2; + p2->prev = p1_prev; + + p1_next->prev = p2; + p2->next = p1_next; + + p2_prev->next = p1; + p1->prev = p2_prev; + + p2_next->prev = p1; + p1->next = p2_next; + } + + void append_to_list(EntryPointer new_tail) { + insert_after(new_tail, sentinel->prev); + } + + template + SKA_NOINLINE(std::pair) + emplace_new_key( + int8_t distance_from_desired, + EntryPointer current_entry, + Key&& key, + Args&&... args) { + using std::swap; + if (num_slots_minus_one == 0 || distance_from_desired == max_lookups || + static_cast(num_elements + 1) > + static_cast(num_slots_minus_one + 1) * + static_cast(_max_load_factor)) { + grow(); + return emplace(std::forward(key), std::forward(args)...); + } else if (current_entry->is_empty()) { + current_entry->emplace( + distance_from_desired, + std::forward(key), + std::forward(args)...); + ++num_elements; + append_to_list(current_entry); + return {{current_entry}, true}; + } + value_type to_insert(std::forward(key), std::forward(args)...); + swap(distance_from_desired, current_entry->distance_from_desired); + // We maintain the invariant that: + // - result.current_entry contains the new value we're inserting + // and is in the LinkedList position of to_insert + // - to_insert contains the value that represents the position of + // result.current_entry + swap(to_insert, current_entry->value); + iterator result = {current_entry}; + for (++distance_from_desired, ++current_entry;; ++current_entry) { + if (current_entry->is_empty()) { + current_entry->emplace(distance_from_desired, std::move(to_insert)); + append_to_list(current_entry); + // now we can swap back the displaced value to its correct position, + // putting the new value we're inserting to the front of the list + swap_positions(current_entry, result.current); + ++num_elements; + return {result, true}; + } else if (current_entry->distance_from_desired < distance_from_desired) { + swap(distance_from_desired, current_entry->distance_from_desired); + swap(to_insert, current_entry->value); + // to maintain our invariants we need to swap positions + // of result.current & current_entry: + swap_positions(result.current, current_entry); + ++distance_from_desired; + } else { + ++distance_from_desired; + if (distance_from_desired == max_lookups) { + // the displaced element gets put back into its correct position + // we grow the hash table, and then try again to reinsert the new + // element + swap(to_insert, result.current->value); + grow(); + return emplace(std::move(to_insert)); + } + } + } + } + + void grow() { + rehash(std::max(uint64_t(4), 2 * bucket_count())); + } + + void deallocate_data( + EntryPointer begin, + uint64_t num_slots_minus_one_, + int8_t max_lookups_) { + AllocatorTraits::deallocate( + *this, begin, num_slots_minus_one_ + max_lookups_ + 1); + } + + void reset_to_empty_state() { + deallocate_data(entries, num_slots_minus_one, max_lookups); + entries = empty_default_table(); + num_slots_minus_one = 0; + hash_policy.reset(); + max_lookups = detailv3::min_lookups - 1; + } + + template + uint64_t hash_object(const U& key) { + return static_cast(*this)(key); + } + template + uint64_t hash_object(const U& key) const { + return static_cast(*this)(key); + } + template + bool compares_equal(const L& lhs, const R& rhs) { + return static_cast(*this)(lhs, rhs); + } + + public: + struct convertible_to_iterator { + EntryPointer it; + + operator iterator() { + if (it->has_value()) + return {it}; + else + return ++iterator{it}; + } + operator const_iterator() { + if (it->has_value()) + return {it}; + else + return ++const_iterator{it}; + } + }; +}; +} // namespace detailv3 + +struct prime_number_hash_policy { + static uint64_t mod0(uint64_t /*unused*/) { + return 0llu; + } + static uint64_t mod2(uint64_t hash) { + return hash % 2llu; + } + static uint64_t mod3(uint64_t hash) { + return hash % 3llu; + } + static uint64_t mod5(uint64_t hash) { + return hash % 5llu; + } + static uint64_t mod7(uint64_t hash) { + return hash % 7llu; + } + static uint64_t mod11(uint64_t hash) { + return hash % 11llu; + } + static uint64_t mod13(uint64_t hash) { + return hash % 13llu; + } + static uint64_t mod17(uint64_t hash) { + return hash % 17llu; + } + static uint64_t mod23(uint64_t hash) { + return hash % 23llu; + } + static uint64_t mod29(uint64_t hash) { + return hash % 29llu; + } + static uint64_t mod37(uint64_t hash) { + return hash % 37llu; + } + static uint64_t mod47(uint64_t hash) { + return hash % 47llu; + } + static uint64_t mod59(uint64_t hash) { + return hash % 59llu; + } + static uint64_t mod73(uint64_t hash) { + return hash % 73llu; + } + static uint64_t mod97(uint64_t hash) { + return hash % 97llu; + } + static uint64_t mod127(uint64_t hash) { + return hash % 127llu; + } + static uint64_t mod151(uint64_t hash) { + return hash % 151llu; + } + static uint64_t mod197(uint64_t hash) { + return hash % 197llu; + } + static uint64_t mod251(uint64_t hash) { + return hash % 251llu; + } + static uint64_t mod313(uint64_t hash) { + return hash % 313llu; + } + static uint64_t mod397(uint64_t hash) { + return hash % 397llu; + } + static uint64_t mod499(uint64_t hash) { + return hash % 499llu; + } + static uint64_t mod631(uint64_t hash) { + return hash % 631llu; + } + static uint64_t mod797(uint64_t hash) { + return hash % 797llu; + } + static uint64_t mod1009(uint64_t hash) { + return hash % 1009llu; + } + static uint64_t mod1259(uint64_t hash) { + return hash % 1259llu; + } + static uint64_t mod1597(uint64_t hash) { + return hash % 1597llu; + } + static uint64_t mod2011(uint64_t hash) { + return hash % 2011llu; + } + static uint64_t mod2539(uint64_t hash) { + return hash % 2539llu; + } + static uint64_t mod3203(uint64_t hash) { + return hash % 3203llu; + } + static uint64_t mod4027(uint64_t hash) { + return hash % 4027llu; + } + static uint64_t mod5087(uint64_t hash) { + return hash % 5087llu; + } + static uint64_t mod6421(uint64_t hash) { + return hash % 6421llu; + } + static uint64_t mod8089(uint64_t hash) { + return hash % 8089llu; + } + static uint64_t mod10193(uint64_t hash) { + return hash % 10193llu; + } + static uint64_t mod12853(uint64_t hash) { + return hash % 12853llu; + } + static uint64_t mod16193(uint64_t hash) { + return hash % 16193llu; + } + static uint64_t mod20399(uint64_t hash) { + return hash % 20399llu; + } + static uint64_t mod25717(uint64_t hash) { + return hash % 25717llu; + } + static uint64_t mod32401(uint64_t hash) { + return hash % 32401llu; + } + static uint64_t mod40823(uint64_t hash) { + return hash % 40823llu; + } + static uint64_t mod51437(uint64_t hash) { + return hash % 51437llu; + } + static uint64_t mod64811(uint64_t hash) { + return hash % 64811llu; + } + static uint64_t mod81649(uint64_t hash) { + return hash % 81649llu; + } + static uint64_t mod102877(uint64_t hash) { + return hash % 102877llu; + } + static uint64_t mod129607(uint64_t hash) { + return hash % 129607llu; + } + static uint64_t mod163307(uint64_t hash) { + return hash % 163307llu; + } + static uint64_t mod205759(uint64_t hash) { + return hash % 205759llu; + } + static uint64_t mod259229(uint64_t hash) { + return hash % 259229llu; + } + static uint64_t mod326617(uint64_t hash) { + return hash % 326617llu; + } + static uint64_t mod411527(uint64_t hash) { + return hash % 411527llu; + } + static uint64_t mod518509(uint64_t hash) { + return hash % 518509llu; + } + static uint64_t mod653267(uint64_t hash) { + return hash % 653267llu; + } + static uint64_t mod823117(uint64_t hash) { + return hash % 823117llu; + } + static uint64_t mod1037059(uint64_t hash) { + return hash % 1037059llu; + } + static uint64_t mod1306601(uint64_t hash) { + return hash % 1306601llu; + } + static uint64_t mod1646237(uint64_t hash) { + return hash % 1646237llu; + } + static uint64_t mod2074129(uint64_t hash) { + return hash % 2074129llu; + } + static uint64_t mod2613229(uint64_t hash) { + return hash % 2613229llu; + } + static uint64_t mod3292489(uint64_t hash) { + return hash % 3292489llu; + } + static uint64_t mod4148279(uint64_t hash) { + return hash % 4148279llu; + } + static uint64_t mod5226491(uint64_t hash) { + return hash % 5226491llu; + } + static uint64_t mod6584983(uint64_t hash) { + return hash % 6584983llu; + } + static uint64_t mod8296553(uint64_t hash) { + return hash % 8296553llu; + } + static uint64_t mod10453007(uint64_t hash) { + return hash % 10453007llu; + } + static uint64_t mod13169977(uint64_t hash) { + return hash % 13169977llu; + } + static uint64_t mod16593127(uint64_t hash) { + return hash % 16593127llu; + } + static uint64_t mod20906033(uint64_t hash) { + return hash % 20906033llu; + } + static uint64_t mod26339969(uint64_t hash) { + return hash % 26339969llu; + } + static uint64_t mod33186281(uint64_t hash) { + return hash % 33186281llu; + } + static uint64_t mod41812097(uint64_t hash) { + return hash % 41812097llu; + } + static uint64_t mod52679969(uint64_t hash) { + return hash % 52679969llu; + } + static uint64_t mod66372617(uint64_t hash) { + return hash % 66372617llu; + } + static uint64_t mod83624237(uint64_t hash) { + return hash % 83624237llu; + } + static uint64_t mod105359939(uint64_t hash) { + return hash % 105359939llu; + } + static uint64_t mod132745199(uint64_t hash) { + return hash % 132745199llu; + } + static uint64_t mod167248483(uint64_t hash) { + return hash % 167248483llu; + } + static uint64_t mod210719881(uint64_t hash) { + return hash % 210719881llu; + } + static uint64_t mod265490441(uint64_t hash) { + return hash % 265490441llu; + } + static uint64_t mod334496971(uint64_t hash) { + return hash % 334496971llu; + } + static uint64_t mod421439783(uint64_t hash) { + return hash % 421439783llu; + } + static uint64_t mod530980861(uint64_t hash) { + return hash % 530980861llu; + } + static uint64_t mod668993977(uint64_t hash) { + return hash % 668993977llu; + } + static uint64_t mod842879579(uint64_t hash) { + return hash % 842879579llu; + } + static uint64_t mod1061961721(uint64_t hash) { + return hash % 1061961721llu; + } + static uint64_t mod1337987929(uint64_t hash) { + return hash % 1337987929llu; + } + static uint64_t mod1685759167(uint64_t hash) { + return hash % 1685759167llu; + } + static uint64_t mod2123923447(uint64_t hash) { + return hash % 2123923447llu; + } + static uint64_t mod2675975881(uint64_t hash) { + return hash % 2675975881llu; + } + static uint64_t mod3371518343(uint64_t hash) { + return hash % 3371518343llu; + } + static uint64_t mod4247846927(uint64_t hash) { + return hash % 4247846927llu; + } + static uint64_t mod5351951779(uint64_t hash) { + return hash % 5351951779llu; + } + static uint64_t mod6743036717(uint64_t hash) { + return hash % 6743036717llu; + } + static uint64_t mod8495693897(uint64_t hash) { + return hash % 8495693897llu; + } + static uint64_t mod10703903591(uint64_t hash) { + return hash % 10703903591llu; + } + static uint64_t mod13486073473(uint64_t hash) { + return hash % 13486073473llu; + } + static uint64_t mod16991387857(uint64_t hash) { + return hash % 16991387857llu; + } + static uint64_t mod21407807219(uint64_t hash) { + return hash % 21407807219llu; + } + static uint64_t mod26972146961(uint64_t hash) { + return hash % 26972146961llu; + } + static uint64_t mod33982775741(uint64_t hash) { + return hash % 33982775741llu; + } + static uint64_t mod42815614441(uint64_t hash) { + return hash % 42815614441llu; + } + static uint64_t mod53944293929(uint64_t hash) { + return hash % 53944293929llu; + } + static uint64_t mod67965551447(uint64_t hash) { + return hash % 67965551447llu; + } + static uint64_t mod85631228929(uint64_t hash) { + return hash % 85631228929llu; + } + static uint64_t mod107888587883(uint64_t hash) { + return hash % 107888587883llu; + } + static uint64_t mod135931102921(uint64_t hash) { + return hash % 135931102921llu; + } + static uint64_t mod171262457903(uint64_t hash) { + return hash % 171262457903llu; + } + static uint64_t mod215777175787(uint64_t hash) { + return hash % 215777175787llu; + } + static uint64_t mod271862205833(uint64_t hash) { + return hash % 271862205833llu; + } + static uint64_t mod342524915839(uint64_t hash) { + return hash % 342524915839llu; + } + static uint64_t mod431554351609(uint64_t hash) { + return hash % 431554351609llu; + } + static uint64_t mod543724411781(uint64_t hash) { + return hash % 543724411781llu; + } + static uint64_t mod685049831731(uint64_t hash) { + return hash % 685049831731llu; + } + static uint64_t mod863108703229(uint64_t hash) { + return hash % 863108703229llu; + } + static uint64_t mod1087448823553(uint64_t hash) { + return hash % 1087448823553llu; + } + static uint64_t mod1370099663459(uint64_t hash) { + return hash % 1370099663459llu; + } + static uint64_t mod1726217406467(uint64_t hash) { + return hash % 1726217406467llu; + } + static uint64_t mod2174897647073(uint64_t hash) { + return hash % 2174897647073llu; + } + static uint64_t mod2740199326961(uint64_t hash) { + return hash % 2740199326961llu; + } + static uint64_t mod3452434812973(uint64_t hash) { + return hash % 3452434812973llu; + } + static uint64_t mod4349795294267(uint64_t hash) { + return hash % 4349795294267llu; + } + static uint64_t mod5480398654009(uint64_t hash) { + return hash % 5480398654009llu; + } + static uint64_t mod6904869625999(uint64_t hash) { + return hash % 6904869625999llu; + } + static uint64_t mod8699590588571(uint64_t hash) { + return hash % 8699590588571llu; + } + static uint64_t mod10960797308051(uint64_t hash) { + return hash % 10960797308051llu; + } + static uint64_t mod13809739252051(uint64_t hash) { + return hash % 13809739252051llu; + } + static uint64_t mod17399181177241(uint64_t hash) { + return hash % 17399181177241llu; + } + static uint64_t mod21921594616111(uint64_t hash) { + return hash % 21921594616111llu; + } + static uint64_t mod27619478504183(uint64_t hash) { + return hash % 27619478504183llu; + } + static uint64_t mod34798362354533(uint64_t hash) { + return hash % 34798362354533llu; + } + static uint64_t mod43843189232363(uint64_t hash) { + return hash % 43843189232363llu; + } + static uint64_t mod55238957008387(uint64_t hash) { + return hash % 55238957008387llu; + } + static uint64_t mod69596724709081(uint64_t hash) { + return hash % 69596724709081llu; + } + static uint64_t mod87686378464759(uint64_t hash) { + return hash % 87686378464759llu; + } + static uint64_t mod110477914016779(uint64_t hash) { + return hash % 110477914016779llu; + } + static uint64_t mod139193449418173(uint64_t hash) { + return hash % 139193449418173llu; + } + static uint64_t mod175372756929481(uint64_t hash) { + return hash % 175372756929481llu; + } + static uint64_t mod220955828033581(uint64_t hash) { + return hash % 220955828033581llu; + } + static uint64_t mod278386898836457(uint64_t hash) { + return hash % 278386898836457llu; + } + static uint64_t mod350745513859007(uint64_t hash) { + return hash % 350745513859007llu; + } + static uint64_t mod441911656067171(uint64_t hash) { + return hash % 441911656067171llu; + } + static uint64_t mod556773797672909(uint64_t hash) { + return hash % 556773797672909llu; + } + static uint64_t mod701491027718027(uint64_t hash) { + return hash % 701491027718027llu; + } + static uint64_t mod883823312134381(uint64_t hash) { + return hash % 883823312134381llu; + } + static uint64_t mod1113547595345903(uint64_t hash) { + return hash % 1113547595345903llu; + } + static uint64_t mod1402982055436147(uint64_t hash) { + return hash % 1402982055436147llu; + } + static uint64_t mod1767646624268779(uint64_t hash) { + return hash % 1767646624268779llu; + } + static uint64_t mod2227095190691797(uint64_t hash) { + return hash % 2227095190691797llu; + } + static uint64_t mod2805964110872297(uint64_t hash) { + return hash % 2805964110872297llu; + } + static uint64_t mod3535293248537579(uint64_t hash) { + return hash % 3535293248537579llu; + } + static uint64_t mod4454190381383713(uint64_t hash) { + return hash % 4454190381383713llu; + } + static uint64_t mod5611928221744609(uint64_t hash) { + return hash % 5611928221744609llu; + } + static uint64_t mod7070586497075177(uint64_t hash) { + return hash % 7070586497075177llu; + } + static uint64_t mod8908380762767489(uint64_t hash) { + return hash % 8908380762767489llu; + } + static uint64_t mod11223856443489329(uint64_t hash) { + return hash % 11223856443489329llu; + } + static uint64_t mod14141172994150357(uint64_t hash) { + return hash % 14141172994150357llu; + } + static uint64_t mod17816761525534927(uint64_t hash) { + return hash % 17816761525534927llu; + } + static uint64_t mod22447712886978529(uint64_t hash) { + return hash % 22447712886978529llu; + } + static uint64_t mod28282345988300791(uint64_t hash) { + return hash % 28282345988300791llu; + } + static uint64_t mod35633523051069991(uint64_t hash) { + return hash % 35633523051069991llu; + } + static uint64_t mod44895425773957261(uint64_t hash) { + return hash % 44895425773957261llu; + } + static uint64_t mod56564691976601587(uint64_t hash) { + return hash % 56564691976601587llu; + } + static uint64_t mod71267046102139967(uint64_t hash) { + return hash % 71267046102139967llu; + } + static uint64_t mod89790851547914507(uint64_t hash) { + return hash % 89790851547914507llu; + } + static uint64_t mod113129383953203213(uint64_t hash) { + return hash % 113129383953203213llu; + } + static uint64_t mod142534092204280003(uint64_t hash) { + return hash % 142534092204280003llu; + } + static uint64_t mod179581703095829107(uint64_t hash) { + return hash % 179581703095829107llu; + } + static uint64_t mod226258767906406483(uint64_t hash) { + return hash % 226258767906406483llu; + } + static uint64_t mod285068184408560057(uint64_t hash) { + return hash % 285068184408560057llu; + } + static uint64_t mod359163406191658253(uint64_t hash) { + return hash % 359163406191658253llu; + } + static uint64_t mod452517535812813007(uint64_t hash) { + return hash % 452517535812813007llu; + } + static uint64_t mod570136368817120201(uint64_t hash) { + return hash % 570136368817120201llu; + } + static uint64_t mod718326812383316683(uint64_t hash) { + return hash % 718326812383316683llu; + } + static uint64_t mod905035071625626043(uint64_t hash) { + return hash % 905035071625626043llu; + } + static uint64_t mod1140272737634240411(uint64_t hash) { + return hash % 1140272737634240411llu; + } + static uint64_t mod1436653624766633509(uint64_t hash) { + return hash % 1436653624766633509llu; + } + static uint64_t mod1810070143251252131(uint64_t hash) { + return hash % 1810070143251252131llu; + } + static uint64_t mod2280545475268481167(uint64_t hash) { + return hash % 2280545475268481167llu; + } + static uint64_t mod2873307249533267101(uint64_t hash) { + return hash % 2873307249533267101llu; + } + static uint64_t mod3620140286502504283(uint64_t hash) { + return hash % 3620140286502504283llu; + } + static uint64_t mod4561090950536962147(uint64_t hash) { + return hash % 4561090950536962147llu; + } + static uint64_t mod5746614499066534157(uint64_t hash) { + return hash % 5746614499066534157llu; + } + static uint64_t mod7240280573005008577(uint64_t hash) { + return hash % 7240280573005008577llu; + } + static uint64_t mod9122181901073924329(uint64_t hash) { + return hash % 9122181901073924329llu; + } + static uint64_t mod11493228998133068689(uint64_t hash) { + return hash % 11493228998133068689llu; + } + static uint64_t mod14480561146010017169(uint64_t hash) { + return hash % 14480561146010017169llu; + } + static uint64_t mod18446744073709551557(uint64_t hash) { + return hash % 18446744073709551557llu; + } + + using mod_function = uint64_t (*)(uint64_t); + + mod_function next_size_over(uint64_t& size) const { + // prime numbers generated by the following method: + // 1. start with a prime p = 2 + // 2. go to wolfram alpha and get p = NextPrime(2 * p) + // 3. repeat 2. until you overflow 64 bits + // you now have large gaps which you would hit if somebody called reserve() + // with an unlucky number. + // 4. to fill the gaps for every prime p go to wolfram alpha and get + // ClosestPrime(p * 2^(1/3)) and ClosestPrime(p * 2^(2/3)) and put those in + // the gaps + // 5. get PrevPrime(2^64) and put it at the end + // NOLINTNEXTLINE(*c-array*) + static constexpr const uint64_t prime_list[] = { + 2llu, + 3llu, + 5llu, + 7llu, + 11llu, + 13llu, + 17llu, + 23llu, + 29llu, + 37llu, + 47llu, + 59llu, + 73llu, + 97llu, + 127llu, + 151llu, + 197llu, + 251llu, + 313llu, + 397llu, + 499llu, + 631llu, + 797llu, + 1009llu, + 1259llu, + 1597llu, + 2011llu, + 2539llu, + 3203llu, + 4027llu, + 5087llu, + 6421llu, + 8089llu, + 10193llu, + 12853llu, + 16193llu, + 20399llu, + 25717llu, + 32401llu, + 40823llu, + 51437llu, + 64811llu, + 81649llu, + 102877llu, + 129607llu, + 163307llu, + 205759llu, + 259229llu, + 326617llu, + 411527llu, + 518509llu, + 653267llu, + 823117llu, + 1037059llu, + 1306601llu, + 1646237llu, + 2074129llu, + 2613229llu, + 3292489llu, + 4148279llu, + 5226491llu, + 6584983llu, + 8296553llu, + 10453007llu, + 13169977llu, + 16593127llu, + 20906033llu, + 26339969llu, + 33186281llu, + 41812097llu, + 52679969llu, + 66372617llu, + 83624237llu, + 105359939llu, + 132745199llu, + 167248483llu, + 210719881llu, + 265490441llu, + 334496971llu, + 421439783llu, + 530980861llu, + 668993977llu, + 842879579llu, + 1061961721llu, + 1337987929llu, + 1685759167llu, + 2123923447llu, + 2675975881llu, + 3371518343llu, + 4247846927llu, + 5351951779llu, + 6743036717llu, + 8495693897llu, + 10703903591llu, + 13486073473llu, + 16991387857llu, + 21407807219llu, + 26972146961llu, + 33982775741llu, + 42815614441llu, + 53944293929llu, + 67965551447llu, + 85631228929llu, + 107888587883llu, + 135931102921llu, + 171262457903llu, + 215777175787llu, + 271862205833llu, + 342524915839llu, + 431554351609llu, + 543724411781llu, + 685049831731llu, + 863108703229llu, + 1087448823553llu, + 1370099663459llu, + 1726217406467llu, + 2174897647073llu, + 2740199326961llu, + 3452434812973llu, + 4349795294267llu, + 5480398654009llu, + 6904869625999llu, + 8699590588571llu, + 10960797308051llu, + 13809739252051llu, + 17399181177241llu, + 21921594616111llu, + 27619478504183llu, + 34798362354533llu, + 43843189232363llu, + 55238957008387llu, + 69596724709081llu, + 87686378464759llu, + 110477914016779llu, + 139193449418173llu, + 175372756929481llu, + 220955828033581llu, + 278386898836457llu, + 350745513859007llu, + 441911656067171llu, + 556773797672909llu, + 701491027718027llu, + 883823312134381llu, + 1113547595345903llu, + 1402982055436147llu, + 1767646624268779llu, + 2227095190691797llu, + 2805964110872297llu, + 3535293248537579llu, + 4454190381383713llu, + 5611928221744609llu, + 7070586497075177llu, + 8908380762767489llu, + 11223856443489329llu, + 14141172994150357llu, + 17816761525534927llu, + 22447712886978529llu, + 28282345988300791llu, + 35633523051069991llu, + 44895425773957261llu, + 56564691976601587llu, + 71267046102139967llu, + 89790851547914507llu, + 113129383953203213llu, + 142534092204280003llu, + 179581703095829107llu, + 226258767906406483llu, + 285068184408560057llu, + 359163406191658253llu, + 452517535812813007llu, + 570136368817120201llu, + 718326812383316683llu, + 905035071625626043llu, + 1140272737634240411llu, + 1436653624766633509llu, + 1810070143251252131llu, + 2280545475268481167llu, + 2873307249533267101llu, + 3620140286502504283llu, + 4561090950536962147llu, + 5746614499066534157llu, + 7240280573005008577llu, + 9122181901073924329llu, + 11493228998133068689llu, + 14480561146010017169llu, + 18446744073709551557llu}; + // NOLINTNEXTLINE(*c-array*) + static constexpr uint64_t (*const mod_functions[])(uint64_t) = { + &mod0, + &mod2, + &mod3, + &mod5, + &mod7, + &mod11, + &mod13, + &mod17, + &mod23, + &mod29, + &mod37, + &mod47, + &mod59, + &mod73, + &mod97, + &mod127, + &mod151, + &mod197, + &mod251, + &mod313, + &mod397, + &mod499, + &mod631, + &mod797, + &mod1009, + &mod1259, + &mod1597, + &mod2011, + &mod2539, + &mod3203, + &mod4027, + &mod5087, + &mod6421, + &mod8089, + &mod10193, + &mod12853, + &mod16193, + &mod20399, + &mod25717, + &mod32401, + &mod40823, + &mod51437, + &mod64811, + &mod81649, + &mod102877, + &mod129607, + &mod163307, + &mod205759, + &mod259229, + &mod326617, + &mod411527, + &mod518509, + &mod653267, + &mod823117, + &mod1037059, + &mod1306601, + &mod1646237, + &mod2074129, + &mod2613229, + &mod3292489, + &mod4148279, + &mod5226491, + &mod6584983, + &mod8296553, + &mod10453007, + &mod13169977, + &mod16593127, + &mod20906033, + &mod26339969, + &mod33186281, + &mod41812097, + &mod52679969, + &mod66372617, + &mod83624237, + &mod105359939, + &mod132745199, + &mod167248483, + &mod210719881, + &mod265490441, + &mod334496971, + &mod421439783, + &mod530980861, + &mod668993977, + &mod842879579, + &mod1061961721, + &mod1337987929, + &mod1685759167, + &mod2123923447, + &mod2675975881, + &mod3371518343, + &mod4247846927, + &mod5351951779, + &mod6743036717, + &mod8495693897, + &mod10703903591, + &mod13486073473, + &mod16991387857, + &mod21407807219, + &mod26972146961, + &mod33982775741, + &mod42815614441, + &mod53944293929, + &mod67965551447, + &mod85631228929, + &mod107888587883, + &mod135931102921, + &mod171262457903, + &mod215777175787, + &mod271862205833, + &mod342524915839, + &mod431554351609, + &mod543724411781, + &mod685049831731, + &mod863108703229, + &mod1087448823553, + &mod1370099663459, + &mod1726217406467, + &mod2174897647073, + &mod2740199326961, + &mod3452434812973, + &mod4349795294267, + &mod5480398654009, + &mod6904869625999, + &mod8699590588571, + &mod10960797308051, + &mod13809739252051, + &mod17399181177241, + &mod21921594616111, + &mod27619478504183, + &mod34798362354533, + &mod43843189232363, + &mod55238957008387, + &mod69596724709081, + &mod87686378464759, + &mod110477914016779, + &mod139193449418173, + &mod175372756929481, + &mod220955828033581, + &mod278386898836457, + &mod350745513859007, + &mod441911656067171, + &mod556773797672909, + &mod701491027718027, + &mod883823312134381, + &mod1113547595345903, + &mod1402982055436147, + &mod1767646624268779, + &mod2227095190691797, + &mod2805964110872297, + &mod3535293248537579, + &mod4454190381383713, + &mod5611928221744609, + &mod7070586497075177, + &mod8908380762767489, + &mod11223856443489329, + &mod14141172994150357, + &mod17816761525534927, + &mod22447712886978529, + &mod28282345988300791, + &mod35633523051069991, + &mod44895425773957261, + &mod56564691976601587, + &mod71267046102139967, + &mod89790851547914507, + &mod113129383953203213, + &mod142534092204280003, + &mod179581703095829107, + &mod226258767906406483, + &mod285068184408560057, + &mod359163406191658253, + &mod452517535812813007, + &mod570136368817120201, + &mod718326812383316683, + &mod905035071625626043, + &mod1140272737634240411, + &mod1436653624766633509, + &mod1810070143251252131, + &mod2280545475268481167, + &mod2873307249533267101, + &mod3620140286502504283, + &mod4561090950536962147, + &mod5746614499066534157, + &mod7240280573005008577, + &mod9122181901073924329, + &mod11493228998133068689, + &mod14480561146010017169, + &mod18446744073709551557}; + const uint64_t* found = std::lower_bound( + std::begin(prime_list), std::end(prime_list) - 1, size); + size = *found; + return mod_functions[1 + found - prime_list]; + } + void commit(mod_function new_mod_function) { + current_mod_function = new_mod_function; + } + void reset() { + current_mod_function = &mod0; + } + + uint64_t index_for_hash(uint64_t hash, uint64_t /*num_slots_minus_one*/) + const { + return current_mod_function(hash); + } + uint64_t keep_in_range(uint64_t index, uint64_t num_slots_minus_one) const { + return index > num_slots_minus_one ? current_mod_function(index) : index; + } + + private: + mod_function current_mod_function = &mod0; +}; + +struct power_of_two_hash_policy { + uint64_t index_for_hash(uint64_t hash, uint64_t num_slots_minus_one) const { + return hash & num_slots_minus_one; + } + uint64_t keep_in_range(uint64_t index, uint64_t num_slots_minus_one) const { + return index_for_hash(index, num_slots_minus_one); + } + int8_t next_size_over(uint64_t& size) const { + size = detailv3::next_power_of_two(size); + return 0; + } + void commit(int8_t /*unused*/) {} + void reset() {} +}; + +struct fibonacci_hash_policy { + uint64_t index_for_hash(uint64_t hash, uint64_t /*num_slots_minus_one*/) + const { + return (11400714819323198485ull * hash) >> shift; + } + uint64_t keep_in_range(uint64_t index, uint64_t num_slots_minus_one) const { + return index & num_slots_minus_one; + } + + int8_t next_size_over(uint64_t& size) const { + size = std::max(uint64_t(2), detailv3::next_power_of_two(size)); + return static_cast(64 - detailv3::log2(size)); + } + void commit(int8_t shift_) { + shift = shift_; + } + void reset() { + shift = 63; + } + + private: + int8_t shift = 63; +}; + +template < + typename K, + typename V, + typename H = std::hash, + typename E = std::equal_to, + typename A = std::allocator>> +class order_preserving_flat_hash_map + : public detailv3::sherwood_v3_table< + std::pair, + K, + H, + detailv3::KeyOrValueHasher, H>, + E, + detailv3::KeyOrValueEquality, E>, + A, + typename std::allocator_traits::template rebind_alloc< + detailv3::sherwood_v3_entry>>> { + using Table = detailv3::sherwood_v3_table< + std::pair, + K, + H, + detailv3::KeyOrValueHasher, H>, + E, + detailv3::KeyOrValueEquality, E>, + A, + typename std::allocator_traits::template rebind_alloc< + detailv3::sherwood_v3_entry>>>; + + public: + using key_type = K; + using mapped_type = V; + + using Table::Table; + order_preserving_flat_hash_map() = default; + + inline V& operator[](const K& key) { + return emplace(key, convertible_to_value()).first->second; + } + inline V& operator[](K&& key) { + return emplace(std::move(key), convertible_to_value()).first->second; + } + V& at(const K& key) { + auto found = this->find(key); + if (found == this->end()) + throw std::out_of_range("Argument passed to at() was not in the map."); + return found->second; + } + const V& at(const K& key) const { + auto found = this->find(key); + if (found == this->end()) + throw std::out_of_range("Argument passed to at() was not in the map."); + return found->second; + } + + using Table::emplace; + std::pair emplace() { + return emplace(key_type(), convertible_to_value()); + } + template + std::pair insert_or_assign( + const key_type& key, + M&& m) { + auto emplace_result = emplace(key, std::forward(m)); + if (!emplace_result.second) + emplace_result.first->second = std::forward(m); + return emplace_result; + } + template + std::pair insert_or_assign( + key_type&& key, + M&& m) { + auto emplace_result = emplace(std::move(key), std::forward(m)); + if (!emplace_result.second) + emplace_result.first->second = std::forward(m); + return emplace_result; + } + template + typename Table::iterator insert_or_assign( + typename Table::const_iterator /*unused*/, + const key_type& key, + M&& m) { + return insert_or_assign(key, std::forward(m)).first; + } + template + typename Table::iterator insert_or_assign( + typename Table::const_iterator /*unused*/, + key_type&& key, + M&& m) { + return insert_or_assign(std::move(key), std::forward(m)).first; + } + + friend bool operator==( + const order_preserving_flat_hash_map& lhs, + const order_preserving_flat_hash_map& rhs) { + if (lhs.size() != rhs.size()) + return false; + for (const typename Table::value_type& value : lhs) { + auto found = rhs.find(value.first); + if (found == rhs.end() || value.second != found->second) + return false; + } + return true; + } + friend bool operator!=( + const order_preserving_flat_hash_map& lhs, + const order_preserving_flat_hash_map& rhs) { + return !(lhs == rhs); + } + + private: + struct convertible_to_value { + operator V() const { + return V(); + } + }; +}; + +template < + typename T, + typename H = std::hash, + typename E = std::equal_to, + typename A = std::allocator> +class flat_hash_set + : public detailv3::sherwood_v3_table< + T, + T, + H, + detailv3::functor_storage, + E, + detailv3::functor_storage, + A, + typename std::allocator_traits::template rebind_alloc< + detailv3::sherwood_v3_entry>> { + using Table = detailv3::sherwood_v3_table< + T, + T, + H, + detailv3::functor_storage, + E, + detailv3::functor_storage, + A, + typename std::allocator_traits::template rebind_alloc< + detailv3::sherwood_v3_entry>>; + + public: + using key_type = T; + + using Table::Table; + flat_hash_set() = default; + + template + std::pair emplace(Args&&... args) { + return Table::emplace(T(std::forward(args)...)); + } + std::pair emplace(const key_type& arg) { + return Table::emplace(arg); + } + std::pair emplace(key_type& arg) { + return Table::emplace(arg); + } + std::pair emplace(const key_type&& arg) { + return Table::emplace(std::move(arg)); + } + std::pair emplace(key_type&& arg) { + return Table::emplace(std::move(arg)); + } + + friend bool operator==(const flat_hash_set& lhs, const flat_hash_set& rhs) { + if (lhs.size() != rhs.size()) + return false; + for (const T& value : lhs) { + if (rhs.find(value) == rhs.end()) + return false; + } + return true; + } + friend bool operator!=(const flat_hash_set& lhs, const flat_hash_set& rhs) { + return !(lhs == rhs); + } +}; + +template +struct power_of_two_std_hash : std::hash { + typedef ska_ordered::power_of_two_hash_policy hash_policy; +}; + +} // namespace ska_ordered + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/overflows.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/overflows.h new file mode 100644 index 0000000000000000000000000000000000000000..e414de5aaab43b00062139b718067b14be4422ac --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/overflows.h @@ -0,0 +1,105 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include + +#include +#include +#include + +namespace c10 { +// In some versions of MSVC, there will be a compiler error when building. +// C4146: unary minus operator applied to unsigned type, result still unsigned +// C4804: unsafe use of type 'bool' in operation +// It can be addressed by disabling the following warning. +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable : 4146) +#pragma warning(disable : 4804) +#pragma warning(disable : 4018) +#endif + +// The overflow checks may involve float to int conversion which may +// trigger precision loss warning. Re-enable the warning once the code +// is fixed. See T58053069. +C10_CLANG_DIAGNOSTIC_PUSH() +#if C10_CLANG_HAS_WARNING("-Wimplicit-float-conversion") +C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion") +#endif + +// bool can be converted to any type. +// Without specializing on bool, in pytorch_linux_trusty_py2_7_9_build: +// `error: comparison of constant '255' with boolean expression is always false` +// for `f > limit::max()` below +template +std::enable_if_t, bool> overflows( + From /*f*/, + bool strict_unsigned [[maybe_unused]] = false) { + return false; +} + +// skip isnan and isinf check for integral types +template +std::enable_if_t && !std::is_same_v, bool> +overflows(From f, bool strict_unsigned = false) { + using limit = std::numeric_limits::type>; + if constexpr (!limit::is_signed && std::numeric_limits::is_signed) { + // allow for negative numbers to wrap using two's complement arithmetic. + // For example, with uint8, this allows for `a - b` to be treated as + // `a + 255 * b`. + if (!strict_unsigned) { + return greater_than_max(f) || + (c10::is_negative(f) && + -static_cast(f) > static_cast(limit::max())); + } + } + return c10::less_than_lowest(f) || greater_than_max(f); +} + +template +std::enable_if_t, bool> overflows( + From f, + bool strict_unsigned [[maybe_unused]] = false) { + using limit = std::numeric_limits::type>; + if (limit::has_infinity && std::isinf(static_cast(f))) { + return false; + } + if (!limit::has_quiet_NaN && (f != f)) { + return true; + } + return f < limit::lowest() || f > limit::max(); +} + +C10_CLANG_DIAGNOSTIC_POP() + +#ifdef _MSC_VER +#pragma warning(pop) +#endif + +template +std::enable_if_t::value, bool> overflows( + From f, + bool strict_unsigned = false) { + // casts from complex to real are considered to overflow if the + // imaginary component is non-zero + if (!is_complex::value && f.imag() != 0) { + return true; + } + // Check for overflow componentwise + // (Technically, the imag overflow check is guaranteed to be false + // when !is_complex, but any optimizer worth its salt will be + // able to figure it out.) + return overflows< + typename scalar_value_type::type, + typename From::value_type>(f.real(), strict_unsigned) || + overflows< + typename scalar_value_type::type, + typename From::value_type>(f.imag(), strict_unsigned); +} +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/overloaded.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/overloaded.h new file mode 100644 index 0000000000000000000000000000000000000000..9c1571b57e808ab068dd5456e1ea83dfd9fd6342 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/overloaded.h @@ -0,0 +1,36 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +namespace c10 { +namespace detail { + +template +struct overloaded_t {}; + +template +struct overloaded_t : T0 { + using T0::operator(); + overloaded_t(T0 t0) : T0(std::move(t0)) {} +}; +template +struct overloaded_t : T0, overloaded_t { + using T0::operator(); + using overloaded_t::operator(); + overloaded_t(T0 t0, Ts... ts) + : T0(std::move(t0)), overloaded_t(std::move(ts)...) {} +}; + +} // namespace detail + +// Construct an overloaded callable combining multiple callables, e.g. lambdas +template +detail::overloaded_t overloaded(Ts... ts) { + return {std::move(ts)...}; +} + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/python_stub.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/python_stub.h new file mode 100644 index 0000000000000000000000000000000000000000..f457be5949a775e9ce3f4b8b39d8c4bbe95985b8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/python_stub.h @@ -0,0 +1,9 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +struct _object; +using PyObject = _object; + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/qint32.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/qint32.h new file mode 100644 index 0000000000000000000000000000000000000000..2b48a5a89c503e4a3ddae1aee65695044d1a3384 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/qint32.h @@ -0,0 +1,6 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/qint8.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/qint8.h new file mode 100644 index 0000000000000000000000000000000000000000..47f7a9e42540c917299479e9bda73da37083e082 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/qint8.h @@ -0,0 +1,6 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/quint2x4.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/quint2x4.h new file mode 100644 index 0000000000000000000000000000000000000000..b7781bc5772828da4ec97e1db4bbab2b7f54dd42 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/quint2x4.h @@ -0,0 +1,6 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/quint4x2.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/quint4x2.h new file mode 100644 index 0000000000000000000000000000000000000000..b4603a707c35a3a24eee27c4eea54c025f49454b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/quint4x2.h @@ -0,0 +1,6 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/quint8.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/quint8.h new file mode 100644 index 0000000000000000000000000000000000000000..5445be70945ff028d6ad98cff1732b678c7245da --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/quint8.h @@ -0,0 +1,6 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/safe_numerics.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/safe_numerics.h new file mode 100644 index 0000000000000000000000000000000000000000..f376f9dfd8a529851dd45a9319da482eae1c7c60 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/safe_numerics.h @@ -0,0 +1,119 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include + +#include +#include +#include + +// GCC has __builtin_mul_overflow from before it supported __has_builtin +#ifdef _MSC_VER +#define C10_HAS_BUILTIN_OVERFLOW() (0) +#include +#include +#else +#define C10_HAS_BUILTIN_OVERFLOW() (1) +#endif + +namespace c10 { + +template , int> = 0> +C10_ALWAYS_INLINE bool add_overflows(T a, T b, T* out) { +#if C10_HAS_BUILTIN_OVERFLOW() + return __builtin_add_overflow(a, b, out); +#else + if constexpr (std::is_signed_v) { + // For signed types, detect overflow by checking sign changes + volatile T tmp = a + b; + *out = tmp; + + // If both operands have the same sign, check if result changed sign + // unexpectedly. + if ((a > 0) == (b > 0)) { + if ((a > 0) && (tmp <= 0)) { + return true; // Positive overflow + } + if ((a < 0) && (tmp >= 0)) { + return true; // Negative overflow + } + } + return false; + } else { + // For unsigned types, overflow causes wrap-around + volatile T tmp = a + b; + *out = tmp; + return (tmp < a || tmp < b); + } +#endif +} + +C10_ALWAYS_INLINE bool add_overflows(uint64_t a, uint64_t b, uint64_t* out) { + return add_overflows(a, b, out); +} + +template , int> = 0> +C10_ALWAYS_INLINE bool mul_overflows(T a, T b, T* out) { +#if C10_HAS_BUILTIN_OVERFLOW() + return __builtin_mul_overflow(a, b, out); +#else + if constexpr (std::is_signed_v) { + // For signed types, use the division-based check + volatile T tmp = a * b; + *out = tmp; + if (a == 0 || b == 0) { + return false; + } + return !(a == tmp / b); + } else { + // For unsigned types, use leading zeros approach + // This test isn't exact, but avoids doing integer division + *out = a * b; + constexpr int bits = sizeof(T) * 8; + return ( + (c10::llvm::countLeadingZeros(a) + c10::llvm::countLeadingZeros(b)) < + bits); + } +#endif +} + +C10_ALWAYS_INLINE bool mul_overflows(uint64_t a, uint64_t b, uint64_t* out) { + return mul_overflows(a, b, out); +} + +template +bool safe_multiplies_u64(It first, It last, uint64_t* out) { +#if C10_HAS_BUILTIN_OVERFLOW() + uint64_t prod = 1; + bool overflow = false; + for (; first != last; ++first) { + overflow |= c10::mul_overflows(prod, *first, &prod); + } + *out = prod; + return overflow; +#else + uint64_t prod = 1; + uint64_t prod_log2 = 0; + bool is_zero = false; + for (; first != last; ++first) { + auto x = static_cast(*first); + prod *= x; + // log2(0) isn't valid, so need to track it specially + is_zero |= (x == 0); + prod_log2 += c10::llvm::Log2_64_Ceil(x); + } + *out = prod; + // This test isn't exact, but avoids doing integer division + return !is_zero && (prod_log2 >= 64); +#endif +} + +template +bool safe_multiplies_u64(const Container& c, uint64_t* out) { + return safe_multiplies_u64(c.begin(), c.end(), out); +} + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/signal_handler.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/signal_handler.h new file mode 100644 index 0000000000000000000000000000000000000000..60b2c344e0639fb6490a1c300cb77469f111bd62 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/signal_handler.h @@ -0,0 +1,124 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include + +#include + +#if defined(__APPLE__) +#define C10_SUPPORTS_SIGNAL_HANDLER +#elif defined(__linux__) && !defined(C10_DISABLE_SIGNAL_HANDLERS) +#define C10_SUPPORTS_FATAL_SIGNAL_HANDLERS +#define C10_SUPPORTS_SIGNAL_HANDLER +#endif + +#if defined(C10_SUPPORTS_FATAL_SIGNAL_HANDLERS) +#include +#endif + +namespace c10 { + +class C10_API SignalHandler { + public: + enum class Action { NONE, STOP }; + + // Constructor. Specify what action to take when a signal is received. + SignalHandler(Action SIGINT_action, Action SIGHUP_action); + + SignalHandler(const SignalHandler&) = delete; + SignalHandler(SignalHandler&&) = delete; + SignalHandler& operator=(const SignalHandler&) = delete; + SignalHandler& operator=(SignalHandler&&) = delete; + ~SignalHandler(); + + Action CheckForSignals(); + + bool GotSIGINT(); + bool GotSIGHUP(); + + Action SIGINT_action_; + Action SIGHUP_action_; + std::atomic my_sigint_count_; + std::atomic my_sighup_count_; +}; + +#if defined(C10_SUPPORTS_FATAL_SIGNAL_HANDLERS) +class C10_API FatalSignalHandler { + // This works by setting up certain fatal signal handlers. Previous fatal + // signal handlers will still be called when the signal is raised. Defaults + // to being off. + public: + C10_API void setPrintStackTracesOnFatalSignal(bool print); + C10_API bool printStackTracesOnFatalSignal(); + static FatalSignalHandler& getInstance(); + FatalSignalHandler(const FatalSignalHandler&) = delete; + FatalSignalHandler(FatalSignalHandler&&) = delete; + FatalSignalHandler& operator=(const FatalSignalHandler&) = delete; + FatalSignalHandler& operator=(FatalSignalHandler&&) = delete; + virtual ~FatalSignalHandler() = default; + + protected: + explicit FatalSignalHandler(); + + private: + void installFatalSignalHandlers(); + void uninstallFatalSignalHandlers(); + static void fatalSignalHandlerStatic(int signum); + void fatalSignalHandler(int signum); + virtual void fatalSignalHandlerPostProcess(); + struct sigaction* getPreviousSigaction(int signum); + const char* getSignalName(int signum); + void callPreviousSignalHandler( + struct sigaction* action, + int signum, + siginfo_t* info, + void* ctx); + void stacktraceSignalHandler(bool needsLock); + static void stacktraceSignalHandlerStatic( + int signum, + siginfo_t* info, + void* ctx); + void stacktraceSignalHandler(int signum, siginfo_t* info, void* ctx); + + // The mutex protects the bool. + std::mutex fatalSignalHandlersInstallationMutex; + bool fatalSignalHandlersInstalled; + // We need to hold a reference to call the previous SIGUSR2 handler in case + // we didn't signal it + struct sigaction previousSigusr2{}; + // Flag dictating whether the SIGUSR2 handler falls back to previous handlers + // or is intercepted in order to print a stack trace. + std::atomic fatalSignalReceived; + // Global state set when a fatal signal is received so that backtracing + // threads know why they're printing a stacktrace. + const char* fatalSignalName; + int fatalSignum = -1; + // This wait condition is used to wait for other threads to finish writing + // their stack trace when in fatal sig handler (we can't use pthread_join + // because there's no way to convert from a tid to a pthread_t). + std::condition_variable writingCond; + std::mutex writingMutex; + // used to indicate if the other thread responded to the signal + bool signalReceived; + + struct signal_handler { + const char* name; + int signum; + struct sigaction previous; + }; + + // NOLINTNEXTLINE(*c-arrays*) + static signal_handler kSignalHandlers[]; +}; + +#endif // defined(C10_SUPPORTS_SIGNAL_HANDLER) + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/sparse_bitset.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/sparse_bitset.h new file mode 100644 index 0000000000000000000000000000000000000000..877b4fb52f0ed04a6bb555201f0cc58163bfe552 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/sparse_bitset.h @@ -0,0 +1,898 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +//===- llvm/ADT/SparseBitVector.h - Efficient Sparse BitVector --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the SparseBitVector class. See the doxygen comment for +// SparseBitVector for more details on the algorithm used. +// +//===----------------------------------------------------------------------===// + +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include + +namespace c10 { + +/// SparseBitVector is an implementation of a bitvector that is sparse by only +/// storing the elements that have non-zero bits set. In order to make this +/// fast for the most common cases, SparseBitVector is implemented as a linked +/// list of SparseBitVectorElements. We maintain a pointer to the last +/// SparseBitVectorElement accessed (in the form of a list iterator), in order +/// to make multiple in-order test/set constant time after the first one is +/// executed. Note that using vectors to store SparseBitVectorElement's does +/// not work out very well because it causes insertion in the middle to take +/// enormous amounts of time with a large amount of bits. Other structures that +/// have better worst cases for insertion in the middle (various balanced trees, +/// etc) do not perform as well in practice as a linked list with this iterator +/// kept up to date. They are also significantly more memory intensive. + +template +struct SparseBitVectorElement { + public: + using BitWord = unsigned long; + using size_type = unsigned; + enum { + BITWORD_SIZE = sizeof(BitWord) * CHAR_BIT, + BITWORDS_PER_ELEMENT = (ElementSize + BITWORD_SIZE - 1) / BITWORD_SIZE, + BITS_PER_ELEMENT = ElementSize + }; + + private: + // Index of Element in terms of where first bit starts. + unsigned ElementIndex; + std::array Bits{}; + + SparseBitVectorElement() : ElementIndex(~0U) {} + + public: + explicit SparseBitVectorElement(unsigned Idx) : ElementIndex(Idx) {} + + // Comparison. + bool operator==(const SparseBitVectorElement& RHS) const { + if (ElementIndex != RHS.ElementIndex) + return false; + for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i) + if (Bits[i] != RHS.Bits[i]) + return false; + return true; + } + + bool operator!=(const SparseBitVectorElement& RHS) const { + return !(*this == RHS); + } + + // Return the bits that make up word Idx in our element. + BitWord word(unsigned Idx) const { + assert(Idx < BITWORDS_PER_ELEMENT); + return Bits[Idx]; + } + + unsigned index() const { + return ElementIndex; + } + + bool empty() const { + for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i) + if (Bits[i]) + return false; + return true; + } + + void set(unsigned Idx) { + Bits[Idx / BITWORD_SIZE] |= 1L << (Idx % BITWORD_SIZE); + } + + bool test_and_set(unsigned Idx) { + bool old = test(Idx); + if (!old) { + set(Idx); + return true; + } + return false; + } + + void reset(unsigned Idx) { + Bits[Idx / BITWORD_SIZE] &= ~(1L << (Idx % BITWORD_SIZE)); + } + + bool test(unsigned Idx) const { + return Bits[Idx / BITWORD_SIZE] & (1L << (Idx % BITWORD_SIZE)); + } + + size_type count() const { + unsigned NumBits = 0; + for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i) + NumBits += llvm::countPopulation(Bits[i]); + return NumBits; + } + + /// find_first - Returns the index of the first set bit. + int find_first() const { + for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i) + if (Bits[i] != 0) + return i * BITWORD_SIZE + llvm::countTrailingZeros(Bits[i]); + throw std::runtime_error("Illegal empty element"); + } + + /// find_last - Returns the index of the last set bit. + int find_last() const { + for (unsigned I = 0; I < BITWORDS_PER_ELEMENT; ++I) { + unsigned Idx = BITWORDS_PER_ELEMENT - I - 1; + if (Bits[Idx] != 0) + return Idx * BITWORD_SIZE + BITWORD_SIZE - + llvm::countLeadingZeros(Bits[Idx]); + } + throw std::runtime_error("Illegal empty element"); + } + + /// find_next - Returns the index of the next set bit starting from the + /// "Curr" bit. Returns -1 if the next set bit is not found. + int find_next(unsigned Curr) const { + if (Curr >= BITS_PER_ELEMENT) + return -1; + + unsigned WordPos = Curr / BITWORD_SIZE; + unsigned BitPos = Curr % BITWORD_SIZE; + BitWord Copy = Bits[WordPos]; + assert( + WordPos <= BITWORDS_PER_ELEMENT && "Word Position outside of element"); + + // Mask off previous bits. + Copy &= ~0UL << BitPos; + + if (Copy != 0) + return WordPos * BITWORD_SIZE + llvm::countTrailingZeros(Copy); + + // Check subsequent words. + for (unsigned i = WordPos + 1; i < BITWORDS_PER_ELEMENT; ++i) + if (Bits[i] != 0) + return i * BITWORD_SIZE + llvm::countTrailingZeros(Bits[i]); + return -1; + } + + // Union this element with RHS and return true if this one changed. + bool unionWith(const SparseBitVectorElement& RHS) { + bool changed = false; + for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i) { + BitWord old = changed ? 0 : Bits[i]; + + Bits[i] |= RHS.Bits[i]; + if (!changed && old != Bits[i]) + changed = true; + } + return changed; + } + + // Return true if we have any bits in common with RHS + bool intersects(const SparseBitVectorElement& RHS) const { + for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i) { + if (RHS.Bits[i] & Bits[i]) + return true; + } + return false; + } + + // Intersect this Element with RHS and return true if this one changed. + // BecameZero is set to true if this element became all-zero bits. + bool intersectWith(const SparseBitVectorElement& RHS, bool& BecameZero) { + bool changed = false; + bool allzero = true; + + for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i) { + BitWord old = changed ? 0 : Bits[i]; + + Bits[i] &= RHS.Bits[i]; + if (Bits[i] != 0) + allzero = false; + + if (!changed && old != Bits[i]) + changed = true; + } + BecameZero = allzero; + return changed; + } + + // Intersect this Element with the complement of RHS and return true if this + // one changed. BecameZero is set to true if this element became all-zero + // bits. + bool intersectWithComplement( + const SparseBitVectorElement& RHS, + bool& BecameZero) { + bool changed = false; + bool allzero = true; + + for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i) { + BitWord old = changed ? 0 : Bits[i]; + + Bits[i] &= ~RHS.Bits[i]; + if (Bits[i] != 0) + allzero = false; + + if (!changed && old != Bits[i]) + changed = true; + } + BecameZero = allzero; + return changed; + } + + // Three argument version of intersectWithComplement that intersects + // RHS1 & ~RHS2 into this element + void intersectWithComplement( + const SparseBitVectorElement& RHS1, + const SparseBitVectorElement& RHS2, + bool& BecameZero) { + bool allzero = true; + + for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i) { + Bits[i] = RHS1.Bits[i] & ~RHS2.Bits[i]; + if (Bits[i] != 0) + allzero = false; + } + BecameZero = allzero; + } +}; + +template +class SparseBitVector { + using ElementList = std::list>; + using ElementListIter = typename ElementList::iterator; + using ElementListConstIter = typename ElementList::const_iterator; + enum { BITWORD_SIZE = SparseBitVectorElement::BITWORD_SIZE }; + + ElementList Elements; + // Pointer to our current Element. This has no visible effect on the external + // state of a SparseBitVector, it's just used to improve performance in the + // common case of testing/modifying bits with similar indices. + mutable ElementListIter CurrElementIter; + + // This is like std::lower_bound, except we do linear searching from the + // current position. + ElementListIter FindLowerBoundImpl(unsigned ElementIndex) const { + // We cache a non-const iterator so we're forced to resort to const_cast to + // get the begin/end in the case where 'this' is const. To avoid duplication + // of code with the only difference being whether the const cast is present + // 'this' is always const in this particular function and we sort out the + // difference in FindLowerBound and FindLowerBoundConst. + ElementListIter Begin = + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) + const_cast*>(this)->Elements.begin(); + ElementListIter End = + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) + const_cast*>(this)->Elements.end(); + + if (Elements.empty()) { + CurrElementIter = Begin; + return CurrElementIter; + } + + // Make sure our current iterator is valid. + if (CurrElementIter == End) + --CurrElementIter; + + // Search from our current iterator, either backwards or forwards, + // depending on what element we are looking for. + ElementListIter ElementIter = CurrElementIter; + if (CurrElementIter->index() == ElementIndex) { + return ElementIter; + } else if (CurrElementIter->index() > ElementIndex) { + while (ElementIter != Begin && ElementIter->index() > ElementIndex) + --ElementIter; + } else { + while (ElementIter != End && ElementIter->index() < ElementIndex) + ++ElementIter; + } + CurrElementIter = ElementIter; + return ElementIter; + } + ElementListConstIter FindLowerBoundConst(unsigned ElementIndex) const { + return FindLowerBoundImpl(ElementIndex); + } + ElementListIter FindLowerBound(unsigned ElementIndex) { + return FindLowerBoundImpl(ElementIndex); + } + + // Iterator to walk set bits in the bitmap. This iterator is a lot uglier + // than it would be, in order to be efficient. + class SparseBitVectorIterator { + private: + bool AtEnd{false}; + + const SparseBitVector* BitVector = nullptr; + + // Current element inside of bitmap. + ElementListConstIter Iter; + + // Current bit number inside of our bitmap. + unsigned BitNumber{0}; + + // Current word number inside of our element. + unsigned WordNumber{0}; + + // Current bits from the element. + typename SparseBitVectorElement::BitWord Bits{0}; + + // Move our iterator to the first non-zero bit in the bitmap. + void AdvanceToFirstNonZero() { + if (AtEnd) + return; + if (BitVector->Elements.empty()) { + AtEnd = true; + return; + } + Iter = BitVector->Elements.begin(); + BitNumber = Iter->index() * ElementSize; + unsigned BitPos = Iter->find_first(); + BitNumber += BitPos; + WordNumber = (BitNumber % ElementSize) / BITWORD_SIZE; + Bits = Iter->word(WordNumber); + Bits >>= BitPos % BITWORD_SIZE; + } + + // Move our iterator to the next non-zero bit. + void AdvanceToNextNonZero() { + if (AtEnd) + return; + + while (Bits && !(Bits & 1)) { + Bits >>= 1; + BitNumber += 1; + } + + // See if we ran out of Bits in this word. + if (!Bits) { + int NextSetBitNumber = Iter->find_next(BitNumber % ElementSize); + // If we ran out of set bits in this element, move to next element. + if (NextSetBitNumber == -1 || (BitNumber % ElementSize == 0)) { + ++Iter; + WordNumber = 0; + + // We may run out of elements in the bitmap. + if (Iter == BitVector->Elements.end()) { + AtEnd = true; + return; + } + // Set up for next non-zero word in bitmap. + BitNumber = Iter->index() * ElementSize; + NextSetBitNumber = Iter->find_first(); + BitNumber += NextSetBitNumber; + WordNumber = (BitNumber % ElementSize) / BITWORD_SIZE; + Bits = Iter->word(WordNumber); + Bits >>= NextSetBitNumber % BITWORD_SIZE; + } else { + WordNumber = (NextSetBitNumber % ElementSize) / BITWORD_SIZE; + Bits = Iter->word(WordNumber); + Bits >>= NextSetBitNumber % BITWORD_SIZE; + BitNumber = Iter->index() * ElementSize; + BitNumber += NextSetBitNumber; + } + } + } + + public: + SparseBitVectorIterator() = default; + + SparseBitVectorIterator( + const SparseBitVector* RHS, + bool end = false) + : AtEnd(end), + BitVector(RHS), + Iter(BitVector->Elements.begin()), + WordNumber(~0) { + AdvanceToFirstNonZero(); + } + + // Preincrement. + inline SparseBitVectorIterator& operator++() { + ++BitNumber; + Bits >>= 1; + AdvanceToNextNonZero(); + return *this; + } + + // Postincrement. + inline SparseBitVectorIterator operator++(int) { + SparseBitVectorIterator tmp = *this; + ++*this; + return tmp; + } + + // Return the current set bit number. + unsigned operator*() const { + return BitNumber; + } + + bool operator==(const SparseBitVectorIterator& RHS) const { + // If they are both at the end, ignore the rest of the fields. + if (AtEnd && RHS.AtEnd) + return true; + // Otherwise they are the same if they have the same bit number and + // bitmap. + return AtEnd == RHS.AtEnd && RHS.BitNumber == BitNumber; + } + + bool operator!=(const SparseBitVectorIterator& RHS) const { + return !(*this == RHS); + } + }; + + public: + using iterator = SparseBitVectorIterator; + + SparseBitVector() : Elements(), CurrElementIter(Elements.begin()) {} + + SparseBitVector(const SparseBitVector& RHS) + : Elements(RHS.Elements), CurrElementIter(Elements.begin()) {} + SparseBitVector(SparseBitVector&& RHS) noexcept + : Elements(std::move(RHS.Elements)), CurrElementIter(Elements.begin()) {} + ~SparseBitVector() = default; + + // Clear. + void clear() { + Elements.clear(); + } + + // Assignment + SparseBitVector& operator=(const SparseBitVector& RHS) { + if (this == &RHS) + return *this; + + Elements = RHS.Elements; + CurrElementIter = Elements.begin(); + return *this; + } + SparseBitVector& operator=(SparseBitVector&& RHS) noexcept { + Elements = std::move(RHS.Elements); + CurrElementIter = Elements.begin(); + return *this; + } + + // Test, Reset, and Set a bit in the bitmap. + bool test(unsigned Idx) const { + if (Elements.empty()) + return false; + + unsigned ElementIndex = Idx / ElementSize; + ElementListConstIter ElementIter = FindLowerBoundConst(ElementIndex); + + // If we can't find an element that is supposed to contain this bit, there + // is nothing more to do. + if (ElementIter == Elements.end() || ElementIter->index() != ElementIndex) + return false; + return ElementIter->test(Idx % ElementSize); + } + + void reset(unsigned Idx) { + if (Elements.empty()) + return; + + unsigned ElementIndex = Idx / ElementSize; + ElementListIter ElementIter = FindLowerBound(ElementIndex); + + // If we can't find an element that is supposed to contain this bit, there + // is nothing more to do. + if (ElementIter == Elements.end() || ElementIter->index() != ElementIndex) + return; + ElementIter->reset(Idx % ElementSize); + + // When the element is zeroed out, delete it. + if (ElementIter->empty()) { + ++CurrElementIter; + Elements.erase(ElementIter); + } + } + + void set(unsigned Idx) { + unsigned ElementIndex = Idx / ElementSize; + ElementListIter ElementIter; + if (Elements.empty()) { + ElementIter = Elements.emplace(Elements.end(), ElementIndex); + } else { + ElementIter = FindLowerBound(ElementIndex); + + if (ElementIter == Elements.end() || + ElementIter->index() != ElementIndex) { + // We may have hit the beginning of our SparseBitVector, in which case, + // we may need to insert right after this element, which requires moving + // the current iterator forward one, because insert does insert before. + if (ElementIter != Elements.end() && + ElementIter->index() < ElementIndex) + ++ElementIter; + ElementIter = Elements.emplace(ElementIter, ElementIndex); + } + } + CurrElementIter = ElementIter; + + ElementIter->set(Idx % ElementSize); + } + + bool test_and_set(unsigned Idx) { + bool old = test(Idx); + if (!old) { + set(Idx); + return true; + } + return false; + } + + bool operator!=(const SparseBitVector& RHS) const { + return !(*this == RHS); + } + + bool operator==(const SparseBitVector& RHS) const { + ElementListConstIter Iter1 = Elements.begin(); + ElementListConstIter Iter2 = RHS.Elements.begin(); + + for (; Iter1 != Elements.end() && Iter2 != RHS.Elements.end(); + ++Iter1, ++Iter2) { + if (*Iter1 != *Iter2) + return false; + } + return Iter1 == Elements.end() && Iter2 == RHS.Elements.end(); + } + + // Union our bitmap with the RHS and return true if we changed. + bool operator|=(const SparseBitVector& RHS) { + if (this == &RHS) + return false; + + if (empty()) { + *this = RHS; + return true; + } + + bool changed = false; + ElementListIter Iter1 = Elements.begin(); + ElementListConstIter Iter2 = RHS.Elements.begin(); + + // If RHS is empty, we are done + if (RHS.Elements.empty()) + return false; + + while (Iter2 != RHS.Elements.end()) { + if (Iter1 == Elements.end() || Iter1->index() > Iter2->index()) { + Elements.insert(Iter1, *Iter2); + ++Iter2; + changed = true; + } else if (Iter1->index() == Iter2->index()) { + changed |= Iter1->unionWith(*Iter2); + ++Iter1; + ++Iter2; + } else { + ++Iter1; + } + } + CurrElementIter = Elements.begin(); + return changed; + } + + // Intersect our bitmap with the RHS and return true if ours changed. + bool operator-=(const SparseBitVector& RHS) { + return intersectWithComplement(RHS); + } + + // Intersect our bitmap with the RHS and return true if ours changed. + bool operator&=(const SparseBitVector& RHS) { + if (this == &RHS) + return false; + + bool changed = false; + ElementListIter Iter1 = Elements.begin(); + ElementListConstIter Iter2 = RHS.Elements.begin(); + + // Check if both bitmaps are empty. + if (Elements.empty() && RHS.Elements.empty()) + return false; + + // Loop through, intersecting as we go, erasing elements when necessary. + while (Iter2 != RHS.Elements.end()) { + if (Iter1 == Elements.end()) { + CurrElementIter = Elements.begin(); + return changed; + } + + if (Iter1->index() > Iter2->index()) { + ++Iter2; + } else if (Iter1->index() == Iter2->index()) { + bool BecameZero = false; + changed |= Iter1->intersectWith(*Iter2, BecameZero); + if (BecameZero) { + ElementListIter IterTmp = Iter1; + ++Iter1; + Elements.erase(IterTmp); + } else { + ++Iter1; + } + ++Iter2; + } else { + ElementListIter IterTmp = Iter1; + ++Iter1; + Elements.erase(IterTmp); + changed = true; + } + } + if (Iter1 != Elements.end()) { + Elements.erase(Iter1, Elements.end()); + changed = true; + } + CurrElementIter = Elements.begin(); + return changed; + } + + // Intersect our bitmap with the complement of the RHS and return true + // if ours changed. + bool intersectWithComplement(const SparseBitVector& RHS) { + if (this == &RHS) { + if (!empty()) { + clear(); + return true; + } + return false; + } + + bool changed = false; + ElementListIter Iter1 = Elements.begin(); + ElementListConstIter Iter2 = RHS.Elements.begin(); + + // If either our bitmap or RHS is empty, we are done + if (Elements.empty() || RHS.Elements.empty()) + return false; + + // Loop through, intersecting as we go, erasing elements when necessary. + while (Iter2 != RHS.Elements.end()) { + if (Iter1 == Elements.end()) { + CurrElementIter = Elements.begin(); + return changed; + } + + if (Iter1->index() > Iter2->index()) { + ++Iter2; + } else if (Iter1->index() == Iter2->index()) { + bool BecameZero = false; + changed |= Iter1->intersectWithComplement(*Iter2, BecameZero); + if (BecameZero) { + ElementListIter IterTmp = Iter1; + ++Iter1; + Elements.erase(IterTmp); + } else { + ++Iter1; + } + ++Iter2; + } else { + ++Iter1; + } + } + CurrElementIter = Elements.begin(); + return changed; + } + + bool intersectWithComplement(const SparseBitVector* RHS) const { + return intersectWithComplement(*RHS); + } + + // Three argument version of intersectWithComplement. + // Result of RHS1 & ~RHS2 is stored into this bitmap. + void intersectWithComplement( + const SparseBitVector& RHS1, + const SparseBitVector& RHS2) { + if (this == &RHS1) { + intersectWithComplement(RHS2); + return; + } else if (this == &RHS2) { + SparseBitVector RHS2Copy(RHS2); + intersectWithComplement(RHS1, RHS2Copy); + return; + } + + Elements.clear(); + CurrElementIter = Elements.begin(); + ElementListConstIter Iter1 = RHS1.Elements.begin(); + ElementListConstIter Iter2 = RHS2.Elements.begin(); + + // If RHS1 is empty, we are done + // If RHS2 is empty, we still have to copy RHS1 + if (RHS1.Elements.empty()) + return; + + // Loop through, intersecting as we go, erasing elements when necessary. + while (Iter2 != RHS2.Elements.end()) { + if (Iter1 == RHS1.Elements.end()) + return; + + if (Iter1->index() > Iter2->index()) { + ++Iter2; + } else if (Iter1->index() == Iter2->index()) { + bool BecameZero = false; + Elements.emplace_back(Iter1->index()); + Elements.back().intersectWithComplement(*Iter1, *Iter2, BecameZero); + if (BecameZero) + Elements.pop_back(); + ++Iter1; + ++Iter2; + } else { + Elements.push_back(*Iter1++); + } + } + + // copy the remaining elements + std::copy(Iter1, RHS1.Elements.end(), std::back_inserter(Elements)); + } + + void intersectWithComplement( + const SparseBitVector* RHS1, + const SparseBitVector* RHS2) { + intersectWithComplement(*RHS1, *RHS2); + } + + bool intersects(const SparseBitVector* RHS) const { + return intersects(*RHS); + } + + // Return true if we share any bits in common with RHS + bool intersects(const SparseBitVector& RHS) const { + ElementListConstIter Iter1 = Elements.begin(); + ElementListConstIter Iter2 = RHS.Elements.begin(); + + // Check if both bitmaps are empty. + if (Elements.empty() && RHS.Elements.empty()) + return false; + + // Loop through, intersecting stopping when we hit bits in common. + while (Iter2 != RHS.Elements.end()) { + if (Iter1 == Elements.end()) + return false; + + if (Iter1->index() > Iter2->index()) { + ++Iter2; + } else if (Iter1->index() == Iter2->index()) { + if (Iter1->intersects(*Iter2)) + return true; + ++Iter1; + ++Iter2; + } else { + ++Iter1; + } + } + return false; + } + + // Return true iff all bits set in this SparseBitVector are + // also set in RHS. + bool contains(const SparseBitVector& RHS) const { + SparseBitVector Result(*this); + Result &= RHS; + return (Result == RHS); + } + + // Return the first set bit in the bitmap. Return -1 if no bits are set. + int find_first() const { + if (Elements.empty()) + return -1; + const SparseBitVectorElement& First = *(Elements.begin()); + return (First.index() * ElementSize) + First.find_first(); + } + + // Return the last set bit in the bitmap. Return -1 if no bits are set. + int find_last() const { + if (Elements.empty()) + return -1; + const SparseBitVectorElement& Last = *(Elements.rbegin()); + return (Last.index() * ElementSize) + Last.find_last(); + } + + // Return true if the SparseBitVector is empty + bool empty() const { + return Elements.empty(); + } + + unsigned count() const { + unsigned BitCount = 0; + for (ElementListConstIter Iter = Elements.begin(); Iter != Elements.end(); + ++Iter) + BitCount += Iter->count(); + + return BitCount; + } + + iterator begin() const { + return iterator(this); + } + + iterator end() const { + return iterator(this, true); + } +}; + +// Convenience functions to allow Or and And without dereferencing in the user +// code. + +template +inline bool operator|=( + SparseBitVector& LHS, + const SparseBitVector* RHS) { + return LHS |= *RHS; +} + +template +inline bool operator|=( + SparseBitVector* LHS, + const SparseBitVector& RHS) { + return LHS->operator|=(RHS); +} + +template +inline bool operator&=( + SparseBitVector* LHS, + const SparseBitVector& RHS) { + return LHS->operator&=(RHS); +} + +template +inline bool operator&=( + SparseBitVector& LHS, + const SparseBitVector* RHS) { + return LHS &= *RHS; +} + +// Convenience functions for infix union, intersection, difference operators. + +template +inline SparseBitVector operator|( + const SparseBitVector& LHS, + const SparseBitVector& RHS) { + SparseBitVector Result(LHS); + Result |= RHS; + return Result; +} + +template +inline SparseBitVector operator&( + const SparseBitVector& LHS, + const SparseBitVector& RHS) { + SparseBitVector Result(LHS); + Result &= RHS; + return Result; +} + +template +inline SparseBitVector operator-( + const SparseBitVector& LHS, + const SparseBitVector& RHS) { + SparseBitVector Result; + Result.intersectWithComplement(LHS, RHS); + return Result; +} + +template +std::ostream& operator<<( + std::ostream& stream, + const SparseBitVector& vec) { + bool first = true; + stream << '{'; + for (auto el : vec) { + if (first) { + first = false; + } else { + stream << ", "; + } + stream << el; + } + stream << '}'; + return stream; +} + +} // end namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/ssize.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/ssize.h new file mode 100644 index 0000000000000000000000000000000000000000..395bf8a2eb7c5ef35f0de9530f4f9ccb9fe18e42 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/ssize.h @@ -0,0 +1,51 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +#include +#include + +namespace c10 { + +// Implementations of std::ssize() from C++ 20. +// +// This is useful in particular for avoiding -Werror=sign-compare +// issues. +// +// Use this with argument-dependent lookup, e.g.: +// use c10::ssize; +// auto size = ssize(container); +// +// As with the standard library version, containers are permitted to +// specialize this with a free function defined in the same namespace. +// +// See https://en.cppreference.com/w/cpp/iterator/size for more +// information as well as the source of our implementations. +// +// We augment the implementation by adding an assert() if an overflow +// would occur. + +template +constexpr auto ssize(const C& c) -> std:: + common_type_t> { + using R = std:: + common_type_t>; + // We expect this to be exceedingly rare to fire and don't wish to + // pay a performance hit in release mode. + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!greater_than_max(c.size())); + return static_cast(c.size()); +} + +template +// NOLINTNEXTLINE(*-c-arrays) +constexpr auto ssize(const T (&array)[N]) noexcept -> std::ptrdiff_t { + return N; +} + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/static_tracepoint.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/static_tracepoint.h new file mode 100644 index 0000000000000000000000000000000000000000..4030828469d45cdbef603bbb8588071a41b9b398 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/static_tracepoint.h @@ -0,0 +1,39 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#if defined(__ELF__) && (defined(__x86_64__) || defined(__i386__)) && \ + !(defined(TORCH_DISABLE_SDT) && TORCH_DISABLE_SDT) + +#define TORCH_HAVE_SDT 1 + +#include + +#define TORCH_SDT(name, ...) \ + TORCH_SDT_PROBE_N( \ + pytorch, name, 0, TORCH_SDT_NARG(0, ##__VA_ARGS__), ##__VA_ARGS__) +// Use TORCH_SDT_DEFINE_SEMAPHORE(name) to define the semaphore +// as global variable before using the TORCH_SDT_WITH_SEMAPHORE macro +#define TORCH_SDT_WITH_SEMAPHORE(name, ...) \ + TORCH_SDT_PROBE_N( \ + pytorch, name, 1, TORCH_SDT_NARG(0, ##__VA_ARGS__), ##__VA_ARGS__) +#define TORCH_SDT_IS_ENABLED(name) (TORCH_SDT_SEMAPHORE(pytorch, name) > 0) + +#else + +#define TORCH_HAVE_SDT 0 + +#define TORCH_SDT(name, ...) \ + do { \ + } while (0) +#define TORCH_SDT_WITH_SEMAPHORE(name, ...) \ + do { \ + } while (0) +#define TORCH_SDT_IS_ENABLED(name) (false) +#define TORCH_SDT_DEFINE_SEMAPHORE(name) +#define TORCH_SDT_DECLARE_SEMAPHORE(name) + +#endif + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/static_tracepoint_elfx86.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/static_tracepoint_elfx86.h new file mode 100644 index 0000000000000000000000000000000000000000..a3afe767fee1e9cf92062b2ece5e2f0520dcb9e4 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/static_tracepoint_elfx86.h @@ -0,0 +1,149 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +// clang-format off + +// Default constraint for the probe arguments as operands. +#ifndef TORCH_SDT_ARG_CONSTRAINT +#define TORCH_SDT_ARG_CONSTRAINT "nor" +#endif + +// Instruction to emit for the probe. +#define TORCH_SDT_NOP nop + +// Note section properties. +#define TORCH_SDT_NOTE_NAME "stapsdt" +#define TORCH_SDT_NOTE_TYPE 3 + +// Semaphore variables are put in this section +#define TORCH_SDT_SEMAPHORE_SECTION ".probes" + +// Size of address depending on platform. +#ifdef __LP64__ +#define TORCH_SDT_ASM_ADDR .8byte +#else +#define TORCH_SDT_ASM_ADDR .4byte +#endif + +// Assembler helper Macros. +#define TORCH_SDT_S(x) #x +#define TORCH_SDT_ASM_1(x) TORCH_SDT_S(x) "\n" +#define TORCH_SDT_ASM_2(a, b) TORCH_SDT_S(a) "," TORCH_SDT_S(b) "\n" +#define TORCH_SDT_ASM_3(a, b, c) TORCH_SDT_S(a) "," TORCH_SDT_S(b) "," \ + TORCH_SDT_S(c) "\n" +#define TORCH_SDT_ASM_STRING(x) TORCH_SDT_ASM_1(.asciz TORCH_SDT_S(x)) + +// Helper to determine the size of an argument. +#define TORCH_SDT_IS_ARRAY_POINTER(x) ((__builtin_classify_type(x) == 14) || \ + (__builtin_classify_type(x) == 5)) +#define TORCH_SDT_ARGSIZE(x) (TORCH_SDT_IS_ARRAY_POINTER(x) \ + ? sizeof(void*) \ + : sizeof(x)) + +// Format of each probe arguments as operand. +// Size of the argument tagged with TORCH_SDT_Sn, with "n" constraint. +// Value of the argument tagged with TORCH_SDT_An, with configured constraint. +#define TORCH_SDT_ARG(n, x) \ + [TORCH_SDT_S##n] "n" ((size_t)TORCH_SDT_ARGSIZE(x)), \ + [TORCH_SDT_A##n] TORCH_SDT_ARG_CONSTRAINT (x) + +// Templates to append arguments as operands. +#define TORCH_SDT_OPERANDS_0() [__sdt_dummy] "g" (0) +#define TORCH_SDT_OPERANDS_1(_1) TORCH_SDT_ARG(1, _1) +#define TORCH_SDT_OPERANDS_2(_1, _2) \ + TORCH_SDT_OPERANDS_1(_1), TORCH_SDT_ARG(2, _2) +#define TORCH_SDT_OPERANDS_3(_1, _2, _3) \ + TORCH_SDT_OPERANDS_2(_1, _2), TORCH_SDT_ARG(3, _3) +#define TORCH_SDT_OPERANDS_4(_1, _2, _3, _4) \ + TORCH_SDT_OPERANDS_3(_1, _2, _3), TORCH_SDT_ARG(4, _4) +#define TORCH_SDT_OPERANDS_5(_1, _2, _3, _4, _5) \ + TORCH_SDT_OPERANDS_4(_1, _2, _3, _4), TORCH_SDT_ARG(5, _5) +#define TORCH_SDT_OPERANDS_6(_1, _2, _3, _4, _5, _6) \ + TORCH_SDT_OPERANDS_5(_1, _2, _3, _4, _5), TORCH_SDT_ARG(6, _6) +#define TORCH_SDT_OPERANDS_7(_1, _2, _3, _4, _5, _6, _7) \ + TORCH_SDT_OPERANDS_6(_1, _2, _3, _4, _5, _6), TORCH_SDT_ARG(7, _7) +#define TORCH_SDT_OPERANDS_8(_1, _2, _3, _4, _5, _6, _7, _8) \ + TORCH_SDT_OPERANDS_7(_1, _2, _3, _4, _5, _6, _7), TORCH_SDT_ARG(8, _8) +#define TORCH_SDT_OPERANDS_9(_1, _2, _3, _4, _5, _6, _7, _8, _9) \ + TORCH_SDT_OPERANDS_8(_1, _2, _3, _4, _5, _6, _7, _8), TORCH_SDT_ARG(9, _9) + +// Templates to reference the arguments from operands in note section. +#define TORCH_SDT_ARGFMT(no) %n[TORCH_SDT_S##no]@%[TORCH_SDT_A##no] +#define TORCH_SDT_ARG_TEMPLATE_0 /*No arguments*/ +#define TORCH_SDT_ARG_TEMPLATE_1 TORCH_SDT_ARGFMT(1) +#define TORCH_SDT_ARG_TEMPLATE_2 TORCH_SDT_ARG_TEMPLATE_1 TORCH_SDT_ARGFMT(2) +#define TORCH_SDT_ARG_TEMPLATE_3 TORCH_SDT_ARG_TEMPLATE_2 TORCH_SDT_ARGFMT(3) +#define TORCH_SDT_ARG_TEMPLATE_4 TORCH_SDT_ARG_TEMPLATE_3 TORCH_SDT_ARGFMT(4) +#define TORCH_SDT_ARG_TEMPLATE_5 TORCH_SDT_ARG_TEMPLATE_4 TORCH_SDT_ARGFMT(5) +#define TORCH_SDT_ARG_TEMPLATE_6 TORCH_SDT_ARG_TEMPLATE_5 TORCH_SDT_ARGFMT(6) +#define TORCH_SDT_ARG_TEMPLATE_7 TORCH_SDT_ARG_TEMPLATE_6 TORCH_SDT_ARGFMT(7) +#define TORCH_SDT_ARG_TEMPLATE_8 TORCH_SDT_ARG_TEMPLATE_7 TORCH_SDT_ARGFMT(8) +#define TORCH_SDT_ARG_TEMPLATE_9 TORCH_SDT_ARG_TEMPLATE_8 TORCH_SDT_ARGFMT(9) + +// Resolvable by name macros +// An attribute that marks a function or variable as needing to be resolvable +// by name. This generally is needed if inline assembly refers to the variable +// by string name. +#ifdef __roar__ +#define TORCH_NAME_RESOLVABLE __attribute__((roar_resolvable_by_name)) +#else +#define TORCH_NAME_RESOLVABLE +#endif + +// Semaphore define, declare and probe note format + +#define TORCH_SDT_SEMAPHORE(provider, name) \ + torch_sdt_semaphore_##provider##_##name + +#define TORCH_SDT_DEFINE_SEMAPHORE(name) \ + extern "C" { \ + TORCH_NAME_RESOLVABLE \ + volatile unsigned short TORCH_SDT_SEMAPHORE(pytorch, name) \ + __attribute__((section(TORCH_SDT_SEMAPHORE_SECTION), used)) = 0; \ + } + +#define TORCH_SDT_DECLARE_SEMAPHORE(name) \ + extern "C" TORCH_NAME_RESOLVABLE volatile unsigned short \ + TORCH_SDT_SEMAPHORE(pytorch, name) + +#define TORCH_SDT_SEMAPHORE_NOTE_0(provider, name) \ + TORCH_SDT_ASM_1( TORCH_SDT_ASM_ADDR 0) /*No Semaphore*/ \ + +#define TORCH_SDT_SEMAPHORE_NOTE_1(provider, name) \ + TORCH_SDT_ASM_1(TORCH_SDT_ASM_ADDR TORCH_SDT_SEMAPHORE(provider, name)) + +// Structure of note section for the probe. +#define TORCH_SDT_NOTE_CONTENT(provider, name, has_semaphore, arg_template) \ + TORCH_SDT_ASM_1(990: TORCH_SDT_NOP) \ + TORCH_SDT_ASM_3( .pushsection .note.stapsdt,"","note") \ + TORCH_SDT_ASM_1( .balign 4) \ + TORCH_SDT_ASM_3( .4byte 992f-991f, 994f-993f, TORCH_SDT_NOTE_TYPE) \ + TORCH_SDT_ASM_1(991: .asciz TORCH_SDT_NOTE_NAME) \ + TORCH_SDT_ASM_1(992: .balign 4) \ + TORCH_SDT_ASM_1(993: TORCH_SDT_ASM_ADDR 990b) \ + TORCH_SDT_ASM_1( TORCH_SDT_ASM_ADDR 0) /*Reserved for Base Address*/ \ + TORCH_SDT_SEMAPHORE_NOTE_##has_semaphore(provider, name) \ + TORCH_SDT_ASM_STRING(provider) \ + TORCH_SDT_ASM_STRING(name) \ + TORCH_SDT_ASM_STRING(arg_template) \ + TORCH_SDT_ASM_1(994: .balign 4) \ + TORCH_SDT_ASM_1( .popsection) + +// Main probe Macro. +#define TORCH_SDT_PROBE(provider, name, has_semaphore, n, arglist) \ + __asm__ __volatile__ ( \ + TORCH_SDT_NOTE_CONTENT( \ + provider, name, has_semaphore, TORCH_SDT_ARG_TEMPLATE_##n) \ + :: TORCH_SDT_OPERANDS_##n arglist \ + ) \ + +// Helper Macros to handle variadic arguments. +#define TORCH_SDT_NARG_(_0, _1, _2, _3, _4, _5, _6, _7, _8, _9, N, ...) N +#define TORCH_SDT_NARG(...) \ + TORCH_SDT_NARG_(__VA_ARGS__, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0) +#define TORCH_SDT_PROBE_N(provider, name, has_semaphore, N, ...) \ + TORCH_SDT_PROBE(provider, name, has_semaphore, N, (__VA_ARGS__)) + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/strides.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/strides.h new file mode 100644 index 0000000000000000000000000000000000000000..1e74cffc5e6338d234846bac166d5fcac7db63b0 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/strides.h @@ -0,0 +1,29 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include +#include +#include + +namespace c10 { + +// Computes the contiguous strides of a tensor, given its sizes. +inline DimVector contiguous_strides(const IntArrayRef sizes) { + using Int = IntArrayRef::value_type; + const Int dims = static_cast(sizes.size()); + + // With this initialisation we get the case dim == 0 or 1 right + DimVector strides(dims, 1); + + for (auto i = dims - 2; i >= 0; --i) { + // Strides can't be 0 even if sizes are 0. + strides[i] = strides[i + 1] * std::max(sizes[i + 1], Int{1}); + } + + return strides; +} + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/string_utils.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/string_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..cbcf0b1f3c95d2e0e572ae58b6e066efc893f582 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/string_utils.h @@ -0,0 +1,27 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include + +#if !defined(FBCODE_CAFFE2) && !defined(C10_NO_DEPRECATED) + +namespace c10 { + +// NOLINTNEXTLINE(misc-unused-using-decls) +using std::stod; +// NOLINTNEXTLINE(misc-unused-using-decls) +using std::stoi; +// NOLINTNEXTLINE(misc-unused-using-decls) +using std::stoll; +// NOLINTNEXTLINE(misc-unused-using-decls) +using std::stoull; +// NOLINTNEXTLINE(misc-unused-using-decls) +using std::to_string; + +} // namespace c10 + +#endif + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/string_view.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/string_view.h new file mode 100644 index 0000000000000000000000000000000000000000..559cde09f9c35071293f0ed62d481ea7f6940710 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/string_view.h @@ -0,0 +1,648 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace c10 { + +/** + * Port of std::string_view with methods from C++20. + * Implemented following the interface definition in + * https://en.cppreference.com/w/cpp/string/basic_string_view + * See there for the API documentation. + * + * Difference: We don't have a Traits template parameter because + * std::char_traits isn't constexpr and we'd have to reimplement + * std::char_traits if we wanted to use it with our constexpr basic_string_view. + */ +template +// NOLINTNEXTLINE(cppcoreguidelines-special-member-functions) +class basic_string_view final { + public: + using value_type = CharT; + using pointer = CharT*; + using const_pointer = const CharT*; + using reference = CharT&; + using const_reference = const CharT&; + using const_iterator = const CharT*; + using iterator = const_iterator; + using const_reverse_iterator = std::reverse_iterator; + using reverse_iterator = const_reverse_iterator; + using size_type = std::size_t; + using difference_type = std::ptrdiff_t; + + static constexpr size_type npos = size_type(-1); + + constexpr basic_string_view() noexcept : begin_(nullptr) {} + + explicit constexpr basic_string_view(const_pointer str, size_type count) + : begin_(str), size_(count) {} + + /* implicit */ constexpr basic_string_view(const_pointer str) + : basic_string_view(str, strlen_(str)) {} + + /* implicit */ basic_string_view(const ::std::basic_string& str) + : basic_string_view(str.data(), str.size()) {} + + /* implicit */ constexpr basic_string_view( + const ::std::basic_string_view& str) + : basic_string_view(str.data(), str.size()) {} + + constexpr basic_string_view(const basic_string_view&) noexcept = default; + + constexpr basic_string_view& operator=( + const basic_string_view& rhs) noexcept = default; + + constexpr operator ::std::basic_string_view() const { + return ::std::basic_string_view(data(), size()); + } + + explicit operator ::std::basic_string() const { + return ::std::basic_string(data(), size()); + } + + constexpr const_iterator begin() const noexcept { + return cbegin(); + } + + constexpr const_iterator cbegin() const noexcept { + return begin_; + } + + constexpr const_iterator end() const noexcept { + return cend(); + } + + constexpr const_iterator cend() const noexcept { + return begin_ + size_; + } + + constexpr const_reverse_iterator rbegin() const noexcept { + return crbegin(); + } + + constexpr const_reverse_iterator crbegin() const noexcept { + return const_reverse_iterator(this->end()); + } + + constexpr const_reverse_iterator rend() const noexcept { + return crend(); + } + + constexpr const_reverse_iterator crend() const noexcept { + return const_reverse_iterator(this->begin()); + } + + friend constexpr const_iterator begin(basic_string_view sv) noexcept { + return sv.begin(); + } + + friend constexpr const_iterator end(basic_string_view sv) noexcept { + return sv.end(); + } + + constexpr const_reference operator[](size_type pos) const { + // TODO: split out + return at_(pos); + } + + constexpr const_reference at(size_type pos) const { +#if !defined( \ + __CUDA_ARCH__) // CUDA doesn't like std::out_of_range in device code + return C10_UNLIKELY(pos >= size_) + ? (throw std::out_of_range( + "string_view::operator[] or string_view::at() out of range. Index: " + + std::to_string(pos) + ", size: " + std::to_string(size())), + at_(0)) + : at_(pos); +#else + return at_(pos); +#endif + } + + constexpr const_reference front() const { + return *begin_; + } + + constexpr const_reference back() const { + return *(begin_ + size_ - 1); + } + + constexpr const_pointer data() const noexcept { + return begin_; + } + + constexpr size_type size() const noexcept { + return size_; + } + + constexpr size_type length() const noexcept { + return size(); + } + + constexpr size_type max_size() const noexcept { + return std::numeric_limits::max(); + } + + [[nodiscard]] constexpr bool empty() const noexcept { + return size() == 0; + } + + constexpr void remove_prefix(size_type n) { + if (n > size()) { + throw std::out_of_range( + "basic_string_view::remove_prefix: out of range. PrefixLength: " + + std::to_string(n) + ", size: " + std::to_string(size())); + } + begin_ += n; + size_ -= n; + } + + constexpr void remove_suffix(size_type n) { + if (n > size()) { + throw std::out_of_range( + "basic_string_view::remove_suffix: out of range. SuffixLength: " + + std::to_string(n) + ", size: " + std::to_string(size())); + } + size_ -= n; + } + + constexpr void swap(basic_string_view& sv) noexcept { + auto tmp = *this; + *this = sv; + sv = tmp; + } + + size_type copy(pointer dest, size_type count, size_type pos = 0) const { + if (pos > size_) { + throw std::out_of_range( + "basic_string_view::copy: out of range. Index: " + + std::to_string(pos) + ", size: " + std::to_string(size())); + } + size_type copy_length = std::min(count, size_ - pos); + for (auto iter = begin() + pos, end = iter + copy_length; iter != end;) { + *(dest++) = *(iter++); + } + return copy_length; + } + + constexpr basic_string_view substr(size_type pos = 0, size_type count = npos) + const { +#if !defined( \ + __CUDA_ARCH__) // CUDA doesn't like std::out_of_range in device code + return (pos > size_) + ? (throw std::out_of_range( + "basic_string_view::substr parameter out of bounds. Index: " + + std::to_string(pos) + ", size: " + std::to_string(size())), + substr_()) + : substr_(pos, count); +#else + return substr_(pos, count); +#endif + } + + constexpr int compare(basic_string_view rhs) const noexcept { + // Write it iteratively. This is faster. + for (size_t i = 0, end = std::min(size(), rhs.size()); i < end; ++i) { + if (at_(i) < rhs.at_(i)) { + return -1; + } else if (at_(i) > rhs.at_(i)) { + return 1; + } + } + if (size() < rhs.size()) { + return -1; + } else if (size() > rhs.size()) { + return 1; + } + return 0; + } + + constexpr int compare(size_type pos1, size_type count1, basic_string_view v) + const { + return substr(pos1, count1).compare(v); + } + + constexpr int compare( + size_type pos1, + size_type count1, + basic_string_view v, + size_type pos2, + size_type count2) const { + return substr(pos1, count1).compare(v.substr(pos2, count2)); + } + + constexpr int compare(const_pointer s) const { + return compare(basic_string_view(s)); + } + + constexpr int compare(size_type pos1, size_type count1, const_pointer s) + const { + return substr(pos1, count1).compare(basic_string_view(s)); + } + + constexpr int compare( + size_type pos1, + size_type count1, + const_pointer s, + size_type count2) const { + return substr(pos1, count1).compare(basic_string_view(s, count2)); + } + + friend constexpr bool operator==( + basic_string_view lhs, + basic_string_view rhs) noexcept { + return lhs.equals_(rhs); + } + + friend constexpr bool operator!=( + basic_string_view lhs, + basic_string_view rhs) noexcept { + return !(lhs == rhs); + } + + friend constexpr bool operator<( + basic_string_view lhs, + basic_string_view rhs) noexcept { + return lhs.compare(rhs) < 0; + } + + friend constexpr bool operator>=( + basic_string_view lhs, + basic_string_view rhs) noexcept { + return !(lhs < rhs); + } + + friend constexpr bool operator>( + basic_string_view lhs, + basic_string_view rhs) noexcept { + return rhs < lhs; + } + + friend constexpr bool operator<=( + basic_string_view lhs, + basic_string_view rhs) noexcept { + return !(lhs > rhs); + } + + constexpr bool starts_with(basic_string_view prefix) const noexcept { + return (prefix.size() > size()) ? false + : prefix.equals_(substr_(0, prefix.size())); + } + + constexpr bool starts_with(CharT prefix) const noexcept { + return !empty() && prefix == front(); + } + + constexpr bool starts_with(const_pointer prefix) const { + return starts_with(basic_string_view(prefix)); + } + + constexpr bool ends_with(basic_string_view suffix) const noexcept { + return (suffix.size() > size()) + ? false + : suffix.equals_(substr_(size() - suffix.size(), suffix.size())); + } + + constexpr bool ends_with(CharT suffix) const noexcept { + return !empty() && suffix == back(); + } + + constexpr bool ends_with(const_pointer suffix) const { + return ends_with(basic_string_view(suffix)); + } + + constexpr size_type find(basic_string_view v, size_type pos = 0) + const noexcept { + if (v.empty()) { + return pos <= size() ? pos : npos; + } + + if (pos + v.size() <= size()) { + for (size_type cur = pos, end = size() - v.size(); cur <= end; ++cur) { + if (v.at_(0) == at_(cur) && + v.substr_(1).equals_(substr_(cur + 1, v.size() - 1))) { + return cur; + } + } + } + return npos; + } + + constexpr size_type find(CharT ch, size_type pos = 0) const noexcept { + return find_first_if_(pos, charIsEqual_{ch}); + } + + constexpr size_type find(const_pointer s, size_type pos, size_type count) + const { + return find(basic_string_view(s, count), pos); + } + + constexpr size_type find(const_pointer s, size_type pos = 0) const { + return find(basic_string_view(s), pos); + } + + constexpr size_type rfind(basic_string_view v, size_type pos = npos) + const noexcept { + // Write it iteratively. This is faster. + if (v.empty()) { + return pos <= size() ? pos : size(); + } + + if (v.size() <= size()) { + pos = std::min(size() - v.size(), pos); + do { + if (v.at_(0) == at_(pos) && + v.substr_(1).equals_(substr_(pos + 1, v.size() - 1))) { + return pos; + } + } while (pos-- > 0); + } + return npos; + } + + constexpr size_type rfind(CharT ch, size_type pos = npos) const noexcept { + return find_last_if_(pos, charIsEqual_{ch}); + } + + constexpr size_type rfind(const_pointer s, size_type pos, size_type count) + const { + return rfind(basic_string_view(s, count), pos); + } + + constexpr size_type rfind(const_pointer s, size_type pos = npos) const { + return rfind(basic_string_view(s), pos); + } + + constexpr size_type find_first_of(basic_string_view v, size_type pos = 0) + const noexcept { + return find_first_if_(pos, stringViewContainsChar_{v}); + } + + constexpr size_type find_first_of(CharT ch, size_type pos = 0) + const noexcept { + return find_first_if_(pos, charIsEqual_{ch}); + } + + constexpr size_type find_first_of( + const_pointer s, + size_type pos, + size_type count) const { + return find_first_of(basic_string_view(s, count), pos); + } + + constexpr size_type find_first_of(const_pointer s, size_type pos = 0) const { + return find_first_of(basic_string_view(s), pos); + } + + constexpr size_type find_last_of(basic_string_view v, size_type pos = npos) + const noexcept { + return find_last_if_(pos, stringViewContainsChar_{v}); + } + + constexpr size_type find_last_of(CharT ch, size_type pos = npos) + const noexcept { + return find_last_if_(pos, charIsEqual_{ch}); + } + + constexpr size_type find_last_of( + const_pointer s, + size_type pos, + size_type count) const { + return find_last_of(basic_string_view(s, count), pos); + } + + constexpr size_type find_last_of(const_pointer s, size_type pos = npos) + const { + return find_last_of(basic_string_view(s), pos); + } + + constexpr size_type find_first_not_of(basic_string_view v, size_type pos = 0) + const noexcept { + return find_first_if_(pos, stringViewDoesNotContainChar_{v}); + } + + constexpr size_type find_first_not_of(CharT ch, size_type pos = 0) + const noexcept { + return find_first_if_(pos, charIsNotEqual_{ch}); + } + + constexpr size_type find_first_not_of( + const_pointer s, + size_type pos, + size_type count) const { + return find_first_not_of(basic_string_view(s, count), pos); + } + + constexpr size_type find_first_not_of(const_pointer s, size_type pos = 0) + const { + return find_first_not_of(basic_string_view(s), pos); + } + + constexpr size_type find_last_not_of( + basic_string_view v, + size_type pos = npos) const noexcept { + return find_last_if_(pos, stringViewDoesNotContainChar_{v}); + } + + constexpr size_type find_last_not_of(CharT ch, size_type pos = npos) + const noexcept { + return find_last_if_(pos, charIsNotEqual_{ch}); + } + + constexpr size_type find_last_not_of( + const_pointer s, + size_type pos, + size_type count) const { + return find_last_not_of(basic_string_view(s, count), pos); + } + + constexpr size_type find_last_not_of(const_pointer s, size_type pos = npos) + const { + return find_last_not_of(basic_string_view(s), pos); + } + + private: + static constexpr size_type strlen_(const_pointer str) noexcept { + const_pointer current = str; + while (*current != '\0') { + ++current; + } + return current - str; + } + + constexpr const_reference at_(size_type pos) const noexcept { + return *(begin_ + pos); + } + + constexpr basic_string_view substr_(size_type pos = 0, size_type count = npos) + const { + return basic_string_view{begin_ + pos, std::min(count, size() - pos)}; + } + + template + // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward) + constexpr size_type find_first_if_(size_type pos, Condition&& condition) + const noexcept { + if (pos + 1 <= size()) { + for (size_type cur = pos; cur < size(); ++cur) { + if (condition(at_(cur))) { + return cur; + } + } + } + return npos; + } + + template + // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward) + constexpr size_type find_last_if_(size_type pos, Condition&& condition) + const noexcept { + // Write it iteratively. This is faster. + if (!empty()) { + pos = std::min(size() - 1, pos); + do { + if (condition(at_(pos))) { + return pos; + } + } while (pos-- > 0); + } + return npos; + } + + constexpr bool equals_(basic_string_view rhs) const { + // We don't use string_view::compare() here but implement it manually + // because only looking at equality allows for more optimized code. +#if defined(__GNUC__) && !defined(__CUDACC__) + return size() == rhs.size() && + 0 == __builtin_memcmp(data(), rhs.data(), size()); +#else + if (size() != rhs.size()) { + return false; + } + // Yes, memcmp would be laster than this loop, but memcmp isn't constexpr + // and I didn't feel like implementing a constexpr memcmp variant. + // TODO At some point this should probably be done, including tricks + // like comparing one machine word instead of a byte per iteration. + for (typename basic_string_view::size_type pos = 0; pos < size(); + ++pos) { + if (at_(pos) != rhs.at_(pos)) { + return false; + } + } + return true; +#endif + } + + struct charIsEqual_ final { + CharT expected; + constexpr bool operator()(CharT actual) const noexcept { + return expected == actual; + } + }; + + struct charIsNotEqual_ final { + CharT expected; + constexpr bool operator()(CharT actual) const noexcept { + return expected != actual; + } + }; + + struct stringViewContainsChar_ final { + basic_string_view expected; + constexpr bool operator()(CharT ch) const noexcept { + return npos != expected.find(ch); + } + }; + + struct stringViewDoesNotContainChar_ final { + basic_string_view expected; + constexpr bool operator()(CharT ch) const noexcept { + return npos == expected.find(ch); + } + }; + + const_pointer begin_; + size_type size_{}; +}; + +template +inline std::basic_ostream& operator<<( + std::basic_ostream& stream, + basic_string_view sv) { + // The rules for operator<< are quite complex, so lets defer to the + // STL implementation. + using std_string_type = ::std::basic_string_view; + return stream << std_string_type(sv.data(), sv.size()); +} + +template +constexpr inline void swap( + basic_string_view& lhs, + basic_string_view& rhs) noexcept { + lhs.swap(rhs); +} +using string_view = std::string_view; +using c10_string_view = basic_string_view; + +// NOTE: In C++20, this function should be replaced by string_view.starts_with +constexpr bool starts_with( + const std::string_view s, + const std::string_view prefix) noexcept { + return (prefix.size() > s.size()) ? false + : prefix == s.substr(0, prefix.size()); +} + +// NOTE: In C++20, this function should be replaced by string_view.starts_with +constexpr bool starts_with( + const std::string_view s, + const char prefix) noexcept { + return !s.empty() && prefix == s.front(); +} + +// NOTE: In C++20, this function should be replaced by string_view.ends_with +constexpr bool ends_with( + const std::string_view s, + const std::string_view suffix) noexcept { + return (suffix.size() > s.size()) + ? false + : suffix == s.substr(s.size() - suffix.size(), suffix.size()); +} + +// NOTE: In C++20, this function should be replaced by string_view.ends_with +constexpr bool ends_with(const std::string_view s, const char prefix) noexcept { + return !s.empty() && prefix == s.back(); +} + +} // namespace c10 + +namespace std { +template +struct hash<::c10::basic_string_view> { + size_t operator()(::c10::basic_string_view x) const { + // The standard says that std::string_view hashing must do the same as + // std::string hashing but leaves the details of std::string hashing + // up to the implementer. So, to be conformant, we need to reuse and + // existing STL type's hash function. The std::string fallback is probably + // slow but the only way to be conformant. + + using std_string_type = ::std::basic_string_view; + return ::std::hash{}(std_string_type(x.data(), x.size())); + } +}; +} // namespace std + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/strong_type.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/strong_type.h new file mode 100644 index 0000000000000000000000000000000000000000..4e3d1a431c19958786bba8245d56bb12854fd5e3 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/strong_type.h @@ -0,0 +1,1669 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/* + * strong_type C++14/17/20 strong typedef library + * + * Copyright (C) Björn Fahller + * + * Use, modification and distribution is subject to the + * Boost Software License, Version 1.0. (See accompanying + * file LICENSE_1_0.txt or copy at + * http://www.boost.org/LICENSE_1_0.txt) + * + * Project home: https://github.com/rollbear/strong_type + */ + +#ifndef ROLLBEAR_STRONG_TYPE_HPP_INCLUDED +#define ROLLBEAR_STRONG_TYPE_HPP_INCLUDED + +#include +#include +#include +#include +#include + +#ifndef STRONG_HAS_STD_FORMAT +#define STRONG_HAS_STD_FORMAT 0 +#endif + +#ifndef STRONG_HAS_FMT_FORMAT +#define STRONG_HAS_FMT_FORMAT 0 +#endif + +#if STRONG_HAS_STD_FORMAT +#include +#if !defined(__cpp_lib_format) || __cpp_lib_format < 201907 +#undef STRONG_HAS_STD_FORMAT +#define STRONG_HAS_STD_FORMAT 0 +#endif +#endif + +#if STRONG_HAS_FMT_FORMAT +#include +#endif + +namespace strong +{ + +namespace impl +{ + template + using WhenConstructible = std::enable_if_t>; +} + +template +using modifier = typename M::template modifier; + +struct uninitialized_t {}; +static constexpr uninitialized_t uninitialized{}; + +struct default_constructible +{ + template + class modifier + { + }; +}; + +namespace impl { + template + constexpr bool supports_default_construction(const ::strong::default_constructible::modifier* /*unused*/) + { + return true; + } +} + +template +class type : public modifier>... +{ +public: + template {}>> + explicit type(uninitialized_t /*unused*/) + noexcept + { + } + template (nullptr))> + constexpr + type() + noexcept(noexcept(T{})) + : val{} + { + } + + template >> + constexpr + explicit + type( + std::initializer_list us + ) + noexcept(noexcept(T{us})) + : val{us} + { + } + template && (sizeof...(U) > 0)>> + constexpr + explicit + type( + U&& ... u) + noexcept(std::is_nothrow_constructible_v) + : val(std::forward(u)...) + {} + + friend constexpr void swap(type& a, type& b) noexcept( + std::is_nothrow_move_constructible_v && + std::is_nothrow_move_assignable_v + ) + { + using std::swap; + swap(a.val, b.val); + } + + [[nodiscard]] + constexpr T& value_of() & noexcept { return val;} + [[nodiscard]] + constexpr const T& value_of() const & noexcept { return val;} + [[nodiscard]] + constexpr T&& value_of() && noexcept { return std::move(val);} + + [[nodiscard]] + friend constexpr T& value_of(type& t) noexcept { return t.val;} + [[nodiscard]] + friend constexpr const T& value_of(const type& t) noexcept { return t.val;} + [[nodiscard]] + friend constexpr T&& value_of(type&& t) noexcept { return std::move(t).val;} +private: + T val; +}; + +namespace impl { + template + constexpr bool is_strong_type_func(const strong::type* /*unused*/) { return true;} + constexpr bool is_strong_type_func(...) { return false;} + template + constexpr T underlying_type(strong::type*); + +} + +template +struct is_strong_type : std::integral_constant(nullptr))> {}; + +namespace impl { + template + using WhenStrongType = std::enable_if_t>::value>; + template + using WhenNotStrongType = std::enable_if_t>::value>; +} + +template ::value> +struct underlying_type +{ + using type = decltype(impl::underlying_type(static_cast(nullptr))); +}; + +template +struct underlying_type +{ + using type = T; +}; + +template +using underlying_type_t = typename underlying_type::type; + + +namespace impl { + template< + typename T, + typename = impl::WhenNotStrongType> + constexpr + T && + access(T &&t) + noexcept { + return std::forward(t); + } + template < + typename T, + typename = impl::WhenStrongType> + [[nodiscard]] + constexpr + auto + access(T&& t) + noexcept + -> decltype(value_of(std::forward(t))) + { + return value_of(std::forward(t)); + } + +} +struct equality +{ + template + class modifier; +}; + + +template +class equality::modifier<::strong::type> +{ + using type = ::strong::type; +public: + [[nodiscard]] + friend + constexpr + auto + operator==( + const type& lh, + const type& rh) + noexcept(noexcept(std::declval() == std::declval())) + -> decltype(std::declval() == std::declval()) + { + return value_of(lh) == value_of(rh); + } + + [[nodiscard]] + friend + constexpr + auto + operator!=( + const type& lh, + const type& rh) + noexcept(noexcept(std::declval() != std::declval())) + -> decltype(std::declval() != std::declval()) + { + return value_of(lh) != value_of(rh); + } +}; + +namespace impl +{ + template + class typed_equality + { + private: + using TT = underlying_type_t; + using OT = underlying_type_t; + public: + [[nodiscard]] + friend + constexpr + auto operator==(const T& lh, const Other& rh) + noexcept(noexcept(std::declval() == std::declval())) + -> decltype(std::declval() == std::declval()) + { + return value_of(lh) == impl::access(rh); + } + [[nodiscard]] + friend + constexpr + auto operator==(const Other& lh, const T& rh) + noexcept(noexcept(std::declval() == std::declval())) + -> decltype(std::declval() == std::declval()) + { + return impl::access(lh) == value_of(rh) ; + } + [[nodiscard]] + friend + constexpr + auto operator!=(const T& lh, const Other rh) + noexcept(noexcept(std::declval() != std::declval())) + -> decltype(std::declval() != std::declval()) + { + return value_of(lh) != impl::access(rh); + } + [[nodiscard]] + friend + constexpr + auto operator!=(const Other& lh, const T& rh) + noexcept(noexcept(std::declval() != std::declval())) + -> decltype(std::declval() != std::declval()) + { + return impl::access(lh) != value_of(rh) ; + } + }; +} +template +struct equality_with +{ + template + class modifier : public impl::typed_equality... + { + }; +}; + +namespace impl +{ + template + class typed_ordering + { + private: + using TT = underlying_type_t; + using OT = underlying_type_t; + public: + [[nodiscard]] + friend + constexpr + auto operator<(const T& lh, const Other& rh) + noexcept(noexcept(std::declval() < std::declval())) + -> decltype(std::declval() < std::declval()) + { + return value_of(lh) < impl::access(rh); + } + [[nodiscard]] + friend + constexpr + auto operator<(const Other& lh, const T& rh) + noexcept(noexcept(std::declval() < std::declval())) + -> decltype(std::declval() < std::declval()) + { + return impl::access(lh) < value_of(rh) ; + } + + [[nodiscard]] + friend + constexpr + auto operator<=(const T& lh, const Other& rh) + noexcept(noexcept(std::declval() <= std::declval())) + -> decltype(std::declval() <= std::declval()) + { + return value_of(lh) <= impl::access(rh); + } + [[nodiscard]] + friend + constexpr + auto operator<=(const Other& lh, const T& rh) + noexcept(noexcept(std::declval() <= std::declval())) + -> decltype(std::declval() <= std::declval()) + { + return impl::access(lh) <= value_of(rh) ; + } + + [[nodiscard]] + friend + constexpr + auto operator>(const T& lh, const Other& rh) + noexcept(noexcept(std::declval() > std::declval())) + -> decltype(std::declval() > std::declval()) + { + return value_of(lh) > impl::access(rh); + } + [[nodiscard]] + friend + constexpr + auto operator>(const Other& lh, const T& rh) + noexcept(noexcept(std::declval() > std::declval())) + -> decltype(std::declval() > std::declval()) + { + return impl::access(lh) > value_of(rh) ; + } + + [[nodiscard]] + friend + constexpr + auto operator>=(const T& lh, const Other& rh) + noexcept(noexcept(std::declval() >= std::declval())) + -> decltype(std::declval() >= std::declval()) + { + return value_of(lh) >= impl::access(rh); + } + [[nodiscard]] + friend + constexpr + auto operator>=(const Other& lh, const T& rh) + noexcept(noexcept(std::declval() >= std::declval())) + -> decltype(std::declval() >= std::declval()) + { + return impl::access(lh) >= value_of(rh) ; + } + }; +} + +template +struct ordered_with +{ + template + class modifier : public impl::typed_ordering... + { + }; +}; + +namespace impl +{ + template + struct require_copy_constructible + { + static constexpr bool value = std::is_copy_constructible>::value; + static_assert(value, "underlying type must be copy constructible"); + }; + template + struct require_move_constructible + { + static constexpr bool value = std::is_move_constructible>::value; + static_assert(value, "underlying type must be move constructible"); + }; + template + struct require_copy_assignable + { + static constexpr bool value = std::is_copy_assignable>::value; + static_assert(value, "underlying type must be copy assignable"); + }; + template + struct require_move_assignable + { + static constexpr bool value = std::is_move_assignable>::value; + static_assert(value, "underlying type must be move assignable"); + }; + + template struct valid_type; + template <> + struct valid_type {}; + + template + struct require_semiregular + : valid_type::value && + require_move_constructible::value && + require_copy_assignable::value && + require_move_assignable::value> + { + }; + +} +struct semiregular +{ + template + class modifier; +}; + +template +class semiregular::modifier<::strong::type> + : public default_constructible::modifier + , private impl::require_semiregular +{ +}; + +struct regular +{ + template + class modifier + : public semiregular::modifier + , public equality::modifier + { + }; +}; + +struct unique +{ + template + class modifier + : private impl::valid_type< + impl::require_move_constructible::value && + impl::require_move_assignable::value + > + { + public: + constexpr modifier() = default; + modifier(const modifier&) = delete; + constexpr modifier(modifier&&) = default; + modifier& operator=(const modifier&) = delete; + constexpr modifier& operator=(modifier&&) = default; + }; +}; +struct ordered +{ + template + class modifier; +}; + + +template +class ordered::modifier<::strong::type> +{ + using type = ::strong::type; +public: + [[nodiscard]] + friend + constexpr + auto + operator<( + const type& lh, + const type& rh) + noexcept(noexcept(std::declval() < std::declval())) + -> decltype(std::declval() < std::declval()) + { + return value_of(lh) < value_of(rh); + } + + [[nodiscard]] + friend + constexpr + auto + operator<=( + const type& lh, + const type& rh) + noexcept(noexcept(std::declval() <= std::declval())) + -> decltype(std::declval() <= std::declval()) + { + return value_of(lh) <= value_of(rh); + } + + [[nodiscard]] + friend + constexpr + auto + operator>( + const type& lh, + const type& rh) + noexcept(noexcept(std::declval() > std::declval())) + -> decltype(std::declval() > std::declval()) + { + return value_of(lh) > value_of(rh); + } + + [[nodiscard]] + friend + constexpr + + auto + operator>=( + const type& lh, + const type& rh) + noexcept(noexcept(std::declval() >= std::declval())) + -> decltype(std::declval() >= std::declval()) + { + return value_of(lh) >= value_of(rh); + } +}; + +struct ostreamable +{ + template + class modifier + { + public: + friend + std::ostream& + operator<<( + std::ostream &os, + const T &t) + { + return os << value_of(t); + } + }; +}; + +struct istreamable +{ + template + class modifier + { + public: + friend + std::istream& + operator>>( + std::istream &is, + T &t) + { + return is >> value_of(t); + } + }; +}; + +struct iostreamable +{ + template + class modifier + : public ostreamable::modifier + , public istreamable::modifier + { + }; +}; + +struct incrementable +{ + template + class modifier + { + public: + friend + constexpr + T& + operator++(T& t) + noexcept(noexcept(++std::declval().value_of())) + { + ++value_of(t); + return t; + } + + friend + constexpr + T + operator++(T& t, int) + { + auto copy = t; + ++t; + return copy; + } + }; +}; + +struct decrementable +{ + template + class modifier + { + public: + friend + constexpr + T& + operator--(T& t) + noexcept(noexcept(--std::declval().value_of())) + { + --value_of(t); + return t; + } + + friend + constexpr + T + operator--(T& t, int) + { + auto copy = t; + --t; + return copy; + } + }; +}; + +struct bicrementable +{ + template + class modifier + : public incrementable::modifier + , public decrementable::modifier + { + }; +}; + +struct boolean +{ + template + class modifier + { + public: + explicit constexpr operator bool() const + noexcept(noexcept(static_cast(value_of(std::declval())))) + { + const auto& self = static_cast(*this); + return static_cast(value_of(self)); + } + }; +}; + +struct hashable +{ + template + class modifier{}; +}; + +struct difference +{ + template + class modifier; +}; + +template +class difference::modifier<::strong::type> +: public ordered::modifier<::strong::type> +, public equality::modifier<::strong::type> +{ + using type = ::strong::type; +public: + friend + constexpr + type& operator+=(type& lh, const type& rh) + noexcept(noexcept(value_of(lh) += value_of(rh))) + { + value_of(lh) += value_of(rh); + return lh; + } + + friend + constexpr + type& operator-=(type& lh, const type& rh) + noexcept(noexcept(value_of(lh) -= value_of(rh))) + { + value_of(lh) -= value_of(rh); + return lh; + } + + friend + constexpr + type& operator*=(type& lh, const T& rh) + noexcept(noexcept(value_of(lh) *= rh)) + { + value_of(lh) *= rh; + return lh; + } + + friend + constexpr + type& operator/=(type& lh, const T& rh) + noexcept(noexcept(value_of(lh) /= rh)) + { + value_of(lh) /= rh; + return lh; + } + + template ()%= std::declval())> + friend + constexpr + type& operator%=(type& lh, const T& rh) + noexcept(noexcept(value_of(lh) %= rh)) + { + value_of(lh)%= rh; + return lh; + } + + friend + constexpr + type operator+(type lh, const type& rh) + { + lh += rh; + return lh; + } + + friend + constexpr + type operator-(type lh, const type& rh) + { + lh -= rh; + return lh; + } + + friend + constexpr + type operator*(type lh, const T& rh) + { + lh *= rh; + return lh; + } + + friend + constexpr + type operator*(const T& lh, type rh) + { + rh *= lh; + return rh; + } + + friend + constexpr + type operator/(type lh, const T& rh) + { + lh /= rh; + return lh; + } + + friend + constexpr + T operator/(const type& lh, const type& rh) + { + return value_of(lh) / value_of(rh); + } + + template () %= std::declval())> + friend + constexpr + type operator%(type lh, const T& rh) + noexcept(noexcept(lh%= rh)) + { + lh %= rh; + return lh; + } + + template () % std::declval())> + friend + constexpr + T operator%(type lh, type rh) + noexcept(noexcept(value_of(lh) % value_of(rh))) + { + return value_of(lh) % value_of(rh); + } +}; + +template +struct affine_point +{ + template + class modifier; +}; + +namespace impl +{ + template + using void_t = void; + + template + struct subtractable : std::false_type {}; + + template + struct subtractable() - std::declval())>> + : std::true_type {}; +} + + +template +template +class affine_point::modifier<::strong::type> +{ + using type = ::strong::type; + static_assert(impl::subtractable::value, "it must be possible to subtract instances of your underlying type"); + using base_diff_type = decltype(std::declval() - std::declval()); +public: + using difference = std::conditional_t{}, strong::type, D>; + static_assert(std::is_constructible_v, ""); + [[nodiscard]] + friend + constexpr + difference + operator-( + const type& lh, + const type& rh) + { + return difference(value_of(lh) - value_of(rh)); + } + + friend + constexpr + type& + operator+=( + type& lh, + const difference& d) + noexcept(noexcept(value_of(lh) += impl::access(d))) + { + value_of(lh) += impl::access(d); + return lh; + } + + friend + constexpr + type& + operator-=( + type& lh, + const difference& d) + noexcept(noexcept(value_of(lh) -= impl::access(d))) + { + value_of(lh) -= impl::access(d); + return lh; + } + + [[nodiscard]] + friend + constexpr + type + operator+( + type lh, + const difference& d) + { + return lh += d; + } + + [[nodiscard]] + friend + constexpr + type + operator+( + const difference& d, + type rh) + { + return rh+= d; + } + + [[nodiscard]] + friend + constexpr + type + operator-( + type lh, + const difference& d) + { + return lh -= d; + } +}; + + +struct pointer +{ + template + class modifier; +}; + +template +class pointer::modifier<::strong::type> +{ + using type = strong::type; +public: + template + [[nodiscard]] + friend + constexpr + auto + operator==( + const type& t, + std::nullptr_t) + noexcept(noexcept(std::declval() == nullptr)) + -> decltype(std::declval() == nullptr) + { + return value_of(t) == nullptr; + } + + template + [[nodiscard]] + friend + constexpr + auto + operator==( + std::nullptr_t, + const type& t) + noexcept(noexcept(nullptr == std::declval())) + -> decltype(nullptr == std::declval()) + { + return value_of(t) == nullptr; + } + + template + [[nodiscard]] + friend + constexpr + auto + operator!=( + const type& t, + std::nullptr_t) + noexcept(noexcept(std::declval() != nullptr)) + -> decltype(std::declval() != nullptr) + { + return value_of(t) != nullptr; + } + + template + [[nodiscard]] + friend + constexpr + auto + operator!=( + std::nullptr_t, + const type& t) + noexcept(noexcept(nullptr != std::declval())) + -> decltype(nullptr != std::declval()) + { + return value_of(t) != nullptr; + } + + [[nodiscard]] + constexpr + decltype(*std::declval()) + operator*() + const + { + auto& self = static_cast(*this); + return *value_of(self); + } + + [[nodiscard]] + constexpr + decltype(&(*std::declval())) operator->() const { return &operator*();} +}; + +struct arithmetic +{ + template + class modifier + { + public: + [[nodiscard]] + friend + constexpr + T + operator-( + const T &lh) + { + return T{-value_of(lh)}; + } + + friend + constexpr + T& + operator+=( + T &lh, + const T &rh) + noexcept(noexcept(value_of(lh) += value_of(rh))) + { + value_of(lh) += value_of(rh); + return lh; + } + + friend + constexpr + T& + operator-=( + T &lh, + const T &rh) + noexcept(noexcept(value_of(lh) -= value_of(rh))) + { + value_of(lh) -= value_of(rh); + return lh; + } + + friend + constexpr + T& + operator*=( + T &lh, + const T &rh) + noexcept(noexcept(value_of(lh) *= value_of(rh))) + { + value_of(lh) *= value_of(rh); + return lh; + } + + friend + constexpr + T& + operator/=( + T &lh, + const T &rh) + noexcept(noexcept(value_of(lh) /= value_of(rh))) + { + value_of(lh) /= value_of(rh); + return lh; + } + + template ()) % value_of(std::declval()))> + friend + constexpr + T& + operator%=( + T &lh, + const T &rh) + noexcept(noexcept(value_of(lh) %= value_of(rh))) + { + value_of(lh) %= value_of(rh); + return lh; + } + + [[nodiscard]] + friend + constexpr + T + operator+( + T lh, + const T &rh) + { + lh += rh; + return lh; + } + + [[nodiscard]] + friend + constexpr + T + operator-( + T lh, + const T &rh) + { + lh -= rh; + return lh; + } + + [[nodiscard]] + friend + constexpr + T + operator*( + T lh, + const T &rh) + { + lh *= rh; + return lh; + } + + [[nodiscard]] + friend + constexpr + T + operator/( + T lh, + const T &rh) + { + lh /= rh; + return lh; + } + + template ()) % value_of(std::declval()))> + [[nodiscard]] + friend + constexpr + T + operator%( + T lh, + const T &rh) + { + lh %= rh; + return lh; + } + + }; +}; + + +struct bitarithmetic +{ + template + class modifier + { + public: + friend + constexpr + T& + operator&=( + T &lh, + const T &rh) + noexcept(noexcept(value_of(lh) &= value_of(rh))) + { + value_of(lh) &= value_of(rh); + return lh; + } + + friend + constexpr + T& + operator|=( + T &lh, + const T &rh) + noexcept(noexcept(value_of(lh) |= value_of(rh))) + { + value_of(lh) |= value_of(rh); + return lh; + } + + friend + constexpr + T& + operator^=( + T &lh, + const T &rh) + noexcept(noexcept(value_of(lh) ^= value_of(rh))) + { + value_of(lh) ^= value_of(rh); + return lh; + } + + template + friend + constexpr + T& + operator<<=( + T &lh, + C c) + noexcept(noexcept(value_of(lh) <<= c)) + { + value_of(lh) <<= c; + return lh; + } + + template + friend + constexpr + T& + operator>>=( + T &lh, + C c) + noexcept(noexcept(value_of(lh) >>= c)) + { + value_of(lh) >>= c; + return lh; + } + + [[nodiscard]] + friend + constexpr + T + operator~( + const T &lh) + { + auto v = value_of(lh); + v = ~v; + return T(v); + } + + [[nodiscard]] + friend + constexpr + T + operator&( + T lh, + const T &rh) + { + lh &= rh; + return lh; + } + + [[nodiscard]] + friend + constexpr + T + operator|( + T lh, + const T &rh) + { + lh |= rh; + return lh; + } + + [[nodiscard]] + friend + constexpr + T + operator^( + T lh, + const T &rh) + { + lh ^= rh; + return lh; + } + + template + [[nodiscard]] + friend + constexpr + T + operator<<( + T lh, + C c) + { + lh <<= c; + return lh; + } + + template + [[nodiscard]] + friend + constexpr + T + operator>>( + T lh, + C c) + { + lh >>= c; + return lh; + } + }; +}; +template +struct indexed +{ + template + class modifier; +}; + +template <> +struct indexed { + template + class modifier; + + template + class modifier> { + using ref = T&; + using cref = const T&; + using rref = T&&; + using type = strong::type; + public: + template + [[nodiscard]] + auto + operator[]( + const I &i) + const & + noexcept(noexcept(std::declval()[impl::access(i)])) + -> decltype(std::declval()[impl::access(i)]) { + auto& self = static_cast(*this); + return value_of(self)[impl::access(i)]; + } + + template + [[nodiscard]] + auto + operator[]( + const I &i) + & + noexcept(noexcept(std::declval()[impl::access(i)])) + -> decltype(std::declval()[impl::access(i)]) { + auto& self = static_cast(*this); + return value_of(self)[impl::access(i)]; + } + + template + [[nodiscard]] + auto + operator[]( + const I &i) + && + noexcept(noexcept(std::declval()[impl::access(i)])) + -> decltype(std::declval()[impl::access(i)]) { + auto& self = static_cast(*this); + return value_of(std::move(self))[impl::access(i)]; + } + + template + [[nodiscard]] + auto + at( + const I &i) + const & + -> decltype(std::declval().at(impl::access(i))) { + auto& self = static_cast(*this); + return value_of(self).at(impl::access(i)); + } + + template + [[nodiscard]] + auto + at( + const I &i) + & + -> decltype(std::declval().at(impl::access(i))) { + auto& self = static_cast(*this); + return value_of(self).at(impl::access(i)); + } + + template + [[nodiscard]] + auto + at( + const I &i) + && + -> decltype(std::declval().at(impl::access(i))) { + auto& self = static_cast(*this); + return value_of(std::move(self)).at(impl::access(i)); + } + }; +}; + +template +template +class indexed::modifier> +{ + using type = ::strong::type; +public: + [[nodiscard]] + auto + operator[]( + const I& i) + const & + noexcept(noexcept(std::declval()[impl::access(i)])) + -> decltype(std::declval()[impl::access(i)]) + { + auto& self = static_cast(*this); + return value_of(self)[impl::access(i)]; + } + + [[nodiscard]] + auto + operator[]( + const I& i) + & + noexcept(noexcept(std::declval()[impl::access(i)])) + -> decltype(std::declval()[impl::access(i)]) + { + auto& self = static_cast(*this); + return value_of(self)[impl::access(i)]; + } + + [[nodiscard]] + auto + operator[]( + const I& i) + && + noexcept(noexcept(std::declval()[impl::access(i)])) + -> decltype(std::declval()[impl::access(i)]) + { + auto& self = static_cast(*this); + return value_of(std::move(self))[impl::access(i)]; + } + + template + [[nodiscard]] + auto + at( + const I& i) + const & + -> decltype(std::declval().at(impl::access(i))) + { + auto& self = static_cast(*this); + return value_of(self).at(impl::access(i)); + } + + template + [[nodiscard]] + auto + at( + const I& i) + & + -> decltype(std::declval().at(impl::access(i))) + { + auto& self = static_cast(*this); + return value_of(self).at(impl::access(i)); + } + + template + [[nodiscard]] + auto + at( + const I& i) + && + -> decltype(std::declval().at(impl::access(i))) + { + auto& self = static_cast(*this); + return value_of(std::move(self)).at(impl::access(i)); + } +}; + +class iterator +{ +public: + template >::iterator_category> + class modifier + : public pointer::modifier + , public equality::modifier + , public incrementable::modifier + { + public: + using difference_type = typename std::iterator_traits>::difference_type; + using value_type = typename std::iterator_traits>::value_type; + using pointer = typename std::iterator_traits>::value_type; + using reference = typename std::iterator_traits>::reference; + using iterator_category = typename std::iterator_traits>::iterator_category; + }; + + template + class modifier + : public modifier + , public decrementable::modifier + { + }; + template + class modifier + : public modifier + , public affine_point>::difference_type>::template modifier + , public indexed<>::modifier + , public ordered::modifier + { + }; +}; + +class range +{ +public: + template + class modifier; +}; + +template +class range::modifier> +{ + using type = ::strong::type; + using r_iterator = decltype(std::declval().begin()); + using r_const_iterator = decltype(std::declval().begin()); +public: + using iterator = ::strong::type; + using const_iterator = ::strong::type; + + iterator + begin() + noexcept(noexcept(std::declval().begin())) + { + auto& self = static_cast(*this); + return iterator{value_of(self).begin()}; + } + + iterator + end() + noexcept(noexcept(std::declval().end())) + { + auto& self = static_cast(*this); + return iterator{value_of(self).end()}; + } + + const_iterator + cbegin() + const + noexcept(noexcept(std::declval().begin())) + { + auto& self = static_cast(*this); + return const_iterator{value_of(self).begin()}; + } + + const_iterator + cend() + const + noexcept(noexcept(std::declval().end())) + { + auto& self = static_cast(*this); + return const_iterator{value_of(self).end()}; + } + + const_iterator + begin() + const + noexcept(noexcept(std::declval().begin())) + { + auto& self = static_cast(*this); + return const_iterator{value_of(self).begin()}; + } + + const_iterator + end() + const + noexcept(noexcept(std::declval().end())) + { + auto& self = static_cast(*this); + return const_iterator{value_of(self).end()}; + } +}; + +namespace impl { + + template + struct converter + { + constexpr explicit operator D() const + noexcept(noexcept(static_cast(std::declval&>()))) + { + auto& self = static_cast(*this); + return static_cast(value_of(self)); + } + }; + template + struct implicit_converter + { + constexpr operator D() const + noexcept(noexcept(static_cast(std::declval&>()))) + { + auto& self = static_cast(*this); + return static_cast(value_of(self)); + } + }; +} +template +struct convertible_to +{ + template + struct modifier : impl::converter... + { + }; +}; + +template +struct implicitly_convertible_to +{ + template + struct modifier : impl::implicit_converter... + { + }; + +}; + +struct formattable +{ + template + class modifier{}; +}; + +} + +namespace std { +template +struct hash<::strong::type> + : std::conditional_t< + std::is_base_of< + ::strong::hashable::modifier< + ::strong::type + >, + ::strong::type + >::value, + hash, + std::false_type> +{ + using type = ::strong::type; + decltype(auto) + operator()( + const ::strong::hashable::modifier& t) + const + noexcept(noexcept(std::declval>()(value_of(std::declval())))) + { + auto& tt = static_cast(t); + return hash::operator()(value_of(tt)); + } +}; + +#if STRONG_HAS_STD_FORMAT +template +struct formatter<::strong::type, Char, + std::enable_if_t< + std::is_base_of< + ::strong::formattable::modifier< + ::strong::type + >, + ::strong::type + >::value + >> + : formatter +{ + using type = ::strong::type; + template + constexpr + decltype(auto) + format(const ::strong::formattable::modifier& t, FormatContext& fc) + noexcept(noexcept(std::declval>().format(value_of(std::declval()), fc))) + { + const auto& tt = static_cast(t); + return formatter::format(value_of(tt), fc); + } +}; +#endif + +} + +#if STRONG_HAS_FMT_FORMAT +namespace fmt +{ +template +struct formatter<::strong::type, Char, + std::enable_if_t< + std::is_base_of< + ::strong::formattable::modifier< + ::strong::type + >, + ::strong::type + >::value + >> + : formatter +{ + using type = ::strong::type; + template + constexpr + decltype(auto) + format(const ::strong::formattable::modifier& t, FormatContext& fc) + noexcept(noexcept(std::declval>().format(value_of(std::declval()), fc))) + { + const auto& tt = static_cast(t); + return formatter::format(value_of(tt), fc); + } +}; +} +#endif +#endif //ROLLBEAR_STRONG_TYPE_HPP_INCLUDED + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/tempfile.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/tempfile.h new file mode 100644 index 0000000000000000000000000000000000000000..afcf4504c87a49112a6f4f21c6fe08f153c49a2a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/tempfile.h @@ -0,0 +1,94 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include + +namespace c10 { +struct C10_API TempFile { + TempFile(std::string_view name, int fd = -1) noexcept : fd(fd), name(name) {} + TempFile(const TempFile&) = delete; + TempFile(TempFile&& other) noexcept + : fd(other.fd), name(std::move(other.name)) { + other.fd = -1; + } + + TempFile& operator=(const TempFile&) = delete; + TempFile& operator=(TempFile&& other) noexcept { + fd = other.fd; + name = std::move(other.name); + other.fd = -1; + return *this; + } +#if defined(_WIN32) + bool open(); +#endif + + ~TempFile(); + + int fd; + + std::string name; +}; + +struct C10_API TempDir { + TempDir() = delete; + explicit TempDir(std::string_view name) noexcept : name(name) {} + TempDir(const TempDir&) = delete; + TempDir(TempDir&& other) noexcept : name(std::move(other.name)) { + other.name.clear(); + } + + TempDir& operator=(const TempDir&) = delete; + TempDir& operator=(TempDir&& other) noexcept { + name = std::move(other.name); + return *this; + } + + ~TempDir(); + + std::string name; +}; + +/// Attempts to return a temporary file or returns `nullopt` if an error +/// occurred. +/// +/// The file returned follows the pattern +/// `/`, where `` is the value of +/// the `"TMPDIR"`, `"TMP"`, `"TEMP"` or +/// `"TEMPDIR"` environment variable if any is set, or otherwise `/tmp`; +/// `` is the value supplied to this function, and +/// `` is a random sequence of numbers. +/// On Windows, `name_prefix` is ignored and `tmpnam_s` is used, +/// and no temporary file is opened. +C10_API std::optional try_make_tempfile( + std::string_view name_prefix = "torch-file-"); + +/// Like `try_make_tempfile`, but throws an exception if a temporary file could +/// not be returned. +C10_API TempFile make_tempfile(std::string_view name_prefix = "torch-file-"); + +/// Attempts to return a temporary directory or returns `nullopt` if an error +/// occurred. +/// +/// The directory returned follows the pattern +/// `//`, where `` is the value +/// of the `"TMPDIR"`, `"TMP"`, `"TEMP"` or +/// `"TEMPDIR"` environment variable if any is set, or otherwise `/tmp`; +/// `` is the value supplied to this function, and +/// `` is a random sequence of numbers. +/// On Windows, `name_prefix` is ignored. +C10_API std::optional try_make_tempdir( + std::string_view name_prefix = "torch-dir-"); + +/// Like `try_make_tempdir`, but throws an exception if a temporary directory +/// could not be returned. +C10_API TempDir make_tempdir(std::string_view name_prefix = "torch-dir-"); +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/thread_name.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/thread_name.h new file mode 100644 index 0000000000000000000000000000000000000000..5cda361bc8f17f673fb6735b76261b82d821f26d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/thread_name.h @@ -0,0 +1,18 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include + +#include + +namespace c10 { + +C10_API void setThreadName(std::string name); + +C10_API std::string getThreadName(); + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/typeid.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/typeid.h new file mode 100644 index 0000000000000000000000000000000000000000..3f7da4264ad5339af2535aeb10863ef07c75515a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/typeid.h @@ -0,0 +1,720 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +/* + * TypeIdentifier is a small type containing an id. + * Types must be registered using CAFFE_DECLARE_KNOWN_TYPE() (in their header) + * and CAFFE_DEFINE_KNOWN_TYPE() (in their .cpp file) for them to have a type + * id. If a type is registered, you can also create an object containing meta + * data like constructor, destructor, stringified name, ... about the type by + * calling TypeMeta::Make. This returns a TypeMeta() object, which is + * basically just a pointer to the type information, so it's cheap to pass + * around. + */ + +// TODO: This file is still in the caffe2 namespace, despite living +// in the ATen directory. This is because the macro +// CAFFE_KNOWN_TYPE (and CAFFE_DECLARE_KNOWN_TYPE) defines a template +// specialization, which relies +// on the namespace of TypeMeta matching the namespace where the macro is +// called. This requires us to fix all of the call-sites, which I want to do +// later. So the namespace is not fixed at the moment. + +// Make at::Half a fundamental type. + +namespace c10::guts { +template <> +struct is_fundamental : std::true_type {}; +} // namespace c10::guts + +namespace caffe2 { + +/** + * A type id is a unique id for a given C++ type. + * You need to register your types using CAFFE_KNOWN_TYPE(MyType) to be able to + * use TypeIdentifier with custom types. This is for example used to store the + * dtype of tensors. + */ +class C10_API TypeIdentifier final + : public at::IdWrapper { + public: + friend std::ostream& operator<<(std::ostream& stream, TypeIdentifier typeId); + friend constexpr bool operator<(TypeIdentifier lhs, TypeIdentifier rhs); + + /** + * Returns the unique id for the given type T. The id is unique for the type T + * in the sense that for any two different types, their ids are different; for + * the same type T, the id remains the same over different calls of the + * function. However, this is not guaranteed over different runs, as the id + * is generated during run-time. Do NOT serialize the id for storage. + */ + template + static constexpr TypeIdentifier Get() noexcept { + return TypeIdentifier(c10::util::get_type_index()); + } + + static constexpr TypeIdentifier uninitialized() { + return TypeIdentifier(c10::util::type_index{0}); + } + + private: + constexpr explicit TypeIdentifier(c10::util::type_index id) : IdWrapper(id) {} +}; + +// Allow usage in std::map / std::set +// TODO Disallow this and rather use std::unordered_map/set everywhere +inline constexpr bool operator<(TypeIdentifier lhs, TypeIdentifier rhs) { + return lhs.underlyingId() < rhs.underlyingId(); +} + +inline std::ostream& operator<<( + std::ostream& stream, + caffe2::TypeIdentifier typeId) { + return stream << typeId.underlyingId(); +} + +} // namespace caffe2 + +namespace at { +using DataType = caffe2::TypeIdentifier; +} + +C10_DEFINE_HASH_FOR_IDWRAPPER(caffe2::TypeIdentifier) + +namespace caffe2 { + +namespace detail { + +// This struct holds the actual type information. There will be +// one allocated per type. TypeMeta objects will then point to the struct +// instance for the type they're configured for. +struct TypeMetaData final { + using New = void*(); + using PlacementNew = void(void*, size_t); + using Copy = void(const void*, void*, size_t); + using PlacementDelete = void(void*, size_t); + using Delete = void(void*); + + constexpr TypeMetaData() noexcept + : itemsize_(0), + new_(nullptr), + placementNew_(nullptr), + copy_(nullptr), + placementDelete_(nullptr), + delete_(nullptr), + id_(TypeIdentifier::uninitialized()), + name_("nullptr (uninitialized)") {} + + constexpr TypeMetaData( + size_t itemsize, + New* newFn, + PlacementNew* placementNew, + Copy* copy, + PlacementDelete* placementDelete, + Delete* deleteFn, + TypeIdentifier id, + std::string_view name) noexcept + : itemsize_(itemsize), + new_(newFn), + placementNew_(placementNew), + copy_(copy), + placementDelete_(placementDelete), + delete_(deleteFn), + id_(id), + name_(name) {} + + size_t itemsize_; + New* new_; + PlacementNew* placementNew_; + Copy* copy_; + PlacementDelete* placementDelete_; + Delete* delete_; + TypeIdentifier id_; + std::string_view name_; +}; + +// Mechanism for throwing errors which can't be prevented at compile time +// due to type erasure. E.g. somebody calling TypeMeta::copy() for +// non-copyable type. Right now just throws exception but is implemented +// in .cpp to manage dependencies +[[noreturn]] C10_API void _ThrowRuntimeTypeLogicError(const std::string& msg); + +/** + * Placement new function for the type. + */ +template +inline void _PlacementNew(void* ptr, size_t n) { + T* typed_ptr = static_cast(ptr); + for (const auto i : c10::irange(n)) { + new (typed_ptr + i) T; + } +} + +template +inline void _PlacementNewNotDefault(void* /*ptr*/, size_t /*n*/) { + _ThrowRuntimeTypeLogicError( + "Type " + std::string(c10::util::get_fully_qualified_type_name()) + + " is not default-constructible."); +} + +template < + typename T, + std::enable_if_t>* = nullptr> +inline constexpr TypeMetaData::PlacementNew* _PickPlacementNew() { + return (c10::guts::is_fundamental::value || std::is_pointer_v) + ? nullptr + : &_PlacementNew; +} + +template < + typename T, + std::enable_if_t>* = nullptr> +inline constexpr TypeMetaData::PlacementNew* _PickPlacementNew() { + static_assert( + !c10::guts::is_fundamental::value && !std::is_pointer_v, + "this should have picked the other SFINAE case"); + return &_PlacementNewNotDefault; +} + +template +inline void* _New() { + return new T; +} + +template +inline void* _NewNotDefault() { + _ThrowRuntimeTypeLogicError( + "Type " + std::string(c10::util::get_fully_qualified_type_name()) + + " is not default-constructible."); +} + +template < + typename T, + std::enable_if_t>* = nullptr> +inline constexpr TypeMetaData::New* _PickNew() { + return &_New; +} + +template < + typename T, + std::enable_if_t>* = nullptr> +inline constexpr TypeMetaData::New* _PickNew() { + return &_NewNotDefault; +} + +/** + * Typed copy function for classes. + */ +template +inline void _Copy(const void* src, void* dst, size_t n) { + const T* typed_src = static_cast(src); + T* typed_dst = static_cast(dst); + for (const auto i : c10::irange(n)) { + typed_dst[i] = typed_src[i]; + } +} + +/** + * A placeholder function for types that do not allow assignment. + */ +template +inline void _CopyNotAllowed(const void* /*src*/, void* /*dst*/, size_t /*n*/) { + _ThrowRuntimeTypeLogicError( + "Type " + std::string(c10::util::get_fully_qualified_type_name()) + + " does not allow assignment."); +} + +template >* = nullptr> +inline constexpr TypeMetaData::Copy* _PickCopy() { + return (c10::guts::is_fundamental::value || std::is_pointer_v) + ? nullptr + : &_Copy; +} + +template < + typename T, + std::enable_if_t>* = nullptr> +inline constexpr TypeMetaData::Copy* _PickCopy() { + static_assert( + !c10::guts::is_fundamental::value && !std::is_pointer_v, + "this should have picked the other SFINAE case"); + return &_CopyNotAllowed; +} + +/** + * Destructor for non-fundamental types. + */ +template +inline void _PlacementDelete(void* ptr, size_t n) { + T* typed_ptr = static_cast(ptr); + for (const auto i : c10::irange(n)) { + typed_ptr[i].~T(); + } +} + +template +inline constexpr TypeMetaData::PlacementDelete* _PickPlacementDelete() { + return (c10::guts::is_fundamental::value || std::is_pointer_v) + ? nullptr + : &_PlacementDelete; +} + +template +inline void _Delete(void* ptr) { + T* typed_ptr = static_cast(ptr); + delete typed_ptr; +} + +template +inline constexpr TypeMetaData::Delete* _PickDelete() noexcept { + return &_Delete; +} + +class _Uninitialized final {}; + +} // namespace detail + +// +// note: this is outside TypeMeta bc gcc seems to have trouble +// with scalarTypeItemSizes as a constexpr static member used by +// a public inline instance method +// + +// item sizes for TypeMeta::itemsize() fast path +static constexpr std::array scalarTypeItemSizes = { +#define SCALAR_TYPE_SIZE(T, name) sizeof(T), + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SCALAR_TYPE_SIZE) +#undef SCALAR_TYPE_SIZE + 0, // Undefined +}; + +/** + * TypeMeta is a thin class that allows us to store the type of a container such + * as a blob, or the data type of a tensor, with a unique run-time id. It also + * stores some additional data such as the item size and the name of the type + * for run-time inspection. + */ +class C10_API TypeMeta final { + public: + using New = detail::TypeMetaData::New; + using PlacementNew = detail::TypeMetaData::PlacementNew; + using Copy = detail::TypeMetaData::Copy; + using PlacementDelete = detail::TypeMetaData::PlacementDelete; + using Delete = detail::TypeMetaData::Delete; + + /** Create a dummy TypeMeta object. To create a TypeMeta object for a specific + * type, use TypeMeta::Make(). + */ + TypeMeta() noexcept; + ~TypeMeta() = default; + + /** + * Copy constructor. + */ + TypeMeta(const TypeMeta& src) noexcept = default; + + /** + * Assignment operators. + */ + TypeMeta& operator=(const TypeMeta& src) noexcept = default; + + TypeMeta& operator=(TypeMeta&& src) noexcept = default; + TypeMeta(TypeMeta&& rhs) noexcept = default; + + inline TypeMeta& operator=(ScalarType scalar_type) noexcept { + index_ = static_cast(scalar_type); + return *this; + } + + private: + // TypeMeta can only be created by Make, making sure that we do not + // create incorrectly mixed up TypeMeta objects. + explicit TypeMeta(const uint16_t index) noexcept : index_(index) {} + + public: + /** + * Returns the type id. + */ + TypeIdentifier id() const noexcept { + return data().id_; + } + /** + * true if we represent some ScalarType type + */ + inline bool isScalarType() const noexcept { + return index_ < NumScalarTypes; + } + /** + * true if we represent ScalarType scalar_type + */ + inline bool isScalarType(ScalarType scalar_type) const noexcept { + return index_ == static_cast(scalar_type); + } + /** + * Returns the size of the item. + */ + inline size_t itemsize() const noexcept { + if (C10_LIKELY(isScalarType())) { + return scalarTypeItemSizes[index_]; + } + return data().itemsize_; + } + /** + * Returns the new function pointer for individual items. + */ + New* newFn() const noexcept { + return data().new_; + } + /** + * Returns the placement new function pointer for individual items. + */ + PlacementNew* placementNew() const noexcept { + return data().placementNew_; + } + /** + * Returns the typed copy function pointer for individual items. + */ + Copy* copy() const noexcept { + return data().copy_; + } + /** + * Returns the destructor function pointer for individual items. + */ + PlacementDelete* placementDelete() const noexcept { + return data().placementDelete_; + } + Delete* deleteFn() const noexcept { + return data().delete_; + } + /** + * Returns a printable name for the type. + */ + std::string_view name() const noexcept { + return data().name_; + } + + friend bool operator==(const TypeMeta& lhs, const TypeMeta& rhs) noexcept; + + template + bool Match() const noexcept { + return (*this == Make()); + } + + // Below are static functions that can be called by passing a specific type. + + template + static constexpr TypeIdentifier Id() noexcept { + return TypeIdentifier::Get(); + } + + template + static std::string_view TypeName() noexcept { + return c10::util::get_fully_qualified_type_name(); + } + + template + static constexpr size_t ItemSize() noexcept { + return sizeof(T); + } + + /** + * Returns a TypeMeta object that corresponds to the typename T. + */ + template + static TypeMeta Make() { + // The instance pointed to is declared here, but defined in a .cpp file. + // We need to silence the compiler warning about using an undefined + // variable template. '-Wpragmas' and '-Wunknown-warning-option' has to be + // disabled for compilers that don't know '-Wundefined-var-template' and + // would error at our attempt to disable it. +#ifndef _MSC_VER +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wpragmas" +#pragma GCC diagnostic ignored "-Wunknown-warning-option" +#pragma GCC diagnostic ignored "-Wundefined-var-template" +#endif + return TypeMeta(_typeMetaData()); +#ifndef _MSC_VER +#pragma GCC diagnostic pop +#endif + } + + /** + * convert ScalarType enum values to TypeMeta handles + */ + static inline caffe2::TypeMeta fromScalarType(ScalarType scalar_type) { + const auto index = static_cast(scalar_type); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + index < NumScalarTypes, + "Unrecognized Scalartype ", + scalar_type, + " (please report this error)"); + return TypeMeta(index); + } + + /** + * convert TypeMeta handles to ScalarType enum values + */ + inline ScalarType toScalarType() const { + if (C10_LIKELY(isScalarType())) { + return static_cast(index_); + } + error_unsupported_typemeta(*this); + } + + private: + [[noreturn]] static void error_unsupported_typemeta(caffe2::TypeMeta dtype); + + // hard limit number of registered types + // note: constexpr provokes Windows compilation error "member may not be + // initialized" static constexpr size_t MaxTypeIndex = 32; + // +#if defined C10_MOBILE +// The reason for this not to be UINT8_MAX is that the array +// initialization takes space which is proportional to the size of the array. +// The compiler seems to add code (or data padding) to initialize the array with +// empty elements. Please see +// https://github.com/pytorch/pytorch/pull/51881 for details. +// +#define MaxTypeIndex \ + (NumScalarTypes + 15 /* number of CAFFE_DEFINE_KNOWN_TYPE in typeid.cpp */ + \ + 1 /* 1 more for caffe2 tensor */) +#else +#define MaxTypeIndex UINT8_MAX +#endif + + // Protects type metadata allocation. + // NOLINTNEXTLINE(facebook-hte-NonPodStaticDeclaration) + static std::mutex& getTypeMetaDatasLock(); + static uint16_t nextTypeIndex; + + static detail::TypeMetaData* typeMetaDatas(); + + static uint16_t existingMetaDataIndexForType(TypeIdentifier identifier); + + public: +#ifdef __CUDACC__ + // NOTE [ TypeIdentifier::Get nvcc/clang discrepancy] + // nvcc and clang do not produce identical results for + // TypeIdentifier::Get, because TypeIdentifier::Get relies on + // __PRETTY_FUNCTION__ and they don't agree on the canonical names + // of types (e.g., nvcc normalizes to `short unsigned int`, but clang + // calls it `unsigned short`). Hide the implementation of this function + // from nvcc so that we always use clang (or whatever host C++ compiler) + // for TypeIdentifier::Get. + template + C10_EXPORT static uint16_t addTypeMetaData(); +#else + template + C10_EXPORT static uint16_t addTypeMetaData() { + const auto identifier = TypeIdentifier::Get(); + // Need to hold this for the rest of the function, protecting: + // 1) existingMetaDataIndexForType() + // 2) nextTypeIndex++ + // 3) the write into typeMetaDatas() + std::lock_guard lock(getTypeMetaDatasLock()); + // It may exist already if added in a different dynamic shared library. + const uint16_t existing_index = existingMetaDataIndexForType(identifier); + if (existing_index != MaxTypeIndex) { + return existing_index; + } + const uint16_t index = nextTypeIndex++; + TORCH_CHECK( + index <= MaxTypeIndex, + "Maximum number of CAFFE_KNOWN_TYPE declarations has been exceeded. ", + "Please report this issue."); + typeMetaDatas()[index] = detail::TypeMetaData{ + sizeof(T), + detail::_PickNew(), + detail::_PickPlacementNew(), + detail::_PickCopy(), + detail::_PickPlacementDelete(), + detail::_PickDelete(), + identifier, + c10::util::get_fully_qualified_type_name()}; + return index; + } +#endif + + private: + // specializations return indexes into typeMetaDataInstances() + template + C10_API static uint16_t _typeMetaData() noexcept; + + // + // TypeMeta just wraps this index + // + + uint16_t index_; + + inline const detail::TypeMetaData& data() const { + return typeMetaDatas()[index_]; + } +}; + +// specializations of TypeMeta::_typeMetaData for ScalarType types + +#define DEFINE_SCALAR_METADATA_INSTANCE(T, name) \ + template <> \ + constexpr uint16_t TypeMeta::_typeMetaData() noexcept { \ + return static_cast(ScalarType::name); \ + } +AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_SCALAR_METADATA_INSTANCE) +#undef DEFINE_SCALAR_METADATA_INSTANCE + +template <> +C10_EXPORT constexpr uint16_t TypeMeta::_typeMetaData< + detail::_Uninitialized>() noexcept { + return static_cast(ScalarType::Undefined); +} + +inline TypeMeta::TypeMeta() noexcept + : index_(_typeMetaData()) {} + +inline bool operator==(const TypeMeta& lhs, const TypeMeta& rhs) noexcept { + return (lhs.index_ == rhs.index_); +} +inline bool operator!=(const TypeMeta& lhs, const TypeMeta& rhs) noexcept { + return !operator==(lhs, rhs); +} + +inline std::ostream& operator<<( + std::ostream& stream, + caffe2::TypeMeta typeMeta) { + return stream << typeMeta.name(); +} + +/** + * Register unique id for a type so it can be used in TypeMeta context, e.g. be + * used as a type for Blob or for Tensor elements. + * + * CAFFE_KNOWN_TYPE is deprecated; prefer CAFFE_DECLARE_KNOWN_TYPE and + * CAFFE_DEFINE_KNOWN_TYPE. + * + * CAFFE_KNOWN_TYPE does explicit instantiation of TypeIdentifier::Get + * template function and thus needs to be put in a single translation unit (.cpp + * file) for a given type T. Other translation units that use type T as a type + * of the caffe2::Blob or element type of caffe2::Tensor need to depend on the + * translation unit that contains CAFFE_KNOWN_TYPE declaration via regular + * linkage dependencies. + * + * NOTE: the macro needs to be invoked in ::caffe2 namespace + */ +// Implementation note: in MSVC, we will need to prepend the C10_API +// keyword in order to get things compiled properly. in Linux, gcc seems to +// create attribute ignored error for explicit template instantiations, see +// http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2017/p0537r0.html +// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=51930 +// and as a result, we define these two macros slightly differently. +#if defined(_MSC_VER) || defined(__clang__) +#define EXPORT_IF_NOT_GCC C10_EXPORT +#else +#define EXPORT_IF_NOT_GCC +#endif + +// CAFFE_KNOWN_TYPE is deprecated! Use CAFFE_DECLARE_KNOWN_TYPE and +// CAFFE_DEFINE_KNOWN_TYPE instead. +#define CAFFE_KNOWN_TYPE(T) \ + template uint16_t TypeMeta::addTypeMetaData(); \ + template <> \ + EXPORT_IF_NOT_GCC uint16_t TypeMeta::_typeMetaData() noexcept { \ + static const uint16_t index = addTypeMetaData(); \ + return index; \ + } + +#define CAFFE_DEFINE_KNOWN_TYPE(T, ident) \ + template uint16_t TypeMeta::addTypeMetaData(); \ + namespace detail { \ + EXPORT_IF_NOT_GCC const uint16_t ident##_metadata_index = \ + TypeMeta::addTypeMetaData(); \ + } // namespace detail + +// Unlike CAFFE_KNOWN_TYPE, CAFFE_DECLARE_KNOWN_TYPE avoids a function +// call to access _typeMetaData in the common case. +#define CAFFE_DECLARE_KNOWN_TYPE(T, ident) \ + extern template uint16_t TypeMeta::addTypeMetaData(); \ + namespace detail { \ + extern C10_API const uint16_t ident##_metadata_index; \ + } /* namespace detail */ \ + template <> \ + EXPORT_IF_NOT_GCC C10_ALWAYS_INLINE uint16_t \ + TypeMeta::_typeMetaData() noexcept { \ + return detail::ident##_metadata_index; \ + } + +#define CAFFE_KNOWN_TYPE_NOEXPORT(T) \ + template <> \ + uint16_t TypeMeta::_typeMetaData() noexcept { \ + static const uint16_t index = addTypeMetaData(); \ + return index; \ + } + +CAFFE_DECLARE_KNOWN_TYPE(std::string, std_string) +CAFFE_DECLARE_KNOWN_TYPE(char, char) +CAFFE_DECLARE_KNOWN_TYPE(std::unique_ptr, std_unique_ptr_std_mutex) +CAFFE_DECLARE_KNOWN_TYPE( + std::unique_ptr>, + std_unique_ptr_std_atomic_bool) +CAFFE_DECLARE_KNOWN_TYPE(std::vector, std_vector_int32_t) +CAFFE_DECLARE_KNOWN_TYPE(std::vector, std_vector_int64_t) +CAFFE_DECLARE_KNOWN_TYPE(std::vector, std_vector_unsigned_long) +CAFFE_DECLARE_KNOWN_TYPE(bool*, bool_ptr) +CAFFE_DECLARE_KNOWN_TYPE(char*, char_ptr) +CAFFE_DECLARE_KNOWN_TYPE(int*, int_ptr) + +// For some of the compilers, long is defined separately from int32_t and +// int64_t. As a result we will need to actually define them separately. +// It is recommended that one does NOT use long - use int32_t and int64_t +// explicitly. Explicit long type annotation may go away in the future. +// details: This hack works by defining a _guard_long_unique type, which is +// long iff the compiler has a separate long type and is a dummy type otherwise. +// we then allocate a type id to that _guard_long_unique. If the compiler has a +// separate long type, this allocates a type id for long. Otherwise, it +// allocates a type id for the dummy type, which doesn't matter. +namespace detail { +template +class _guard_long_unique_dummy final {}; +template +using _guard_long_unique = std::conditional_t< + std::is_same_v || std::is_same_v, + _guard_long_unique_dummy, + T>; +} // namespace detail + +CAFFE_DECLARE_KNOWN_TYPE( + detail::_guard_long_unique, + detail_guard_long_unique_long) +CAFFE_DECLARE_KNOWN_TYPE( + detail::_guard_long_unique>, + detail_guard_long_unique_std_vector_long) + +CAFFE_DECLARE_KNOWN_TYPE(float*, float_ptr) +CAFFE_DECLARE_KNOWN_TYPE(at::Half*, at_Half) + +} // namespace caffe2 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/win32-headers.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/win32-headers.h new file mode 100644 index 0000000000000000000000000000000000000000..f9eb55948a858c8551a52c81ece8eab8c862c324 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/util/win32-headers.h @@ -0,0 +1,65 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#ifndef WIN32_LEAN_AND_MEAN +#define WIN32_LEAN_AND_MEAN +#endif +#ifndef NOMINMAX +#define NOMINMAX +#endif +#ifndef NOKERNEL +#define NOKERNEL +#endif +#ifndef NOUSER +#define NOUSER +#endif +#ifndef NOSERVICE +#define NOSERVICE +#endif +#ifndef NOSOUND +#define NOSOUND +#endif +#ifndef NOMCX +#define NOMCX +#endif +#ifndef NOGDI +#define NOGDI +#endif +#ifndef NOMSG +#define NOMSG +#endif +#ifndef NOMB +#define NOMB +#endif +#ifndef NOCLIPBOARD +#define NOCLIPBOARD +#endif + +// dbghelp seems to require windows.h. +// clang-format off +#include +#include +// clang-format on + +#undef VOID +#undef DELETE +#undef IN +#undef THIS +#undef CONST +#undef NAN +#undef UNKNOWN +#undef NONE +#undef ANY +#undef IGNORE +#undef STRICT +#undef GetObject +#undef CreateSemaphore +#undef Yield +#undef RotateRight32 +#undef RotateLeft32 +#undef RotateRight64 +#undef RotateLeft64 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/xpu/XPUCachingAllocator.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/xpu/XPUCachingAllocator.h new file mode 100644 index 0000000000000000000000000000000000000000..9fe6ecf7e59c18eaa8cd6afc37aa06a2045cf8aa --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/xpu/XPUCachingAllocator.h @@ -0,0 +1,121 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include + +namespace c10::xpu::XPUCachingAllocator { + +class XPUAllocator : public DeviceAllocator { + public: + virtual void init(c10::DeviceIndex device_count) = 0; + virtual void* raw_alloc(size_t nbytes) = 0; + virtual void raw_delete(void* ptr) = 0; +}; + +C10_XPU_API extern std::atomic allocator; + +inline XPUAllocator* get() { + return allocator.load(); +} + +inline void init(c10::DeviceIndex device_count) { + get()->init(device_count); +} + +inline void emptyCache(MempoolId_t mempool_id = {0, 0}) { + get()->emptyCache(mempool_id); +} + +inline void resetPeakStats(DeviceIndex device) { + get()->resetPeakStats(device); +} + +inline void resetAccumulatedStats(DeviceIndex device) { + get()->resetAccumulatedStats(device); +} + +inline c10::CachingDeviceAllocator::DeviceStats getDeviceStats( + DeviceIndex device) { + return get()->getDeviceStats(device); +} + +inline void* raw_alloc(size_t size) { + return get()->raw_alloc(size); +} + +inline void raw_delete(void* ptr) { + get()->raw_delete(ptr); +} + +inline void recordStream(const DataPtr& dataPtr, XPUStream stream) { + get()->recordStream(dataPtr, stream); +} + +C10_XPU_API void enablePeerAccess( + c10::DeviceIndex dev, + c10::DeviceIndex dev_to_access); + +C10_XPU_API double getMemoryFraction(DeviceIndex device); + +C10_XPU_API void setMemoryFraction(double fraction, DeviceIndex device); + +C10_XPU_API void createOrIncrefPool( + c10::DeviceIndex device, + c10::MempoolId_t mempool_id, + XPUAllocator* allocator = nullptr); + +C10_XPU_API void beginAllocateToPool( + c10::DeviceIndex device, + c10::MempoolId_t mempool_id, + std::function filter); + +C10_XPU_API void endAllocateToPool( + c10::DeviceIndex device, + c10::MempoolId_t mempool_id); + +C10_XPU_API void releasePool( + c10::DeviceIndex device, + c10::MempoolId_t mempool_id); + +C10_XPU_API int getPoolUseCount( + c10::DeviceIndex device, + c10::MempoolId_t mempool_id); + +} // namespace c10::xpu::XPUCachingAllocator + +namespace c10::xpu { + +using c10::CaptureId_t; +using c10::MempoolId_t; +struct C10_XPU_API MemPool { + MemPool( + XPUCachingAllocator::XPUAllocator* allocator = nullptr, + bool is_user_created = true, + bool use_on_oom = false); + MemPool(const MemPool&) = delete; + MemPool(MemPool&&) = default; + MemPool& operator=(const MemPool&) = delete; + MemPool& operator=(MemPool&&) = default; + ~MemPool(); + + MempoolId_t id(); + XPUCachingAllocator::XPUAllocator* allocator(); + int use_count(); + c10::DeviceIndex device(); + static MempoolId_t graph_pool_handle(bool is_user_created = true); + + private: + static std::atomic uid_; + static std::atomic uuid_; + XPUCachingAllocator::XPUAllocator* allocator_; + bool is_user_created_; + MempoolId_t id_; + c10::DeviceIndex device_; +}; +} // namespace c10::xpu + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/xpu/XPUDeviceProp.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/xpu/XPUDeviceProp.h new file mode 100644 index 0000000000000000000000000000000000000000..b85a34f0bc3d032fe403c5e758cfbba252b27871 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/xpu/XPUDeviceProp.h @@ -0,0 +1,212 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +namespace c10::xpu { + +#define AT_FORALL_XPU_DEVICE_PROPERTIES(_) \ + /* the device name of this SYCL device. */ \ + _(name) \ + \ + /* the device type associated with the device. */ \ + _(device_type) \ + \ + /* the vendor of this SYCL device. */ \ + _(vendor) \ + \ + /* a backend-defined driver version as a std::string. */ \ + _(driver_version) \ + \ + /* the SYCL version as a std::string in the form . */ \ + _(version) \ + \ + /* true if the SYCL device is available. Otherwise, return false. */ \ + _(is_available) \ + \ + /* the maximum size in bytes of the arguments that can be passed to a \ + * kernel. */ \ + _(max_parameter_size) \ + \ + /* the number of parallel compute units available to the device. */ \ + _(max_compute_units) \ + \ + /* the maximum dimensions that specify the global and local work-item IDs \ + * used by the data parallel execution model. */ \ + _(max_work_item_dimensions) \ + \ + /* the maximum number of workitems that are permitted in a work-group \ + * executing a kernel on a single compute unit. */ \ + _(max_work_group_size) \ + \ + /* the maximum number of subgroups in a work-group for any kernel executed \ + * on the device. */ \ + _(max_num_sub_groups) \ + \ + /* a std::vector of size_t containing the set of sub-group sizes supported \ + * by the device. */ \ + _(sub_group_sizes) \ + \ + /* the maximum configured clock frequency of this SYCL device in MHz. */ \ + _(max_clock_frequency) \ + \ + /* the default compute device address space size specified as an unsigned \ + * integer value in bits. Must return either 32 or 64. */ \ + _(address_bits) \ + \ + /* the maximum size of memory object allocation in bytes. */ \ + _(max_mem_alloc_size) \ + \ + /* the minimum value in bits of the largest supported SYCL built-in data \ + * type if this SYCL device is not of device type \ + * sycl::info::device_type::custom. */ \ + _(mem_base_addr_align) \ + \ + /* a std::vector of info::fp_config describing the half/single/double \ + * precision floating-point capability of this SYCL device. */ \ + _(half_fp_config) \ + _(single_fp_config) \ + _(double_fp_config) \ + \ + /* the size of global device memory in bytes. */ \ + _(global_mem_size) \ + \ + /* the type of global memory cache supported. */ \ + _(global_mem_cache_type) \ + \ + /* the size of global memory cache in bytes. */ \ + _(global_mem_cache_size) \ + \ + /* the size of global memory cache line in bytes. */ \ + _(global_mem_cache_line_size) \ + \ + /* the type of local memory supported. */ \ + _(local_mem_type) \ + \ + /* the size of local memory arena in bytes. */ \ + _(local_mem_size) \ + \ + /* the maximum number of sub-devices that can be created when this device is \ + * partitioned. */ \ + _(partition_max_sub_devices) \ + \ + /* the resolution of device timer in nanoseconds. */ \ + _(profiling_timer_resolution) \ + \ + /* the preferred native vector width size for built-in scalar types that can \ + * be put into vectors. */ \ + _(preferred_vector_width_char) \ + _(preferred_vector_width_short) \ + _(preferred_vector_width_int) \ + _(preferred_vector_width_long) \ + _(preferred_vector_width_float) \ + _(preferred_vector_width_double) \ + _(preferred_vector_width_half) \ + \ + /* the native ISA vector width. The vector width is defined as the number of \ + * scalar elements that can be stored in the vector. */ \ + _(native_vector_width_char) \ + _(native_vector_width_short) \ + _(native_vector_width_int) \ + _(native_vector_width_long) \ + _(native_vector_width_float) \ + _(native_vector_width_double) \ + _(native_vector_width_half) + +#define AT_FORALL_XPU_EXT_DEVICE_PROPERTIES(_) \ + /* the number of EUs associated with the Intel GPU. */ \ + _(gpu_eu_count, gpu_eu_count, 512) \ + \ + /* the number of EUs in a subslice. */ \ + _(gpu_eu_count_per_subslice, gpu_eu_count_per_subslice, 8) \ + \ + /* the simd width of EU of GPU. */ \ + _(gpu_eu_simd_width, gpu_eu_simd_width, 8) \ + \ + /* the number of hardware threads per EU of GPU. */ \ + _(gpu_hw_threads_per_eu, gpu_hw_threads_per_eu, 8) \ + \ + /* the device identifier of the Intel GPU, also known as the product ID. */ \ + _(device_id, device_id, 0) \ + \ + /* the device descriptor for device Universal Unique ID, 16 bytes*/ \ + _(uuid, device_info_uuid, (std::array{})) + +#define AT_FORALL_XPU_DEVICE_ASPECT(_) \ + /* sycl::half is supported on device. */ \ + _(fp16) \ + \ + /* double is supported on device. */ \ + _(fp64) \ + \ + /* 64-bit atomic operation is supported on device. */ \ + _(atomic64) + +#define AT_FORALL_XPU_EXP_CL_ASPECT(_) \ + /* conversion between single-precision 32-bit floating-point values and \ + * 16-bit bfloat16 values is supported on device. */ \ + _(bfloat16_conversions) \ + \ + /* specialized hardware to compute MMA is supported on device. */ \ + _(subgroup_matrix_multiply_accumulate) \ + \ + /* specialized hardware to compute MMA for 32-bit floating-point is \ + * supported on device. */ \ + _(subgroup_matrix_multiply_accumulate_tensor_float32) \ + \ + /* block read operations for efficient matrix multiplication is supported on \ + * device. */ \ + _(subgroup_2d_block_io) + +#define AT_FORALL_XPU_EXP_DEVICE_PROPERTIES(_) \ + /* the device architecture of this SYCL device. */ \ + _(architecture) + +#define _DEFINE_SYCL_PROP(ns, property, member) \ + ns::property::return_type member; + +#define DEFINE_DEVICE_PROP(property) \ + _DEFINE_SYCL_PROP(sycl::info::device, property, property) + +#define DEFINE_PLATFORM_PROP(property, member) \ + _DEFINE_SYCL_PROP(sycl::info::platform, property, member) + +#define DEFINE_EXT_DEVICE_PROP(property, ...) \ + _DEFINE_SYCL_PROP(sycl::ext::intel::info::device, property, property) + +#define DEFINE_DEVICE_ASPECT(member) bool has_##member; + +#define DEFINE_EXP_DEVICE_PROP(property) \ + _DEFINE_SYCL_PROP( \ + sycl::ext::oneapi::experimental::info::device, property, property) + +struct C10_XPU_API DeviceProp { + AT_FORALL_XPU_DEVICE_PROPERTIES(DEFINE_DEVICE_PROP); + + // the platform name. + DEFINE_PLATFORM_PROP(name, platform_name); + + AT_FORALL_XPU_EXT_DEVICE_PROPERTIES(DEFINE_EXT_DEVICE_PROP); + + AT_FORALL_XPU_DEVICE_ASPECT(DEFINE_DEVICE_ASPECT); + + AT_FORALL_XPU_EXP_CL_ASPECT(DEFINE_DEVICE_ASPECT); + +#if SYCL_COMPILER_VERSION >= 20250000 + AT_FORALL_XPU_EXP_DEVICE_PROPERTIES(DEFINE_EXP_DEVICE_PROP); +#endif +}; + +#undef _DEFINE_SYCL_PROP +#undef DEFINE_DEVICE_PROP +#undef DEFINE_PLATFORM_PROP +#undef DEFINE_EXT_DEVICE_PROP +#undef DEFINE_DEVICE_ASPECT +#undef DEFINE_EXP_DEVICE_PROP + +} // namespace c10::xpu + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/xpu/XPUEvent.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/xpu/XPUEvent.h new file mode 100644 index 0000000000000000000000000000000000000000..596fdfcc0ff06ccdb4395c3989e987836804eddc --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/xpu/XPUEvent.h @@ -0,0 +1,183 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include + +namespace c10::xpu { + +/* + * XPUEvent are movable not copyable wrappers around SYCL event. XPUEvent are + * constructed lazily when first recorded. It has a device, and this device is + * acquired from the first recording stream. Later streams that record the event + * must match the same device. + * + * Currently, XPUEvent does NOT support to export an inter-process event from + * another process via inter-process communication(IPC). So it means that + * inter-process communication for event handles between different processes is + * not available. This could impact some applications that rely on cross-process + * synchronization and communication. + */ +struct XPUEvent { + // Constructors + XPUEvent(bool enable_timing = false) noexcept + : enable_timing_{enable_timing} {} + + ~XPUEvent() { + if (isCreated()) { + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_event_deletion( + c10::kXPU, reinterpret_cast(event_.get())); + } + } + } + + C10_DISABLE_COPY_AND_ASSIGN(XPUEvent); + + XPUEvent(XPUEvent&& other) = default; + XPUEvent& operator=(XPUEvent&& other) = default; + + operator sycl::event&() const { + return event(); + } + + std::optional device() const { + if (isCreated()) { + return c10::Device(c10::kXPU, device_index_); + } else { + return std::nullopt; + } + } + + inline bool isCreated() const { + return (event_.get() != nullptr); + } + + DeviceIndex device_index() const { + return device_index_; + } + + sycl::event& event() const { + return *event_; + } + + bool query() const { + using namespace sycl::info; + if (!isCreated()) { + return true; + } + + return event().get_info() == + event_command_status::complete; + } + + void record() { + record(getCurrentXPUStream()); + } + + void recordOnce(const XPUStream& stream) { + if (!isCreated()) { + record(stream); + } + } + + void record(const XPUStream& stream) { + if (!isCreated()) { + device_index_ = stream.device_index(); + assignEvent(stream.queue()); + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_event_creation( + c10::kXPU, reinterpret_cast(event_.get())); + } + } else { + TORCH_CHECK( + device_index_ == stream.device_index(), + "Event device ", + device_index_, + " does not match recording stream's device ", + stream.device_index(), + "."); + reassignEvent(stream.queue()); + } + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_event_record( + c10::kXPU, + reinterpret_cast(event_.get()), + reinterpret_cast(&stream.queue())); + } + } + + void block(const XPUStream& stream) { + if (isCreated()) { + std::vector event_list{event()}; + // Make this stream wait until event_ is completed. + stream.queue().ext_oneapi_submit_barrier(event_list); + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_event_wait( + c10::kXPU, + reinterpret_cast(event_.get()), + reinterpret_cast(&stream.queue())); + } + } + } + + double elapsed_time(const XPUEvent& other) const { + TORCH_CHECK( + isCreated() && other.isCreated(), + "Both events must be recorded before calculating elapsed time."); + TORCH_CHECK( + query() && other.query(), + "Both events must be completed before calculating elapsed time."); + TORCH_CHECK( + enable_timing_ && other.enable_timing_, + "Both events must be created with argument 'enable_timing=True'."); + + using namespace sycl::info::event_profiling; + // Block until both of the recorded events are completed. + uint64_t end_time_ns = other.event().get_profiling_info(); + uint64_t start_time_ns = event().get_profiling_info(); + // Return the eplased time in milliseconds. + return 1e-6 * + (static_cast(end_time_ns) - static_cast(start_time_ns)); + } + + void synchronize() const { + if (isCreated()) { + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_event_synchronization( + c10::kXPU, reinterpret_cast(event_.get())); + } + event().wait_and_throw(); + } + } + + private: + void assignEvent(sycl::queue& queue) { + if (enable_timing_) { + event_ = std::make_unique( + sycl::ext::oneapi::experimental::submit_profiling_tag(queue)); + } else { + event_ = std::make_unique(queue.ext_oneapi_submit_barrier()); + } + } + + void reassignEvent(sycl::queue& queue) { + event_.reset(); + assignEvent(queue); + } + + bool enable_timing_ = false; + c10::DeviceIndex device_index_ = -1; + // Only need to track the last event, as events in an in-order queue are + // executed sequentially. + std::unique_ptr event_; +}; + +} // namespace c10::xpu + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/xpu/XPUException.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/xpu/XPUException.h new file mode 100644 index 0000000000000000000000000000000000000000..d5d6d56a1560728c6604d04aaa2aa75c4c615aae --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/xpu/XPUException.h @@ -0,0 +1,27 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +namespace c10::xpu { + +static inline sycl::async_handler asyncHandler = [](sycl::exception_list el) { + if (el.size() == 0) { + return; + } + for (const auto& e : el) { + try { + std::rethrow_exception(e); + } catch (sycl::exception& e) { + TORCH_WARN("SYCL Exception: ", e.what()); + } + } + throw; +}; + +} // namespace c10::xpu + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/xpu/XPUFunctions.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/xpu/XPUFunctions.h new file mode 100644 index 0000000000000000000000000000000000000000..e5017a054d32448a372290fcab2adfdea3e7fb36 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/xpu/XPUFunctions.h @@ -0,0 +1,50 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include + +// The naming convention used here matches the naming convention of torch.xpu + +namespace c10::xpu { + +// Log a warning only once if no devices are detected. +C10_XPU_API DeviceIndex device_count(); + +// Throws an error if no devices are detected. +C10_XPU_API DeviceIndex device_count_ensure_non_zero(); + +C10_XPU_API DeviceIndex current_device(); + +C10_XPU_API void set_device(DeviceIndex device); + +C10_XPU_API DeviceIndex exchange_device(DeviceIndex device); + +C10_XPU_API DeviceIndex maybe_exchange_device(DeviceIndex to_device); + +C10_XPU_API sycl::device& get_raw_device(DeviceIndex device); + +C10_XPU_API sycl::context& get_device_context(); + +C10_XPU_API void get_device_properties( + DeviceProp* device_prop, + DeviceIndex device); + +C10_XPU_API DeviceIndex get_device_idx_from_pointer(void* ptr); + +static inline void check_device_index(DeviceIndex device_index) { + TORCH_CHECK( + device_index >= 0 && device_index < c10::xpu::device_count(), + "The device index is out of range. It must be in [0, ", + static_cast(c10::xpu::device_count()), + "), but got ", + static_cast(device_index), + "."); +} + +} // namespace c10::xpu + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/xpu/XPUGraphsC10Utils.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/xpu/XPUGraphsC10Utils.h new file mode 100644 index 0000000000000000000000000000000000000000..437dda44bfc4826d05389b530a7cd54083ec8f08 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/xpu/XPUGraphsC10Utils.h @@ -0,0 +1,47 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +// XPU Graphs utils used by c10 and aten. +using namespace sycl::ext::oneapi::experimental; +namespace c10::xpu { + +static_assert( + int8_t(queue_state::executing) == 0, + "unexpected int(queue_state::executing) value"); +static_assert( + int8_t(queue_state::recording) == 1, + "unexpected int(queue_state::recording) value"); + +enum class CaptureStatus : int8_t { + Executing = int8_t(queue_state::executing), + Recording = int8_t(queue_state::recording) +}; + +inline std::ostream& operator<<(std::ostream& os, CaptureStatus status) { + switch (status) { + case CaptureStatus::Executing: + os << "Executing"; + break; + case CaptureStatus::Recording: + os << "Recording"; + break; + default: + TORCH_INTERNAL_ASSERT( + false, "Unknown XPU graph CaptureStatus", int(status)); + } + return os; +} + +inline CaptureStatus currentStreamCaptureStatusMayInitCtx() { + auto state = c10::xpu::getCurrentXPUStream().queue().ext_oneapi_get_state(); + return CaptureStatus(state); +} + +} // namespace c10::xpu + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/xpu/XPUMacros.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/xpu/XPUMacros.h new file mode 100644 index 0000000000000000000000000000000000000000..43a42c2a6f8a47a2276268e58edc176ec5f6a781 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/xpu/XPUMacros.h @@ -0,0 +1,38 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#ifndef C10_USING_CUSTOM_GENERATED_MACROS +#include +#endif + +// See c10/macros/Export.h for a detailed explanation of what the function +// of these macros are. We need one set of macros for every separate library +// we build. + +#ifdef _WIN32 +#if defined(C10_XPU_BUILD_SHARED_LIBS) +#define C10_XPU_EXPORT __declspec(dllexport) +#define C10_XPU_IMPORT __declspec(dllimport) +#else +#define C10_XPU_EXPORT +#define C10_XPU_IMPORT +#endif +#else // _WIN32 +#if defined(__GNUC__) +#define C10_XPU_EXPORT __attribute__((__visibility__("default"))) +#else // defined(__GNUC__) +#define C10_XPU_EXPORT +#endif // defined(__GNUC__) +#define C10_XPU_IMPORT C10_XPU_EXPORT +#endif // _WIN32 + +// This one is being used by libc10_xpu.so +#ifdef C10_XPU_BUILD_MAIN_LIB +#define C10_XPU_API C10_XPU_EXPORT +#else +#define C10_XPU_API C10_XPU_IMPORT +#endif + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/xpu/XPUStream.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/xpu/XPUStream.h new file mode 100644 index 0000000000000000000000000000000000000000..df79df4945aa93da62b5faf0bf931a44cd09bf2d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/xpu/XPUStream.h @@ -0,0 +1,217 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include + +namespace c10::xpu { + +/* + * Note [Stream Management] + * + * An XPUStream is an abstraction of an actual SYCL queue in which SYCL kernel + * can execute. Currently, there are several pools per device to manage SYCL + * queue, and a device's pool is lazily created. + * + * There are two pools per device. The first pool contains "normal priority" + * queues. The second pool is the "high priority" queues. There are 32 queues in + * per pool per device, and when a queue is requested one of these queues is + * returned round-robin. That is, the first queue requested is at index 0, the + * second at index 1... to index 31, then index 0 again. + * + * This means that if 33 queues are requested, the first and last queues + * requested are actually the same queue (under the covers) and kernels enqueued + * on them cannot run concurrently. + * + * It is safe to enqueue a kernel on the same queue from two different + * threads as the SYCL specification described. + */ + +static constexpr int max_compile_time_stream_priorities = 3; + +/* + * This serves as a wrapper around c10::Stream and acts as a representation for + * a SYCL queue, which allows asynchronous execution of XPU tasks. + */ +class C10_XPU_API XPUStream { + public: + enum Unchecked { UNCHECKED }; + + /// Construct a XPUStream from a Stream. This construction is checked, and + /// will raise an error if the Stream is not, in fact, a XPU stream. + explicit XPUStream(Stream stream) : stream_(stream) { + TORCH_CHECK(stream_.device_type() == DeviceType::XPU); + } + + /// Construct a XPUStream from a Stream with no error checking. + explicit XPUStream(Unchecked, Stream stream) : stream_(stream) {} + + bool operator==(const XPUStream& other) const noexcept { + return unwrap() == other.unwrap(); + } + + bool operator!=(const XPUStream& other) const noexcept { + return unwrap() != other.unwrap(); + } + + /// Implicit conversion to sycl::queue&. + operator sycl::queue&() const { + return queue(); + } + + /// Implicit conversion to sycl::queue*. + operator sycl::queue*() const { + return &queue(); + } + + /// Implicit conversion to Stream (a.k.a., forget that the stream is a + /// XPU stream). + operator Stream() const { + return unwrap(); + } + + /// Get the XPU device type that this stream is associated with. + DeviceType device_type() const { + return DeviceType::XPU; + } + + /// Get the XPU device index that this stream is associated with. + DeviceIndex device_index() const { + return stream_.device_index(); + } + + /// Get the full Device that this stream is associated with. The Device is + /// guaranteed to be a XPU device. + Device device() const { + return Device(DeviceType::XPU, device_index()); + } + + /// Return the stream ID corresponding to this particular stream. StreamId is + /// a int64_t representation generated by its type and index. + StreamId id() const { + return stream_.id(); + } + + /// Return true if all enqueued tasks in this stream have been completed, + /// otherwise return false. + bool query() const { + return queue().ext_oneapi_empty(); + } + + /// Performs a blocking wait for the completion of all enqueued tasks in this + /// stream. + void synchronize() const { + queue().wait_and_throw(); + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_stream_synchronization( + c10::kXPU, reinterpret_cast(&queue())); + } + } + + /// Return the priority that this stream is associated with. Lower numbers + /// represent higher priority. + int priority() const; + + /// Explicit conversion to sycl::queue&. + sycl::queue& queue() const; + + /// Explicit conversion to Stream. + Stream unwrap() const { + return stream_; + } + + /// Reversibly pack a XPUStream into a struct representation. The XPUStream + /// can be unpacked using unpack3(). + struct c10::StreamData3 pack3() const { + return stream_.pack3(); + } + + /// Unpack a XPUStream from the 3 fields generated by pack3(). + static XPUStream unpack3( + StreamId stream_id, + DeviceIndex device_index, + DeviceType device_type) { + return XPUStream(Stream::unpack3(stream_id, device_index, device_type)); + } + + /// Return the range of priority **supported by PyTorch**. + static std::tuple priority_range() { + // See Note [XPU Stream priorities] + return std::make_tuple(1, -max_compile_time_stream_priorities + 2); + } + + private: + Stream stream_; +}; + +/** + * Get a stream from the pool in a round-robin fashion. + * + * You can request a stream from the highest priority pool by setting + * isHighPriority to true for a specific device. + */ +C10_XPU_API XPUStream +getStreamFromPool(const bool isHighPriority = false, DeviceIndex device = -1); + +/** + * Get a stream from the pool in a round-robin fashion. + * + * You can request a stream by setting a priority value for a specific device. + * The priority number lower, the priority higher. + */ +C10_XPU_API XPUStream +getStreamFromPool(const int priority, DeviceIndex device = -1); + +/** + * Get an XPUStream from an external SYCL queue. + * + * This function allows interoperability with other libraries by enabling + * the use of an external SYCL queue that was not created by PyTorch. This + * can be useful for data exchange or other operations where integration + * with non-PyTorch queues is required. + * + * NOTE: It is the user's responsibility to ensure that the referenced SYCL + * queue remains alive while the corresponding XPUStream, or any c10::Stream + * derived from it, is in use. The different SYCL queue pointers will result in + * distinct XPUStream instances, even if the SYCL queues they dereference are + * equivalent. + */ +C10_XPU_API XPUStream +getStreamFromExternal(sycl::queue* ext_queue, DeviceIndex device_index); + +/** + * Get the current XPU stream, for the passed XPU device, or for the current + * device if no device index is passed. + */ +C10_XPU_API XPUStream getCurrentXPUStream(DeviceIndex device = -1); + +/** + * Set the current stream on the device of the passed in stream to be the passed + * in stream. + */ +C10_XPU_API void setCurrentXPUStream(XPUStream stream); + +C10_XPU_API std::ostream& operator<<(std::ostream& stream, const XPUStream& s); + +/** + * Block all reserved SYCL queues in the stream pools on the device, and wait + * for their synchronizations. + */ +C10_XPU_API void syncStreamsOnDevice(DeviceIndex device = -1); + +} // namespace c10::xpu + +namespace std { +template <> +struct hash { + size_t operator()(c10::xpu::XPUStream s) const noexcept { + return std::hash{}(s.unwrap()); + } +}; +} // namespace std + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/xpu/impl/XPUGuardImpl.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/xpu/impl/XPUGuardImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..0d700f946ebe76abf99c2641448f5b4e2c3241eb --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/xpu/impl/XPUGuardImpl.h @@ -0,0 +1,223 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include +#include + +#include + +namespace c10::xpu::impl { + +struct XPUGuardImpl final : public c10::impl::DeviceGuardImplInterface { + static constexpr DeviceType static_type = kXPU; + + XPUGuardImpl() = default; + + explicit XPUGuardImpl(DeviceType t) { + TORCH_CHECK( + t == kXPU, "XPUGuardImpl initialized with non-XPU DeviceType: ", t); + } + + DeviceType type() const override { + return kXPU; + } + + Device exchangeDevice(Device d) const override { + TORCH_CHECK(d.is_xpu(), "Expected a XPU device, but got ", d); + const auto old_device_index = c10::xpu::exchange_device(d.index()); + return Device(kXPU, old_device_index); + } + + Device getDevice() const override { + const auto device = c10::xpu::current_device(); + return Device(kXPU, device); + } + + void setDevice(Device d) const override { + TORCH_CHECK(d.is_xpu(), "Expected a XPU device, but got ", d); + c10::xpu::set_device(d.index()); + } + + void uncheckedSetDevice(Device d) const noexcept override { + c10::xpu::set_device(d.index()); + } + + Stream getStream(Device d) const override { + return getCurrentXPUStream(d.index()).unwrap(); + } + + Stream getNewStream(Device d, int priority = 0) const override { + return getStreamFromPool(priority, d.index()); + } + + Stream getStreamFromGlobalPool(Device d, bool isHighPriority = false) + const override { + return getStreamFromPool(isHighPriority, d.index()); + } + + // NB: These do NOT set the current device + Stream exchangeStream(Stream s) const override { + const XPUStream stream(s); + const auto old_stream = getCurrentXPUStream(s.device().index()); + setCurrentXPUStream(stream); + return old_stream.unwrap(); + } + + DeviceIndex deviceCount() const noexcept override { + return c10::xpu::device_count(); + } + + // Event-related functions + void destroyEvent(void* event, const DeviceIndex device_index) + const noexcept override { + if (!event) + return; + + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_event_deletion( + c10::kXPU, reinterpret_cast(event)); + } + + delete reinterpret_cast(event); + } + + void record( + void** event, + const Stream& stream, + const DeviceIndex device_index, + const EventFlag flag) const override { + TORCH_CHECK( + device_index == -1 || device_index == stream.device_index(), + "Event device index ", + device_index, + " does not match recording stream's device index ", + stream.device_index(), + "."); + + auto* xpu_event = reinterpret_cast(*event); + const XPUStream xpu_stream{stream}; + + // Delete the event previously recorded. + if (xpu_event) + delete xpu_event; +#if SYCL_COMPILER_VERSION >= 20250000 + if (flag == EventFlag::BACKEND_DEFAULT) { + // Use the profiling tag to record the event to enable timing feature. + xpu_event = + new sycl::event(sycl::ext::oneapi::experimental::submit_profiling_tag( + xpu_stream.queue())); + } else { + xpu_event = + new sycl::event(xpu_stream.queue().ext_oneapi_submit_barrier()); + } +#else + xpu_event = new sycl::event(xpu_stream.queue().ext_oneapi_submit_barrier()); +#endif + *event = reinterpret_cast(xpu_event); + + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_event_record( + c10::kXPU, + reinterpret_cast(xpu_event), + reinterpret_cast(&xpu_stream.queue())); + } + } + + void block(void* event, const Stream& stream) const override { + if (!event) + return; + auto* xpu_event = reinterpret_cast(event); + std::vector event_list{*xpu_event}; + const XPUStream xpu_stream(stream); + xpu_stream.queue().ext_oneapi_submit_barrier(event_list); + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_event_wait( + c10::kXPU, + reinterpret_cast(xpu_event), + reinterpret_cast(&xpu_stream.queue())); + } + } + + bool queryEvent(void* event) const override { + using namespace sycl::info; + if (!event) + return true; + auto* xpu_event = reinterpret_cast(event); + return xpu_event->get_info() == + event_command_status::complete; + } + + double elapsedTime( + void* start_event, + void* end_event, + const DeviceIndex device_index) const override { +#if SYCL_COMPILER_VERSION < 20250000 + TORCH_CHECK_NOT_IMPLEMENTED( + false, + "elapsedTime requires PyTorch to be built with SYCL compiler version 2025.0.0 or newer."); +#endif + TORCH_CHECK( + start_event && end_event, + "Both events must be recorded before calculating elapsed time."); + auto* xpu_start_event = reinterpret_cast(start_event); + auto* xpu_end_event = reinterpret_cast(end_event); + + using namespace sycl::info::event_profiling; + // Block until both of the recorded events are completed. + uint64_t end_time_ns = xpu_end_event->get_profiling_info(); + uint64_t start_time_ns = xpu_start_event->get_profiling_info(); + // Return the eplased time in milliseconds. + return 1e-6 * + (static_cast(end_time_ns) - static_cast(start_time_ns)); + } + + // Stream-related functions + bool queryStream(const Stream& stream) const override { + const XPUStream xpu_stream{stream}; + return xpu_stream.query(); + } + + void synchronizeStream(const Stream& stream) const override { + const XPUStream xpu_stream{stream}; + xpu_stream.synchronize(); + } + + void synchronizeEvent(void* event) const override { + if (!event) + return; + auto* xpu_event = reinterpret_cast(event); + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_event_synchronization( + c10::kXPU, reinterpret_cast(xpu_event)); + } + xpu_event->wait_and_throw(); + } + + void synchronizeDevice(const c10::DeviceIndex device_index) const override { + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_device_synchronization(c10::kXPU); + } + c10::xpu::syncStreamsOnDevice(device_index); + } + + void recordDataPtrOnStream(const c10::DataPtr& data_ptr, const Stream& stream) + const override { + const XPUStream xpu_stream{stream}; + XPUCachingAllocator::recordStream(data_ptr, xpu_stream); + } +}; + +} // namespace c10::xpu::impl + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/xpu/test/impl/XPUTest.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/xpu/test/impl/XPUTest.h new file mode 100644 index 0000000000000000000000000000000000000000..336c8349121389fd6dc64732ef50977e1cb2e0d2 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/c10/xpu/test/impl/XPUTest.h @@ -0,0 +1,26 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include + +#include + +static inline void initHostData(int* hostData, int numel) { + for (const auto i : c10::irange(numel)) { + hostData[i] = i; + } +} + +static inline void clearHostData(int* hostData, int numel) { + for (const auto i : c10::irange(numel)) { + hostData[i] = 0; + } +} + +static inline void validateHostData(int* hostData, int numel) { + for (const auto i : c10::irange(numel)) { + EXPECT_EQ(hostData[i], i); + } +} + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/core/common.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/core/common.h new file mode 100644 index 0000000000000000000000000000000000000000..f8de86b9ed8e3fc25a3e6efe20bc36f5b29336c0 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/core/common.h @@ -0,0 +1,66 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#ifndef CAFFE2_CORE_COMMON_H_ +#define CAFFE2_CORE_COMMON_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef __APPLE__ +#include +#endif + +#if defined(_MSC_VER) +#include +#else +#include +#endif + +// Macros used during the build of this caffe2 instance. This header file +// is automatically generated by the cmake script during build. +#include "caffe2/core/macros.h" + +#include + +namespace caffe2 { + +// Using statements for common classes that we refer to in caffe2 very often. +// Note that we only place it inside caffe2 so the global namespace is not +// polluted. +/* using override */ +using std::set; +using std::string; +using std::unique_ptr; +using std::vector; + +// Define alignment macro that is cross platform +#if (defined _MSC_VER && !defined NOMINMAX) +#define NOMINMAX +#endif + +using std::make_unique; + +#if defined(__ANDROID__) && !defined(__NDK_MAJOR__) +using ::round; +#else +using std::round; +#endif // defined(__ANDROID__) && !defined(__NDK_MAJOR__) + +// Returns which setting Caffe2 was configured and built with (exported from +// CMake) +TORCH_API const std::map& GetBuildOptions(); + +} // namespace caffe2 + +#endif // CAFFE2_CORE_COMMON_H_ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/core/macros.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/core/macros.h new file mode 100644 index 0000000000000000000000000000000000000000..ae86a3366590c8538b92cc7e92191365ef3545c7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/core/macros.h @@ -0,0 +1,75 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Automatically generated header file for caffe2 macros. These +// macros are used to build the Caffe2 binary, and if you are +// building a dependent library, they will need to be set as well +// for your program to link correctly. + +#pragma once + +#define CAFFE2_BUILD_SHARED_LIBS +/* #undef CAFFE2_FORCE_FALLBACK_CUDA_MPI */ +/* #undef CAFFE2_HAS_MKL_DNN */ +/* #undef CAFFE2_HAS_MKL_SGEMM_PACK */ +#define CAFFE2_PERF_WITH_AVX +#define CAFFE2_PERF_WITH_AVX2 +/* #undef CAFFE2_THREADPOOL_MAIN_IMBALANCE */ +/* #undef CAFFE2_THREADPOOL_STATS */ +/* #undef CAFFE2_USE_ACCELERATE */ +#define CAFFE2_USE_CUDNN +/* #undef CAFFE2_USE_EIGEN_FOR_BLAS */ +/* #undef CAFFE2_USE_FBCODE */ +/* #undef CAFFE2_USE_GOOGLE_GLOG */ +/* #undef CAFFE2_USE_LITE_PROTO */ +#define CAFFE2_USE_MKL +#define USE_MKLDNN +/* #undef CAFFE2_USE_NVTX */ +/* #undef CAFFE2_USE_ITT */ + +#ifndef EIGEN_MPL2_ONLY +#define EIGEN_MPL2_ONLY +#endif + +// Useful build settings that are recorded in the compiled binary +// torch.__config__.show() +#define CAFFE2_BUILD_STRINGS { \ + {"TORCH_VERSION", "2.10.0"}, \ + {"CXX_COMPILER", "/opt/rh/gcc-toolset-13/root/usr/bin/c++"}, \ + {"CXX_FLAGS", " -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOROCTRACER -DLIBKINETO_NOXPUPTI=ON -DUSE_FBGEMM -DUSE_FBGEMM_GENAI -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -DC10_NODEPRECATED -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=range-loop-construct -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-unknown-pragmas -Wno-unused-parameter -Wno-strict-overflow -Wno-strict-aliasing -Wno-stringop-overflow -Wsuggest-override -Wno-psabi -Wno-error=old-style-cast -faligned-new -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-dangling-reference -Wno-error=dangling-reference -Wno-stringop-overflow"}, \ + {"BUILD_TYPE", "Release"}, \ + {"BLAS_INFO", "mkl"}, \ + {"LAPACK_INFO", "mkl"}, \ + {"USE_CUDA", "ON"}, \ + {"USE_ROCM", "OFF"}, \ + {"CUDA_VERSION", "12.8"}, \ + {"ROCM_VERSION", ""}, \ + {"USE_CUDNN", "ON"}, \ + {"COMMIT_SHA", "449b1768410104d3ed79d3bcfe4ba1d65c7f22c0"}, \ + {"CUDNN_VERSION", "9.10.2"}, \ + {"USE_NCCL", "1"}, \ + {"USE_MPI", "OFF"}, \ + {"USE_GFLAGS", "OFF"}, \ + {"USE_GLOG", "OFF"}, \ + {"USE_GLOO", "ON"}, \ + {"USE_NNPACK", "ON"}, \ + {"USE_OPENMP", "ON"}, \ + {"FORCE_FALLBACK_CUDA_MPI", ""}, \ + {"HAS_MKL_DNN", ""}, \ + {"HAS_MKL_SGEMM_PACK", ""}, \ + {"PERF_WITH_AVX", "1"}, \ + {"PERF_WITH_AVX2", "1"}, \ + {"USE_ACCELERATE", ""}, \ + {"USE_EIGEN_FOR_BLAS", ""}, \ + {"USE_LITE_PROTO", ""}, \ + {"USE_MKL", "ON"}, \ + {"USE_MKLDNN", "ON"}, \ + {"USE_NVTX", ""}, \ + {"USE_ITT", ""}, \ + {"USE_ROCM_KERNEL_ASSERT", "OFF"}, \ + {"USE_CUSPARSELT", "1"}, \ + {"USE_XPU", "OFF"}, \ + {"USE_XCCL", "OFF"}, \ +} + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/core/timer.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/core/timer.h new file mode 100644 index 0000000000000000000000000000000000000000..54ff81fc25e27eb38cc23e497b692f321b71c6b4 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/core/timer.h @@ -0,0 +1,53 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#ifndef CAFFE2_CORE_TIMER_H_ +#define CAFFE2_CORE_TIMER_H_ + +#include + +#include "caffe2/core/common.h" + +namespace caffe2 { + +/** + * @brief A simple timer object for measuring time. + * + * This is a minimal class around a std::chrono::high_resolution_clock that + * serves as a utility class for testing code. + */ +class Timer { + public: + typedef std::chrono::high_resolution_clock clock; + typedef std::chrono::nanoseconds ns; + Timer() { Start(); } + /** + * @brief Starts a timer. + */ + inline void Start() { start_time_ = clock::now(); } + inline float NanoSeconds() { + return static_cast( + std::chrono::duration_cast(clock::now() - start_time_).count()); + } + /** + * @brief Returns the elapsed time in milliseconds. + */ + inline float MilliSeconds() { return NanoSeconds() / 1000000.f; } + /** + * @brief Returns the elapsed time in microseconds. + */ + inline float MicroSeconds() { return NanoSeconds() / 1000.f; } + /** + * @brief Returns the elapsed time in seconds. + */ + inline float Seconds() { return NanoSeconds() / 1000000000.f; } + + protected: + std::chrono::time_point start_time_; + C10_DISABLE_COPY_AND_ASSIGN(Timer); +}; +} + +#endif // CAFFE2_CORE_TIMER_H_ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/perfkernels/batch_box_cox_vec.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/perfkernels/batch_box_cox_vec.h new file mode 100644 index 0000000000000000000000000000000000000000..7c7c0b7ec332ff3e66c897806c6e26dbbb7dee9d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/perfkernels/batch_box_cox_vec.h @@ -0,0 +1,326 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include +#include +#include "vectorizer.h" +#include + +namespace caffe2::details { + +namespace { +void TileIndicesInPlace(std::vector& v, const std::size_t D, const std::size_t K) { + auto n = v.size(); + v.resize(K * n); + for (const auto k : c10::irange(1, K)) { + for (const auto j : c10::irange(n)) { + v[k * n + j] = v[j] + k * D; + } + } +} + +// MKL VML function templates. +template +void PackV(const int N, const T* a, const int* ia, T* y); +template +void UnpackV(const int N, const T* a, T* y, const int* iy); + +#define DELEGATE_PACKV_FUNCTION(T, OriginalFunc) \ + template <> \ + void PackV(const int N, const T* a, const int* ia, T* y) { \ + OriginalFunc(N, a, ia, y); \ + } +DELEGATE_PACKV_FUNCTION(float, vsPackV) +DELEGATE_PACKV_FUNCTION(double, vdPackV) +#undef DELEGATE_PACKV_FUNCTION + +#define DELEGATE_UNPACKV_FUNCTION(T, OriginalFunc) \ + template <> \ + void UnpackV(const int N, const T* a, T* y, const int* iy) { \ + OriginalFunc(N, a, y, iy); \ + } +DELEGATE_UNPACKV_FUNCTION(float, vsUnpackV) +DELEGATE_UNPACKV_FUNCTION(double, vdUnpackV) +#undef DELEGATE_UNPACKV_FUNCTION + +#ifndef FAST_VECTORIZED_KERNEL +template +void box_cox_zero_lambda( + size_t D, + const T* const self_data, + const T* const lambda2_data, + T k_eps, + T* const output_data) { + int j = 0; + using Vec = at::vec::Vectorized; + constexpr int64_t VLEN = Vec::size(); + auto k_eps_vec = Vec(k_eps); + for(; j + VLEN < D; j += VLEN) { + auto data = Vec::loadu(self_data + j); + auto lambda2 = Vec::loadu(lambda2_data + j); + auto sum = data + lambda2; + auto max = at::vec::max(sum, k_eps_vec); + auto res = max.log(); + res.store(output_data + j); + } + for ( ;j < D; ++j) { + auto sum = self_data[j] + lambda2_data[j]; + auto max = std::max(sum, k_eps); + output_data[j] = std::log(max); + } +} + +template +at::vec::Vectorized box_cox_nonzero_lambda_impl( + at::vec::Vectorized data, + at::vec::Vectorized lambda1, + at::vec::Vectorized lambda2, + at::vec::Vectorized k_eps) { + auto sum = data + lambda2; + auto max = at::vec::max(sum, k_eps); + auto lambda_over_1 = at::vec::fast_recieprocal(lambda1); + auto pow = max.pow(lambda1); + return at::vec::fmsub(pow, lambda_over_1, lambda_over_1); +} + +template +void box_cox_nonzero_lambda( + int64_t D, + const T* data_ptr, + const T* lambda1_ptr, + const T* lambda2_ptr, + T k_eps, + T* out) { + + int j = 0; + using Vec = at::vec::Vectorized; + constexpr int64_t VLEN = Vec::size(); + auto k_eps_vec = Vec(k_eps); + for(; j + VLEN < D; j += VLEN) { + auto data = Vec::loadu(data_ptr + j); + auto lambda1 = Vec::loadu(lambda1_ptr + j); + auto lambda2 = Vec::loadu(lambda2_ptr + j); + auto res = box_cox_nonzero_lambda_impl(data, lambda1, lambda2, k_eps_vec); + res.store(out + j); + } + if (j < D) { + auto remaining = D - j; + auto data = Vec::loadu(data_ptr + j, remaining); + auto lambda1 = Vec::loadu(lambda1_ptr + j, remaining); + auto lambda2 = Vec::loadu(lambda2_ptr + j, remaining); + auto res = box_cox_nonzero_lambda_impl(data, lambda1, lambda2, k_eps_vec); + res.store(out + j, remaining); + } +} +#else +template +void box_cox_zero_lambda( + size_t D, + const T* const self_data, + const T* const lambda2_data, + T k_eps, + T* const output_data) { + VECTOR_LOOP for (auto j=0 ;j < D; ++j) { + auto sum = self_data[j] + lambda2_data[j]; + auto max = std::max(sum, k_eps); + output_data[j] = std::log(max); + } +} + +template +void box_cox_nonzero_lambda( + int64_t D, + const T* data_ptr, + const T* lambda1_ptr, + const T* lambda2_ptr, + T k_eps, + T* out) { + + VECTOR_LOOP for (auto j=0 ;j < D; ++j) { + FAST_MATH + auto sum = data_ptr[j] + lambda2_ptr[j]; + auto max = std::max(sum, k_eps); + auto lamda1 = lambda1_ptr[j]; + auto lambda_over_1 = 1 / lamda1; + if constexpr (std::is_same::value) { + lambda_over_1 = lambda_over_1 * (T{2} - lambda_over_1 * lamda1); + lambda_over_1 = lambda_over_1 * (T{2} - lambda_over_1 * lamda1); + } + auto pow = std::pow(max, lamda1); + out[j] = pow * lambda_over_1 - lambda_over_1; + } +} +#endif // FAST_VECTORIZED_KERNEL + +template +void box_cox_mixed_lambda( + const T* const self_data, + const std::vector& nonzeros, + const std::vector& zeros, + const T* const lambda1, + const T* const lambda2, + const T* const lambda2_z_, + T k_eps, + T* const buffer, + T* const output_data) { + PackV(nonzeros.size(), self_data, nonzeros.data(), buffer); + box_cox_nonzero_lambda( + nonzeros.size(), buffer, lambda1, lambda2, k_eps, buffer); + UnpackV(nonzeros.size(), buffer, output_data, nonzeros.data()); + + PackV(zeros.size(), self_data, zeros.data(), buffer); + box_cox_zero_lambda( + zeros.size(), buffer, lambda2_z_, k_eps, buffer); + UnpackV(zeros.size(), buffer, output_data, zeros.data()); +} + +template +void TileArrayIntoVector( + const T* const a, + const size_t D, + const int K, + std::vector& b) { + b.resize(K * D); + for (const auto k : c10::irange(K)) { + std::copy(a, a + D, b.begin() + k * D); + } +} + +template +void compute_batch_box_cox_vec_fma( + std::size_t N, + std::size_t D, + std::size_t block_size, + const T* self_data, + const T* __restrict lambda1_data, + const T* __restrict lambda2_data, + T* output_data) { + constexpr T k_eps = static_cast(1e-6); + + FOLLY_DECLARE_REUSED(zeros, std::vector); + FOLLY_DECLARE_REUSED(nonzeros, std::vector); + // Don't bother calling reserve; calls after the first will get a + // correctly-sized allocation anyway. + for (const auto j : c10::irange(D)) { + if (lambda1_data[j] == 0) { + zeros.push_back(j); + } else { + nonzeros.push_back(j); + } + } + + // Process K rows at a time for effective vectorization with small rows. + const auto K = std::min(N, (block_size + D - 1) / D); + + FOLLY_DECLARE_REUSED(lambda1_, std::vector); + FOLLY_DECLARE_REUSED(lambda2_, std::vector); + FOLLY_DECLARE_REUSED(lambda2_z_, std::vector); + + if (nonzeros.size() == D) { + // ((x + lambda2)^lambda1 - 1)/lambda1, if lambda1 != 0 + size_t i = 0; + if (K > 1) { + TileArrayIntoVector(lambda1_data, D, K, lambda1_); + TileArrayIntoVector(lambda2_data, D, K, lambda2_); + DCHECK_EQ(K * D, lambda1_.size()); + DCHECK_EQ(K * D, lambda2_.size()); + for (; i < N - K + 1; i += K, self_data += K * D, output_data += K * D) { + box_cox_nonzero_lambda( + K * D, + self_data, + lambda1_.data(), + lambda2_.data(), + k_eps, + output_data); + } + } + for (; i < N; i++, self_data += D, output_data += D) { + box_cox_nonzero_lambda( + D, self_data, lambda1_data, lambda2_data, k_eps, output_data); + } + } else if (zeros.size() == D) { + // ln(x + lambda2), if lambda1 == 0 + size_t i = 0; + if (K > 1) { + TileArrayIntoVector(lambda2_data, D, K, lambda2_z_); + DCHECK_EQ(K * D, lambda2_z_.size()); + for (; i < N - K + 1; i += K, self_data += K * D, output_data += K * D) { + box_cox_zero_lambda( + K * D, self_data, lambda2_z_.data(), k_eps, output_data); + } + } + for (; i < N; i++, self_data += D, output_data += D) { + box_cox_zero_lambda( + D, self_data, lambda2_data, k_eps, output_data); + } + } else { + // mix zeros and nonzeros + const size_t n = nonzeros.size(); + if (K > 1) { + TileIndicesInPlace(nonzeros, 0, K); + TileIndicesInPlace(zeros, 0, K); + } + + FOLLY_DECLARE_REUSED(buffer, std::vector); + + buffer.resize(std::max(nonzeros.size(), zeros.size())); + lambda1_.resize(nonzeros.size()); + lambda2_.resize(nonzeros.size()); + lambda2_z_.resize(zeros.size()); + PackV(nonzeros.size(), lambda1_data, nonzeros.data(), lambda1_.data()); + PackV(nonzeros.size(), lambda2_data, nonzeros.data(), lambda2_.data()); + PackV(zeros.size(), lambda2_data, zeros.data(), lambda2_z_.data()); + + size_t i = 0; + if (K > 1) { + // Truncate to original size, and re-tile with offsets this time. + nonzeros.resize(n); + DCHECK_GT(D, n); + zeros.resize(D - n); + TileIndicesInPlace(nonzeros, D, K); + TileIndicesInPlace(zeros, D, K); + DCHECK_EQ(nonzeros.size(), lambda1_.size()); + DCHECK_EQ(nonzeros.size(), lambda2_.size()); + DCHECK_EQ(zeros.size(), lambda2_z_.size()); + + for (; i < N - K + 1; i += K, self_data += K * D, output_data += K * D) { + box_cox_mixed_lambda( + self_data, + nonzeros, + zeros, + lambda1_.data(), + lambda2_.data(), + lambda2_z_.data(), + k_eps, + buffer.data(), + output_data); + } + // Truncate to original size. + nonzeros.resize(n); + zeros.resize(D - n); + } + for (; i < N; i++, self_data += D, output_data += D) { + box_cox_mixed_lambda( + self_data, + nonzeros, + zeros, + lambda1_.data(), + lambda2_.data(), + lambda2_z_.data(), + k_eps, + buffer.data(), + output_data); + } + } +} +} // namespace + +} // namespace caffe2::details + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/perfkernels/common.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/perfkernels/common.h new file mode 100644 index 0000000000000000000000000000000000000000..f927b1ac74631203bfb9ac4bf869d0e2fa7b0a7c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/perfkernels/common.h @@ -0,0 +1,145 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// !!!! PLEASE READ !!!! +// Minimize (transitively) included headers from _avx*.cc because some of the +// functions defined in the headers compiled with platform dependent compiler +// options can be reused by other translation units generating illegal +// instruction run-time error. + +// Common utilities for writing performance kernels and easy dispatching of +// different backends. +/* +The general workflow shall be as follows, say we want to +implement a functionality called void foo(int a, float b). + +In foo.h, do: + void foo(int a, float b); + +In foo_avx512.cc, do: + void foo__avx512(int a, float b) { + [actual avx512 implementation] + } + +In foo_avx2.cc, do: + void foo__avx2(int a, float b) { + [actual avx2 implementation] + } + +In foo_avx.cc, do: + void foo__avx(int a, float b) { + [actual avx implementation] + } + +In foo.cc, do: + // The base implementation should *always* be provided. + void foo__base(int a, float b) { + [base, possibly slow implementation] + } + decltype(foo__base) foo__avx512; + decltype(foo__base) foo__avx2; + decltype(foo__base) foo__avx; + void foo(int a, float b) { + // You should always order things by their preference, faster + // implementations earlier in the function. + AVX512_DO(foo, a, b); + AVX2_DO(foo, a, b); + AVX_DO(foo, a, b); + BASE_DO(foo, a, b); + } + +*/ +// Details: this functionality basically covers the cases for both build time +// and run time architecture support. +// +// During build time: +// The build system should provide flags CAFFE2_PERF_WITH_AVX512, +// CAFFE2_PERF_WITH_AVX2, and CAFFE2_PERF_WITH_AVX that corresponds to the +// __AVX512F__, __AVX512DQ__, __AVX512VL__, __AVX2__, and __AVX__ flags the +// compiler provides. Note that we do not use the compiler flags but rely on +// the build system flags, because the common files (like foo.cc above) will +// always be built without __AVX512F__, __AVX512DQ__, __AVX512VL__, __AVX2__ +// and __AVX__. +// During run time: +// we use cpuinfo to identify cpu support and run the proper functions. + +#pragma once +#if defined(CAFFE2_PERF_WITH_SVE) || defined(CAFFE2_PERF_WITH_AVX512) || \ + defined(CAFFE2_PERF_WITH_AVX2) || defined(CAFFE2_PERF_WITH_AVX) +#include +#endif + +// DO macros: these should be used in your entry function, similar to foo() +// above, that routes implementations based on CPU capability. + +#define BASE_DO(funcname, ...) return funcname##__base(__VA_ARGS__); + +#ifdef CAFFE2_PERF_WITH_SVE +#define SVE_DO(funcname, ...) \ + { \ + static const bool isDo = cpuinfo_initialize() && cpuinfo_has_arm_sve(); \ + if (isDo) { \ + return funcname##__sve(__VA_ARGS__); \ + } \ + } +#else // CAFFE2_PERF_WITH_SVE +#define SVE_DO(funcname, ...) +#endif // CAFFE2_PERF_WITH_SVE + +#ifdef CAFFE2_PERF_WITH_AVX512 +#define AVX512_DO(funcname, ...) \ + { \ + static const bool isDo = cpuinfo_initialize() && \ + cpuinfo_has_x86_avx512f() && cpuinfo_has_x86_avx512dq() && \ + cpuinfo_has_x86_avx512vl(); \ + if (isDo) { \ + return funcname##__avx512(__VA_ARGS__); \ + } \ + } +#else // CAFFE2_PERF_WITH_AVX512 +#define AVX512_DO(funcname, ...) +#endif // CAFFE2_PERF_WITH_AVX512 + +#ifdef CAFFE2_PERF_WITH_AVX2 +#define AVX2_DO(funcname, ...) \ + { \ + static const bool isDo = cpuinfo_initialize() && cpuinfo_has_x86_avx2(); \ + if (isDo) { \ + return funcname##__avx2(__VA_ARGS__); \ + } \ + } +#define AVX2_FMA_DO(funcname, ...) \ + { \ + static const bool isDo = cpuinfo_initialize() && cpuinfo_has_x86_avx2() && \ + cpuinfo_has_x86_fma3(); \ + if (isDo) { \ + return funcname##__avx2_fma(__VA_ARGS__); \ + } \ + } +#else // CAFFE2_PERF_WITH_AVX2 +#define AVX2_DO(funcname, ...) +#define AVX2_FMA_DO(funcname, ...) +#endif // CAFFE2_PERF_WITH_AVX2 + +#ifdef CAFFE2_PERF_WITH_AVX +#define AVX_DO(funcname, ...) \ + { \ + static const bool isDo = cpuinfo_initialize() && cpuinfo_has_x86_avx(); \ + if (isDo) { \ + return funcname##__avx(__VA_ARGS__); \ + } \ + } +#define AVX_F16C_DO(funcname, ...) \ + { \ + static const bool isDo = cpuinfo_initialize() && cpuinfo_has_x86_avx() && \ + cpuinfo_has_x86_f16c(); \ + if (isDo) { \ + return funcname##__avx_f16c(__VA_ARGS__); \ + } \ + } +#else // CAFFE2_PERF_WITH_AVX +#define AVX_DO(funcname, ...) +#define AVX_F16C_DO(funcname, ...) +#endif // CAFFE2_PERF_WITH_AVX + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/perfkernels/embedding_lookup_idx.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/perfkernels/embedding_lookup_idx.h new file mode 100644 index 0000000000000000000000000000000000000000..45eb7106de95e6ae73e4a99b020339aadb7fc527 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/perfkernels/embedding_lookup_idx.h @@ -0,0 +1,62 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include + +namespace caffe2 { + +// clang-format off +/** + * Embedding lookup with reduction. + * + * `input` of size data_size * block_size + * `indices` of size index_size + * `offsets` of size output_size + * `weights` nullptr or array of size index_size + * `out` of size output_size * block_size + * + * Behavior is roughly equivalent to pseudocode: + * + * pos = 0 + * for (i = 0..output_size-1) + * for (k = 0..block_size-1) + * out[i*block_size + k] = 0 + * start_offset = offsets[i] + * end_offset = offsets[i+1] + * length = end_offset - start_offset + * for (j = start_offset..end_offset-1) + * for (k = 0..block_size-1) + * out[i*block_size + k] += input[indices[pos]*block_size + k] * + * (weights ? weights[IS_WEIGHT_POSITIONAL ? j - start_offset : pos] : 1.0) + * pos += 1 + * if (normalize_weights && length > 0) + * for (k = 0..block_size-1) + * out[i*block_size + k] /= length + * + * TODO: make this API also take "offsets" rather than "lengths" to match the + * API for PyTorch's EmbeddingBag + */ +// clang-format on +template < + typename IndexType, + typename InType, + typename OutType, + bool IS_WEIGHT_POSITIONAL = false> +void EmbeddingLookupIdx( + const std::int64_t block_size, + const std::int64_t output_size, + const std::int64_t index_size, + const std::int64_t data_size, + const InType* input, + const IndexType* indices, + const IndexType* offsets, + const float* weights, // optional, can be null for non-weighted sum + const float* scale_bias, // optional scale & bias params for uint8 input + bool normalize_by_lengths, + OutType* out); + +} // namespace caffe2 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/serialize/crc_alt.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/serialize/crc_alt.h new file mode 100644 index 0000000000000000000000000000000000000000..5586b37e59707104b1138b4f806200ded8466e87 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/serialize/crc_alt.h @@ -0,0 +1,1348 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +// ////////////////////////////////////////////////////////// +// Crc32.h +// Copyright (c) 2011-2019 Stephan Brumme. All rights reserved. +// Slicing-by-16 contributed by Bulat Ziganshin +// Tableless bytewise CRC contributed by Hagai Gold +// see http://create.stephan-brumme.com/disclaimer.html +// + +// if running on an embedded system, you might consider shrinking the +// big Crc32Lookup table by undefining these lines: +#define CRC32_USE_LOOKUP_TABLE_BYTE +#define CRC32_USE_LOOKUP_TABLE_SLICING_BY_4 +#define CRC32_USE_LOOKUP_TABLE_SLICING_BY_8 +#define CRC32_USE_LOOKUP_TABLE_SLICING_BY_16 +// - crc32_bitwise doesn't need it at all +// - crc32_halfbyte has its own small lookup table +// - crc32_1byte_tableless and crc32_1byte_tableless2 don't need it at all +// - crc32_1byte needs only Crc32Lookup[0] +// - crc32_4bytes needs only Crc32Lookup[0..3] +// - crc32_8bytes needs only Crc32Lookup[0..7] +// - crc32_4x8bytes needs only Crc32Lookup[0..7] +// - crc32_16bytes needs all of Crc32Lookup +// using the aforementioned #defines the table is automatically fitted to your needs + +// uint8_t, uint32_t, int32_t +#include +// size_t +#include + +// crc32_fast selects the fastest algorithm depending on flags (CRC32_USE_LOOKUP_...) +/// compute CRC32 using the fastest algorithm for large datasets on modern CPUs +uint32_t crc32_fast (const void* data, size_t length, uint32_t previousCrc32 = 0); + +/// merge two CRC32 such that result = crc32(dataB, lengthB, crc32(dataA, lengthA)) +uint32_t crc32_combine (uint32_t crcA, uint32_t crcB, size_t lengthB); + +/// compute CRC32 (bitwise algorithm) +uint32_t crc32_bitwise (const void* data, size_t length, uint32_t previousCrc32 = 0); +/// compute CRC32 (half-byte algorithm) +uint32_t crc32_halfbyte(const void* data, size_t length, uint32_t previousCrc32 = 0); + +#ifdef CRC32_USE_LOOKUP_TABLE_BYTE +/// compute CRC32 (standard algorithm) +uint32_t crc32_1byte (const void* data, size_t length, uint32_t previousCrc32 = 0); +#endif + +/// compute CRC32 (byte algorithm) without lookup tables +uint32_t crc32_1byte_tableless (const void* data, size_t length, uint32_t previousCrc32 = 0); +/// compute CRC32 (byte algorithm) without lookup tables +uint32_t crc32_1byte_tableless2(const void* data, size_t length, uint32_t previousCrc32 = 0); + +#ifdef CRC32_USE_LOOKUP_TABLE_SLICING_BY_4 +/// compute CRC32 (Slicing-by-4 algorithm) +uint32_t crc32_4bytes (const void* data, size_t length, uint32_t previousCrc32 = 0); +#endif + +#ifdef CRC32_USE_LOOKUP_TABLE_SLICING_BY_8 +/// compute CRC32 (Slicing-by-8 algorithm) +uint32_t crc32_8bytes (const void* data, size_t length, uint32_t previousCrc32 = 0); +/// compute CRC32 (Slicing-by-8 algorithm), unroll inner loop 4 times +uint32_t crc32_4x8bytes(const void* data, size_t length, uint32_t previousCrc32 = 0); +#endif + +#ifdef CRC32_USE_LOOKUP_TABLE_SLICING_BY_16 +/// compute CRC32 (Slicing-by-16 algorithm) +uint32_t crc32_16bytes (const void* data, size_t length, uint32_t previousCrc32 = 0); +/// compute CRC32 (Slicing-by-16 algorithm, prefetch upcoming data blocks) +uint32_t crc32_16bytes_prefetch(const void* data, size_t length, uint32_t previousCrc32 = 0, size_t prefetchAhead = 256); +#endif + +// ////////////////////////////////////////////////////////// +// Crc32.cpp +// Copyright (c) 2011-2019 Stephan Brumme. All rights reserved. +// Slicing-by-16 contributed by Bulat Ziganshin +// Tableless bytewise CRC contributed by Hagai Gold +// see http://create.stephan-brumme.com/disclaimer.html +// + +// if running on an embedded system, you might consider shrinking the +// big Crc32Lookup table: +// - crc32_bitwise doesn't need it at all +// - crc32_halfbyte has its own small lookup table +// - crc32_1byte needs only Crc32Lookup[0] +// - crc32_4bytes needs only Crc32Lookup[0..3] +// - crc32_8bytes needs only Crc32Lookup[0..7] +// - crc32_4x8bytes needs only Crc32Lookup[0..7] +// - crc32_16bytes needs all of Crc32Lookup + + +#ifndef __LITTLE_ENDIAN + #define __LITTLE_ENDIAN 1234 +#endif +#ifndef __BIG_ENDIAN + #define __BIG_ENDIAN 4321 +#endif + +// define endianness and some integer data types +#if defined(_MSC_VER) || defined(__MINGW32__) + // Windows always little endian + #define __BYTE_ORDER __LITTLE_ENDIAN + + // intrinsics / prefetching + #if defined(_M_ARM64) + #include + #else + #include + #endif + + #ifdef __MINGW32__ + #define PREFETCH(location) __builtin_prefetch(location) + #else + #if defined(_M_ARM64) + #define PREFETCH(location) __prefetch(location) + #else + #define PREFETCH(location) _mm_prefetch(location, _MM_HINT_T0) + #endif + #endif +#elif defined(__APPLE__) + #include + #if TARGET_IPHONE_SIMULATOR + #define __BYTE_ORDER __LITTLE_ENDIAN + #elif TARGET_OS_IPHONE + #define __BYTE_ORDER __LITTLE_ENDIAN + #elif TARGET_OS_MAC + #include + #if defined(__BIG_ENDIAN__) + #define __BYTE_ORDER __BIG_ENDIAN + #endif + #if defined(__LITTLE_ENDIAN__) + #define __BYTE_ORDER __LITTLE_ENDIAN + #endif + #else + # error "Unknown Apple platform" + #endif +#elif defined(__ARMEB__) + #define __BYTE_ORDER __BIG_ENDIAN +#elif (defined(__BYTE_ORDER__) and !defined(__BYTE_ORDER)) + #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ + #define __BYTE_ORDER __BIG_ENDIAN + #else + #define __BYTE_ORDER __LITTLE_ENDIAN + #endif +#else + // defines __BYTE_ORDER as __LITTLE_ENDIAN or __BIG_ENDIAN + #include +#endif + +// intrinsics / prefetching +#ifdef __GNUC__ + #define PREFETCH(location) __builtin_prefetch(location) +#else +#ifndef PREFETCH + // no prefetching + #define PREFETCH(location) ; +#endif +#endif + +// abort if byte order is undefined +#ifndef __BYTE_ORDER +#error undefined byte order, compile with -D__BYTE_ORDER=1234 (if little endian) or -D__BYTE_ORDER=4321 (big endian) +#endif + + +namespace +{ + /// zlib's CRC32 polynomial + const uint32_t Polynomial = 0xEDB88320; + + /// swap endianness + static inline uint32_t swap(uint32_t x) + { + #if defined(__GNUC__) || defined(__clang__) + return __builtin_bswap32(x); + #else + return (x >> 24) | + ((x >> 8) & 0x0000FF00) | + ((x << 8) & 0x00FF0000) | + (x << 24); + #endif + } + + /// Slicing-By-16 + #ifdef CRC32_USE_LOOKUP_TABLE_SLICING_BY_16 + const size_t MaxSlice = 16; + #elif defined(CRC32_USE_LOOKUP_TABLE_SLICING_BY_8) + const size_t MaxSlice = 8; + #elif defined(CRC32_USE_LOOKUP_TABLE_SLICING_BY_4) + const size_t MaxSlice = 4; + #elif defined(CRC32_USE_LOOKUP_TABLE_BYTE) + const size_t MaxSlice = 1; + #else + #define NO_LUT // don't need Crc32Lookup at all + #endif + +} // anonymous namespace + +#ifndef NO_LUT +/// forward declaration, table is at the end of this file +extern const uint32_t Crc32Lookup[MaxSlice][256]; // extern is needed to keep compiler happy +#endif + + +/// compute CRC32 (bitwise algorithm) +uint32_t crc32_bitwise(const void* data, size_t length, uint32_t previousCrc32) +{ + uint32_t crc = ~previousCrc32; // same as previousCrc32 ^ 0xFFFFFFFF + const uint8_t* current = (const uint8_t*) data; + + while (length-- != 0) + { + crc ^= *current++; + + for (int j = 0; j < 8; j++) + { + // branch-free + crc = (crc >> 1) ^ (-int32_t(crc & 1) & Polynomial); + + // branching, much slower: + //if (crc & 1) + // crc = (crc >> 1) ^ Polynomial; + //else + // crc = crc >> 1; + } + } + + return ~crc; // same as crc ^ 0xFFFFFFFF +} + + +/// compute CRC32 (half-byte algorithm) +uint32_t crc32_halfbyte(const void* data, size_t length, uint32_t previousCrc32) +{ + uint32_t crc = ~previousCrc32; // same as previousCrc32 ^ 0xFFFFFFFF + const uint8_t* current = (const uint8_t*) data; + + /// look-up table for half-byte, same as crc32Lookup[0][16*i] + static const uint32_t Crc32Lookup16[16] = + { + 0x00000000,0x1DB71064,0x3B6E20C8,0x26D930AC,0x76DC4190,0x6B6B51F4,0x4DB26158,0x5005713C, + 0xEDB88320,0xF00F9344,0xD6D6A3E8,0xCB61B38C,0x9B64C2B0,0x86D3D2D4,0xA00AE278,0xBDBDF21C + }; + + while (length-- != 0) + { + crc = Crc32Lookup16[(crc ^ *current ) & 0x0F] ^ (crc >> 4); + crc = Crc32Lookup16[(crc ^ (*current >> 4)) & 0x0F] ^ (crc >> 4); + current++; + } + + return ~crc; // same as crc ^ 0xFFFFFFFF +} + + +#ifdef CRC32_USE_LOOKUP_TABLE_BYTE +/// compute CRC32 (standard algorithm) +uint32_t crc32_1byte(const void* data, size_t length, uint32_t previousCrc32) +{ + uint32_t crc = ~previousCrc32; // same as previousCrc32 ^ 0xFFFFFFFF + const uint8_t* current = (const uint8_t*) data; + + while (length-- != 0) + crc = (crc >> 8) ^ Crc32Lookup[0][(crc & 0xFF) ^ *current++]; + + return ~crc; // same as crc ^ 0xFFFFFFFF +} +#endif + + +/// compute CRC32 (byte algorithm) without lookup tables +uint32_t crc32_1byte_tableless(const void* data, size_t length, uint32_t previousCrc32) +{ + uint32_t crc = ~previousCrc32; // same as previousCrc32 ^ 0xFFFFFFFF + const uint8_t* current = (const uint8_t*) data; + + while (length-- != 0) + { + uint8_t s = uint8_t(crc) ^ *current++; + + // Hagai Gold made me aware of this table-less algorithm and send me code + + // polynomial 0xEDB88320 can be written in binary as 11101101101110001000001100100000b + // reverse the bits (or just assume bit 0 is the first one) + // and we have bits set at position 0, 1, 2, 4, 5, 7, 8, 10, 11, 12, 16, 22, 23, 26 + // => those are the shift offsets: + //crc = (crc >> 8) ^ + // t ^ + // (t >> 1) ^ (t >> 2) ^ (t >> 4) ^ (t >> 5) ^ // == y + // (t >> 7) ^ (t >> 8) ^ (t >> 10) ^ (t >> 11) ^ // == y >> 6 + // (t >> 12) ^ (t >> 16) ^ // == z + // (t >> 22) ^ (t >> 26) ^ // == z >> 10 + // (t >> 23); + + // the fastest I can come up with: + uint32_t low = (s ^ (s << 6)) & 0xFF; + uint32_t a = (low * ((1 << 23) + (1 << 14) + (1 << 2))); + crc = (crc >> 8) ^ + (low * ((1 << 24) + (1 << 16) + (1 << 8))) ^ + a ^ + (a >> 1) ^ + (low * ((1 << 20) + (1 << 12) )) ^ + (low << 19) ^ + (low << 17) ^ + (low >> 2); + + // Hagai's code: + /*uint32_t t = (s ^ (s << 6)) << 24; + // some temporaries to optimize XOR + uint32_t x = (t >> 1) ^ (t >> 2); + uint32_t y = x ^ (x >> 3); + uint32_t z = (t >> 12) ^ (t >> 16); + crc = (crc >> 8) ^ + t ^ (t >> 23) ^ + y ^ (y >> 6) ^ + z ^ (z >> 10);*/ + } + + return ~crc; // same as crc ^ 0xFFFFFFFF +} + + +/// compute CRC32 (byte algorithm) without lookup tables +uint32_t crc32_1byte_tableless2(const void* data, size_t length, uint32_t previousCrc32) +{ + int32_t crc = ~previousCrc32; // note: signed integer, right shift distributes sign bit into lower bits + const uint8_t* current = (const uint8_t*) data; + + while (length-- != 0) + { + crc = crc ^ *current++; + + uint32_t c = (((crc << 31) >> 31) & ((Polynomial >> 7) ^ (Polynomial >> 1))) ^ + (((crc << 30) >> 31) & ((Polynomial >> 6) ^ Polynomial)) ^ + (((crc << 29) >> 31) & (Polynomial >> 5)) ^ + (((crc << 28) >> 31) & (Polynomial >> 4)) ^ + (((crc << 27) >> 31) & (Polynomial >> 3)) ^ + (((crc << 26) >> 31) & (Polynomial >> 2)) ^ + (((crc << 25) >> 31) & (Polynomial >> 1)) ^ + (((crc << 24) >> 31) & Polynomial); + + crc = ((uint32_t)crc >> 8) ^ c; // convert to unsigned integer before right shift + } + + return ~crc; // same as crc ^ 0xFFFFFFFF +} + + +#ifdef CRC32_USE_LOOKUP_TABLE_SLICING_BY_4 +/// compute CRC32 (Slicing-by-4 algorithm) +uint32_t crc32_4bytes(const void* data, size_t length, uint32_t previousCrc32) +{ + uint32_t crc = ~previousCrc32; // same as previousCrc32 ^ 0xFFFFFFFF + const uint32_t* current = (const uint32_t*) data; + + // process four bytes at once (Slicing-by-4) + while (length >= 4) + { +#if __BYTE_ORDER == __BIG_ENDIAN + uint32_t one = *current++ ^ swap(crc); + crc = Crc32Lookup[0][ one & 0xFF] ^ + Crc32Lookup[1][(one>> 8) & 0xFF] ^ + Crc32Lookup[2][(one>>16) & 0xFF] ^ + Crc32Lookup[3][(one>>24) & 0xFF]; +#else + uint32_t one = *current++ ^ crc; + crc = Crc32Lookup[0][(one>>24) & 0xFF] ^ + Crc32Lookup[1][(one>>16) & 0xFF] ^ + Crc32Lookup[2][(one>> 8) & 0xFF] ^ + Crc32Lookup[3][ one & 0xFF]; +#endif + + length -= 4; + } + + const uint8_t* currentChar = (const uint8_t*) current; + // remaining 1 to 3 bytes (standard algorithm) + while (length-- != 0) + crc = (crc >> 8) ^ Crc32Lookup[0][(crc & 0xFF) ^ *currentChar++]; + + return ~crc; // same as crc ^ 0xFFFFFFFF +} +#endif + + +#ifdef CRC32_USE_LOOKUP_TABLE_SLICING_BY_8 +/// compute CRC32 (Slicing-by-8 algorithm) +uint32_t crc32_8bytes(const void* data, size_t length, uint32_t previousCrc32) +{ + uint32_t crc = ~previousCrc32; // same as previousCrc32 ^ 0xFFFFFFFF + const uint32_t* current = (const uint32_t*) data; + + // process eight bytes at once (Slicing-by-8) + while (length >= 8) + { +#if __BYTE_ORDER == __BIG_ENDIAN + uint32_t one = *current++ ^ swap(crc); + uint32_t two = *current++; + crc = Crc32Lookup[0][ two & 0xFF] ^ + Crc32Lookup[1][(two>> 8) & 0xFF] ^ + Crc32Lookup[2][(two>>16) & 0xFF] ^ + Crc32Lookup[3][(two>>24) & 0xFF] ^ + Crc32Lookup[4][ one & 0xFF] ^ + Crc32Lookup[5][(one>> 8) & 0xFF] ^ + Crc32Lookup[6][(one>>16) & 0xFF] ^ + Crc32Lookup[7][(one>>24) & 0xFF]; +#else + uint32_t one = *current++ ^ crc; + uint32_t two = *current++; + crc = Crc32Lookup[0][(two>>24) & 0xFF] ^ + Crc32Lookup[1][(two>>16) & 0xFF] ^ + Crc32Lookup[2][(two>> 8) & 0xFF] ^ + Crc32Lookup[3][ two & 0xFF] ^ + Crc32Lookup[4][(one>>24) & 0xFF] ^ + Crc32Lookup[5][(one>>16) & 0xFF] ^ + Crc32Lookup[6][(one>> 8) & 0xFF] ^ + Crc32Lookup[7][ one & 0xFF]; +#endif + + length -= 8; + } + + const uint8_t* currentChar = (const uint8_t*) current; + // remaining 1 to 7 bytes (standard algorithm) + while (length-- != 0) + crc = (crc >> 8) ^ Crc32Lookup[0][(crc & 0xFF) ^ *currentChar++]; + + return ~crc; // same as crc ^ 0xFFFFFFFF +} + + +/// compute CRC32 (Slicing-by-8 algorithm), unroll inner loop 4 times +uint32_t crc32_4x8bytes(const void* data, size_t length, uint32_t previousCrc32) +{ + uint32_t crc = ~previousCrc32; // same as previousCrc32 ^ 0xFFFFFFFF + const uint32_t* current = (const uint32_t*) data; + + // enabling optimization (at least -O2) automatically unrolls the inner for-loop + const size_t Unroll = 4; + const size_t BytesAtOnce = 8 * Unroll; + + // process 4x eight bytes at once (Slicing-by-8) + while (length >= BytesAtOnce) + { + for (size_t unrolling = 0; unrolling < Unroll; unrolling++) + { +#if __BYTE_ORDER == __BIG_ENDIAN + uint32_t one = *current++ ^ swap(crc); + uint32_t two = *current++; + crc = Crc32Lookup[0][ two & 0xFF] ^ + Crc32Lookup[1][(two>> 8) & 0xFF] ^ + Crc32Lookup[2][(two>>16) & 0xFF] ^ + Crc32Lookup[3][(two>>24) & 0xFF] ^ + Crc32Lookup[4][ one & 0xFF] ^ + Crc32Lookup[5][(one>> 8) & 0xFF] ^ + Crc32Lookup[6][(one>>16) & 0xFF] ^ + Crc32Lookup[7][(one>>24) & 0xFF]; +#else + uint32_t one = *current++ ^ crc; + uint32_t two = *current++; + crc = Crc32Lookup[0][(two>>24) & 0xFF] ^ + Crc32Lookup[1][(two>>16) & 0xFF] ^ + Crc32Lookup[2][(two>> 8) & 0xFF] ^ + Crc32Lookup[3][ two & 0xFF] ^ + Crc32Lookup[4][(one>>24) & 0xFF] ^ + Crc32Lookup[5][(one>>16) & 0xFF] ^ + Crc32Lookup[6][(one>> 8) & 0xFF] ^ + Crc32Lookup[7][ one & 0xFF]; +#endif + + } + + length -= BytesAtOnce; + } + + const uint8_t* currentChar = (const uint8_t*) current; + // remaining 1 to 31 bytes (standard algorithm) + while (length-- != 0) + crc = (crc >> 8) ^ Crc32Lookup[0][(crc & 0xFF) ^ *currentChar++]; + + return ~crc; // same as crc ^ 0xFFFFFFFF +} +#endif // CRC32_USE_LOOKUP_TABLE_SLICING_BY_8 + + +#ifdef CRC32_USE_LOOKUP_TABLE_SLICING_BY_16 +/// compute CRC32 (Slicing-by-16 algorithm) +uint32_t crc32_16bytes(const void* data, size_t length, uint32_t previousCrc32) +{ + uint32_t crc = ~previousCrc32; // same as previousCrc32 ^ 0xFFFFFFFF + const uint32_t* current = (const uint32_t*) data; + + // enabling optimization (at least -O2) automatically unrolls the inner for-loop + const size_t Unroll = 4; + const size_t BytesAtOnce = 16 * Unroll; + + while (length >= BytesAtOnce) + { + for (size_t unrolling = 0; unrolling < Unroll; unrolling++) + { +#if __BYTE_ORDER == __BIG_ENDIAN + uint32_t one = *current++ ^ swap(crc); + uint32_t two = *current++; + uint32_t three = *current++; + uint32_t four = *current++; + crc = Crc32Lookup[ 0][ four & 0xFF] ^ + Crc32Lookup[ 1][(four >> 8) & 0xFF] ^ + Crc32Lookup[ 2][(four >> 16) & 0xFF] ^ + Crc32Lookup[ 3][(four >> 24) & 0xFF] ^ + Crc32Lookup[ 4][ three & 0xFF] ^ + Crc32Lookup[ 5][(three >> 8) & 0xFF] ^ + Crc32Lookup[ 6][(three >> 16) & 0xFF] ^ + Crc32Lookup[ 7][(three >> 24) & 0xFF] ^ + Crc32Lookup[ 8][ two & 0xFF] ^ + Crc32Lookup[ 9][(two >> 8) & 0xFF] ^ + Crc32Lookup[10][(two >> 16) & 0xFF] ^ + Crc32Lookup[11][(two >> 24) & 0xFF] ^ + Crc32Lookup[12][ one & 0xFF] ^ + Crc32Lookup[13][(one >> 8) & 0xFF] ^ + Crc32Lookup[14][(one >> 16) & 0xFF] ^ + Crc32Lookup[15][(one >> 24) & 0xFF]; +#else + uint32_t one = *current++ ^ crc; + uint32_t two = *current++; + uint32_t three = *current++; + uint32_t four = *current++; + crc = Crc32Lookup[ 0][(four >> 24) & 0xFF] ^ + Crc32Lookup[ 1][(four >> 16) & 0xFF] ^ + Crc32Lookup[ 2][(four >> 8) & 0xFF] ^ + Crc32Lookup[ 3][ four & 0xFF] ^ + Crc32Lookup[ 4][(three >> 24) & 0xFF] ^ + Crc32Lookup[ 5][(three >> 16) & 0xFF] ^ + Crc32Lookup[ 6][(three >> 8) & 0xFF] ^ + Crc32Lookup[ 7][ three & 0xFF] ^ + Crc32Lookup[ 8][(two >> 24) & 0xFF] ^ + Crc32Lookup[ 9][(two >> 16) & 0xFF] ^ + Crc32Lookup[10][(two >> 8) & 0xFF] ^ + Crc32Lookup[11][ two & 0xFF] ^ + Crc32Lookup[12][(one >> 24) & 0xFF] ^ + Crc32Lookup[13][(one >> 16) & 0xFF] ^ + Crc32Lookup[14][(one >> 8) & 0xFF] ^ + Crc32Lookup[15][ one & 0xFF]; +#endif + } + + length -= BytesAtOnce; + } + + const uint8_t* currentChar = (const uint8_t*) current; + // remaining 1 to 63 bytes (standard algorithm) + while (length-- != 0) + crc = (crc >> 8) ^ Crc32Lookup[0][(crc & 0xFF) ^ *currentChar++]; + + return ~crc; // same as crc ^ 0xFFFFFFFF +} + + +/// compute CRC32 (Slicing-by-16 algorithm, prefetch upcoming data blocks) +uint32_t crc32_16bytes_prefetch(const void* data, size_t length, uint32_t previousCrc32, size_t prefetchAhead) +{ + // CRC code is identical to crc32_16bytes (including unrolling), only added prefetching + // 256 bytes look-ahead seems to be the sweet spot on Core i7 CPUs + + uint32_t crc = ~previousCrc32; // same as previousCrc32 ^ 0xFFFFFFFF + const uint32_t* current = (const uint32_t*) data; + + // enabling optimization (at least -O2) automatically unrolls the for-loop + const size_t Unroll = 4; + const size_t BytesAtOnce = 16 * Unroll; + + while (length >= BytesAtOnce + prefetchAhead) + { + PREFETCH(((const char*) current) + prefetchAhead); + + for (size_t unrolling = 0; unrolling < Unroll; unrolling++) + { +#if __BYTE_ORDER == __BIG_ENDIAN + uint32_t one = *current++ ^ swap(crc); + uint32_t two = *current++; + uint32_t three = *current++; + uint32_t four = *current++; + crc = Crc32Lookup[ 0][ four & 0xFF] ^ + Crc32Lookup[ 1][(four >> 8) & 0xFF] ^ + Crc32Lookup[ 2][(four >> 16) & 0xFF] ^ + Crc32Lookup[ 3][(four >> 24) & 0xFF] ^ + Crc32Lookup[ 4][ three & 0xFF] ^ + Crc32Lookup[ 5][(three >> 8) & 0xFF] ^ + Crc32Lookup[ 6][(three >> 16) & 0xFF] ^ + Crc32Lookup[ 7][(three >> 24) & 0xFF] ^ + Crc32Lookup[ 8][ two & 0xFF] ^ + Crc32Lookup[ 9][(two >> 8) & 0xFF] ^ + Crc32Lookup[10][(two >> 16) & 0xFF] ^ + Crc32Lookup[11][(two >> 24) & 0xFF] ^ + Crc32Lookup[12][ one & 0xFF] ^ + Crc32Lookup[13][(one >> 8) & 0xFF] ^ + Crc32Lookup[14][(one >> 16) & 0xFF] ^ + Crc32Lookup[15][(one >> 24) & 0xFF]; +#else + uint32_t one = *current++ ^ crc; + uint32_t two = *current++; + uint32_t three = *current++; + uint32_t four = *current++; + crc = Crc32Lookup[ 0][(four >> 24) & 0xFF] ^ + Crc32Lookup[ 1][(four >> 16) & 0xFF] ^ + Crc32Lookup[ 2][(four >> 8) & 0xFF] ^ + Crc32Lookup[ 3][ four & 0xFF] ^ + Crc32Lookup[ 4][(three >> 24) & 0xFF] ^ + Crc32Lookup[ 5][(three >> 16) & 0xFF] ^ + Crc32Lookup[ 6][(three >> 8) & 0xFF] ^ + Crc32Lookup[ 7][ three & 0xFF] ^ + Crc32Lookup[ 8][(two >> 24) & 0xFF] ^ + Crc32Lookup[ 9][(two >> 16) & 0xFF] ^ + Crc32Lookup[10][(two >> 8) & 0xFF] ^ + Crc32Lookup[11][ two & 0xFF] ^ + Crc32Lookup[12][(one >> 24) & 0xFF] ^ + Crc32Lookup[13][(one >> 16) & 0xFF] ^ + Crc32Lookup[14][(one >> 8) & 0xFF] ^ + Crc32Lookup[15][ one & 0xFF]; +#endif + } + + length -= BytesAtOnce; + } + + const uint8_t* currentChar = (const uint8_t*) current; + // remaining 1 to 63 bytes (standard algorithm) + while (length-- != 0) + crc = (crc >> 8) ^ Crc32Lookup[0][(crc & 0xFF) ^ *currentChar++]; + + return ~crc; // same as crc ^ 0xFFFFFFFF +} +#endif + + +/// compute CRC32 using the fastest algorithm for large datasets on modern CPUs +uint32_t crc32_fast(const void* data, size_t length, uint32_t previousCrc32) +{ +#ifdef CRC32_USE_LOOKUP_TABLE_SLICING_BY_16 + return crc32_16bytes (data, length, previousCrc32); +#elif defined(CRC32_USE_LOOKUP_TABLE_SLICING_BY_8) + return crc32_8bytes (data, length, previousCrc32); +#elif defined(CRC32_USE_LOOKUP_TABLE_SLICING_BY_4) + return crc32_4bytes (data, length, previousCrc32); +#elif defined(CRC32_USE_LOOKUP_TABLE_BYTE) + return crc32_1byte (data, length, previousCrc32); +#else + return crc32_halfbyte(data, length, previousCrc32); +#endif +} + + +/// merge two CRC32 such that result = crc32(dataB, lengthB, crc32(dataA, lengthA)) +uint32_t crc32_combine(uint32_t crcA, uint32_t crcB, size_t lengthB) +{ + // based on Mark Adler's crc_combine from + // https://github.com/madler/pigz/blob/master/pigz.c + + // main idea: + // - if you have two equally-sized blocks A and B, + // then you can create a block C = A ^ B + // which has the property crc(C) = crc(A) ^ crc(B) + // - if you append length(B) zeros to A and call it A' (think of it as AAAA000) + // and prepend length(A) zeros to B and call it B' (think of it as 0000BBB) + // then exists a C' = A' ^ B' + // - remember: if you XOR something with zero, it remains unchanged: X ^ 0 = X + // - that means C' = A concat B so that crc(A concat B) = crc(C') = crc(A') ^ crc(B') + // - the trick is to compute crc(A') based on crc(A) + // and crc(B') based on crc(B) + // - since B' starts with many zeros, the crc of those initial zeros is still zero + // - that means crc(B') = crc(B) + // - unfortunately the trailing zeros of A' change the crc, so usually crc(A') != crc(A) + // - the following code is a fast algorithm to compute crc(A') + // - starting with crc(A) and appending length(B) zeros, needing just log2(length(B)) iterations + // - the details are explained by the original author at + // https://stackoverflow.com/questions/23122312/crc-calculation-of-a-mostly-static-data-stream/23126768 + // + // notes: + // - I squeezed everything into one function to keep global namespace clean (original code two helper functions) + // - most original comments are still in place, I added comments where these helper functions where made inline code + // - performance-wise there isn't any differenze to the original zlib/pigz code + + // degenerated case + if (lengthB == 0) + return crcA; + + /// CRC32 => 32 bits + const uint32_t CrcBits = 32; + + uint32_t odd [CrcBits]; // odd-power-of-two zeros operator + uint32_t even[CrcBits]; // even-power-of-two zeros operator + + // put operator for one zero bit in odd + odd[0] = Polynomial; // CRC-32 polynomial + for (uint32_t i = 1; i < CrcBits; i++) + odd[i] = 1 << (i - 1); + + // put operator for two zero bits in even + // same as gf2_matrix_square(even, odd); + for (uint32_t i = 0; i < CrcBits; i++) + { + uint32_t vec = odd[i]; + even[i] = 0; + for (int j = 0; vec != 0; j++, vec >>= 1) + if (vec & 1) + even[i] ^= odd[j]; + } + // put operator for four zero bits in odd + // same as gf2_matrix_square(odd, even); + for (uint32_t i = 0; i < CrcBits; i++) + { + uint32_t vec = even[i]; + odd[i] = 0; + for (int j = 0; vec != 0; j++, vec >>= 1) + if (vec & 1) + odd[i] ^= even[j]; + } + + // the following loop becomes much shorter if I keep swapping even and odd + uint32_t* a = even; + uint32_t* b = odd; + // apply secondLength zeros to firstCrc32 + for (; lengthB > 0; lengthB >>= 1) + { + // same as gf2_matrix_square(a, b); + for (uint32_t i = 0; i < CrcBits; i++) + { + uint32_t vec = b[i]; + a[i] = 0; + for (int j = 0; vec != 0; j++, vec >>= 1) + if (vec & 1) + a[i] ^= b[j]; + } + + // apply zeros operator for this bit + if (lengthB & 1) + { + // same as firstCrc32 = gf2_matrix_times(a, firstCrc32); + uint32_t sum = 0; + for (int i = 0; crcA != 0; i++, crcA >>= 1) + if (crcA & 1) + sum ^= a[i]; + crcA = sum; + } + + // switch even and odd + uint32_t* t = a; a = b; b = t; + } + + // return combined crc + return crcA ^ crcB; +} + + +// ////////////////////////////////////////////////////////// +// constants + + +#ifndef NO_LUT +/// look-up table, already declared above +const uint32_t Crc32Lookup[MaxSlice][256] = +{ + //// same algorithm as crc32_bitwise + //for (int i = 0; i <= 0xFF; i++) + //{ + // uint32_t crc = i; + // for (int j = 0; j < 8; j++) + // crc = (crc >> 1) ^ ((crc & 1) * Polynomial); + // Crc32Lookup[0][i] = crc; + //} + //// ... and the following slicing-by-8 algorithm (from Intel): + //// http://www.intel.com/technology/comms/perfnet/download/CRC_generators.pdf + //// http://sourceforge.net/projects/slicing-by-8/ + //for (int slice = 1; slice < MaxSlice; slice++) + // Crc32Lookup[slice][i] = (Crc32Lookup[slice - 1][i] >> 8) ^ Crc32Lookup[0][Crc32Lookup[slice - 1][i] & 0xFF]; + { + // note: the first number of every second row corresponds to the half-byte look-up table ! + 0x00000000,0x77073096,0xEE0E612C,0x990951BA,0x076DC419,0x706AF48F,0xE963A535,0x9E6495A3, + 0x0EDB8832,0x79DCB8A4,0xE0D5E91E,0x97D2D988,0x09B64C2B,0x7EB17CBD,0xE7B82D07,0x90BF1D91, + 0x1DB71064,0x6AB020F2,0xF3B97148,0x84BE41DE,0x1ADAD47D,0x6DDDE4EB,0xF4D4B551,0x83D385C7, + 0x136C9856,0x646BA8C0,0xFD62F97A,0x8A65C9EC,0x14015C4F,0x63066CD9,0xFA0F3D63,0x8D080DF5, + 0x3B6E20C8,0x4C69105E,0xD56041E4,0xA2677172,0x3C03E4D1,0x4B04D447,0xD20D85FD,0xA50AB56B, + 0x35B5A8FA,0x42B2986C,0xDBBBC9D6,0xACBCF940,0x32D86CE3,0x45DF5C75,0xDCD60DCF,0xABD13D59, + 0x26D930AC,0x51DE003A,0xC8D75180,0xBFD06116,0x21B4F4B5,0x56B3C423,0xCFBA9599,0xB8BDA50F, + 0x2802B89E,0x5F058808,0xC60CD9B2,0xB10BE924,0x2F6F7C87,0x58684C11,0xC1611DAB,0xB6662D3D, + 0x76DC4190,0x01DB7106,0x98D220BC,0xEFD5102A,0x71B18589,0x06B6B51F,0x9FBFE4A5,0xE8B8D433, + 0x7807C9A2,0x0F00F934,0x9609A88E,0xE10E9818,0x7F6A0DBB,0x086D3D2D,0x91646C97,0xE6635C01, + 0x6B6B51F4,0x1C6C6162,0x856530D8,0xF262004E,0x6C0695ED,0x1B01A57B,0x8208F4C1,0xF50FC457, + 0x65B0D9C6,0x12B7E950,0x8BBEB8EA,0xFCB9887C,0x62DD1DDF,0x15DA2D49,0x8CD37CF3,0xFBD44C65, + 0x4DB26158,0x3AB551CE,0xA3BC0074,0xD4BB30E2,0x4ADFA541,0x3DD895D7,0xA4D1C46D,0xD3D6F4FB, + 0x4369E96A,0x346ED9FC,0xAD678846,0xDA60B8D0,0x44042D73,0x33031DE5,0xAA0A4C5F,0xDD0D7CC9, + 0x5005713C,0x270241AA,0xBE0B1010,0xC90C2086,0x5768B525,0x206F85B3,0xB966D409,0xCE61E49F, + 0x5EDEF90E,0x29D9C998,0xB0D09822,0xC7D7A8B4,0x59B33D17,0x2EB40D81,0xB7BD5C3B,0xC0BA6CAD, + 0xEDB88320,0x9ABFB3B6,0x03B6E20C,0x74B1D29A,0xEAD54739,0x9DD277AF,0x04DB2615,0x73DC1683, + 0xE3630B12,0x94643B84,0x0D6D6A3E,0x7A6A5AA8,0xE40ECF0B,0x9309FF9D,0x0A00AE27,0x7D079EB1, + 0xF00F9344,0x8708A3D2,0x1E01F268,0x6906C2FE,0xF762575D,0x806567CB,0x196C3671,0x6E6B06E7, + 0xFED41B76,0x89D32BE0,0x10DA7A5A,0x67DD4ACC,0xF9B9DF6F,0x8EBEEFF9,0x17B7BE43,0x60B08ED5, + 0xD6D6A3E8,0xA1D1937E,0x38D8C2C4,0x4FDFF252,0xD1BB67F1,0xA6BC5767,0x3FB506DD,0x48B2364B, + 0xD80D2BDA,0xAF0A1B4C,0x36034AF6,0x41047A60,0xDF60EFC3,0xA867DF55,0x316E8EEF,0x4669BE79, + 0xCB61B38C,0xBC66831A,0x256FD2A0,0x5268E236,0xCC0C7795,0xBB0B4703,0x220216B9,0x5505262F, + 0xC5BA3BBE,0xB2BD0B28,0x2BB45A92,0x5CB36A04,0xC2D7FFA7,0xB5D0CF31,0x2CD99E8B,0x5BDEAE1D, + 0x9B64C2B0,0xEC63F226,0x756AA39C,0x026D930A,0x9C0906A9,0xEB0E363F,0x72076785,0x05005713, + 0x95BF4A82,0xE2B87A14,0x7BB12BAE,0x0CB61B38,0x92D28E9B,0xE5D5BE0D,0x7CDCEFB7,0x0BDBDF21, + 0x86D3D2D4,0xF1D4E242,0x68DDB3F8,0x1FDA836E,0x81BE16CD,0xF6B9265B,0x6FB077E1,0x18B74777, + 0x88085AE6,0xFF0F6A70,0x66063BCA,0x11010B5C,0x8F659EFF,0xF862AE69,0x616BFFD3,0x166CCF45, + 0xA00AE278,0xD70DD2EE,0x4E048354,0x3903B3C2,0xA7672661,0xD06016F7,0x4969474D,0x3E6E77DB, + 0xAED16A4A,0xD9D65ADC,0x40DF0B66,0x37D83BF0,0xA9BCAE53,0xDEBB9EC5,0x47B2CF7F,0x30B5FFE9, + 0xBDBDF21C,0xCABAC28A,0x53B39330,0x24B4A3A6,0xBAD03605,0xCDD70693,0x54DE5729,0x23D967BF, + 0xB3667A2E,0xC4614AB8,0x5D681B02,0x2A6F2B94,0xB40BBE37,0xC30C8EA1,0x5A05DF1B,0x2D02EF8D, + } + +#if defined(CRC32_USE_LOOKUP_TABLE_SLICING_BY_4) || defined(CRC32_USE_LOOKUP_TABLE_SLICING_BY_8) || defined(CRC32_USE_LOOKUP_TABLE_SLICING_BY_16) + // beyond this point only relevant for Slicing-by-4, Slicing-by-8 and Slicing-by-16 + ,{ + 0x00000000,0x191B3141,0x32366282,0x2B2D53C3,0x646CC504,0x7D77F445,0x565AA786,0x4F4196C7, + 0xC8D98A08,0xD1C2BB49,0xFAEFE88A,0xE3F4D9CB,0xACB54F0C,0xB5AE7E4D,0x9E832D8E,0x87981CCF, + 0x4AC21251,0x53D92310,0x78F470D3,0x61EF4192,0x2EAED755,0x37B5E614,0x1C98B5D7,0x05838496, + 0x821B9859,0x9B00A918,0xB02DFADB,0xA936CB9A,0xE6775D5D,0xFF6C6C1C,0xD4413FDF,0xCD5A0E9E, + 0x958424A2,0x8C9F15E3,0xA7B24620,0xBEA97761,0xF1E8E1A6,0xE8F3D0E7,0xC3DE8324,0xDAC5B265, + 0x5D5DAEAA,0x44469FEB,0x6F6BCC28,0x7670FD69,0x39316BAE,0x202A5AEF,0x0B07092C,0x121C386D, + 0xDF4636F3,0xC65D07B2,0xED705471,0xF46B6530,0xBB2AF3F7,0xA231C2B6,0x891C9175,0x9007A034, + 0x179FBCFB,0x0E848DBA,0x25A9DE79,0x3CB2EF38,0x73F379FF,0x6AE848BE,0x41C51B7D,0x58DE2A3C, + 0xF0794F05,0xE9627E44,0xC24F2D87,0xDB541CC6,0x94158A01,0x8D0EBB40,0xA623E883,0xBF38D9C2, + 0x38A0C50D,0x21BBF44C,0x0A96A78F,0x138D96CE,0x5CCC0009,0x45D73148,0x6EFA628B,0x77E153CA, + 0xBABB5D54,0xA3A06C15,0x888D3FD6,0x91960E97,0xDED79850,0xC7CCA911,0xECE1FAD2,0xF5FACB93, + 0x7262D75C,0x6B79E61D,0x4054B5DE,0x594F849F,0x160E1258,0x0F152319,0x243870DA,0x3D23419B, + 0x65FD6BA7,0x7CE65AE6,0x57CB0925,0x4ED03864,0x0191AEA3,0x188A9FE2,0x33A7CC21,0x2ABCFD60, + 0xAD24E1AF,0xB43FD0EE,0x9F12832D,0x8609B26C,0xC94824AB,0xD05315EA,0xFB7E4629,0xE2657768, + 0x2F3F79F6,0x362448B7,0x1D091B74,0x04122A35,0x4B53BCF2,0x52488DB3,0x7965DE70,0x607EEF31, + 0xE7E6F3FE,0xFEFDC2BF,0xD5D0917C,0xCCCBA03D,0x838A36FA,0x9A9107BB,0xB1BC5478,0xA8A76539, + 0x3B83984B,0x2298A90A,0x09B5FAC9,0x10AECB88,0x5FEF5D4F,0x46F46C0E,0x6DD93FCD,0x74C20E8C, + 0xF35A1243,0xEA412302,0xC16C70C1,0xD8774180,0x9736D747,0x8E2DE606,0xA500B5C5,0xBC1B8484, + 0x71418A1A,0x685ABB5B,0x4377E898,0x5A6CD9D9,0x152D4F1E,0x0C367E5F,0x271B2D9C,0x3E001CDD, + 0xB9980012,0xA0833153,0x8BAE6290,0x92B553D1,0xDDF4C516,0xC4EFF457,0xEFC2A794,0xF6D996D5, + 0xAE07BCE9,0xB71C8DA8,0x9C31DE6B,0x852AEF2A,0xCA6B79ED,0xD37048AC,0xF85D1B6F,0xE1462A2E, + 0x66DE36E1,0x7FC507A0,0x54E85463,0x4DF36522,0x02B2F3E5,0x1BA9C2A4,0x30849167,0x299FA026, + 0xE4C5AEB8,0xFDDE9FF9,0xD6F3CC3A,0xCFE8FD7B,0x80A96BBC,0x99B25AFD,0xB29F093E,0xAB84387F, + 0x2C1C24B0,0x350715F1,0x1E2A4632,0x07317773,0x4870E1B4,0x516BD0F5,0x7A468336,0x635DB277, + 0xCBFAD74E,0xD2E1E60F,0xF9CCB5CC,0xE0D7848D,0xAF96124A,0xB68D230B,0x9DA070C8,0x84BB4189, + 0x03235D46,0x1A386C07,0x31153FC4,0x280E0E85,0x674F9842,0x7E54A903,0x5579FAC0,0x4C62CB81, + 0x8138C51F,0x9823F45E,0xB30EA79D,0xAA1596DC,0xE554001B,0xFC4F315A,0xD7626299,0xCE7953D8, + 0x49E14F17,0x50FA7E56,0x7BD72D95,0x62CC1CD4,0x2D8D8A13,0x3496BB52,0x1FBBE891,0x06A0D9D0, + 0x5E7EF3EC,0x4765C2AD,0x6C48916E,0x7553A02F,0x3A1236E8,0x230907A9,0x0824546A,0x113F652B, + 0x96A779E4,0x8FBC48A5,0xA4911B66,0xBD8A2A27,0xF2CBBCE0,0xEBD08DA1,0xC0FDDE62,0xD9E6EF23, + 0x14BCE1BD,0x0DA7D0FC,0x268A833F,0x3F91B27E,0x70D024B9,0x69CB15F8,0x42E6463B,0x5BFD777A, + 0xDC656BB5,0xC57E5AF4,0xEE530937,0xF7483876,0xB809AEB1,0xA1129FF0,0x8A3FCC33,0x9324FD72, + }, + + { + 0x00000000,0x01C26A37,0x0384D46E,0x0246BE59,0x0709A8DC,0x06CBC2EB,0x048D7CB2,0x054F1685, + 0x0E1351B8,0x0FD13B8F,0x0D9785D6,0x0C55EFE1,0x091AF964,0x08D89353,0x0A9E2D0A,0x0B5C473D, + 0x1C26A370,0x1DE4C947,0x1FA2771E,0x1E601D29,0x1B2F0BAC,0x1AED619B,0x18ABDFC2,0x1969B5F5, + 0x1235F2C8,0x13F798FF,0x11B126A6,0x10734C91,0x153C5A14,0x14FE3023,0x16B88E7A,0x177AE44D, + 0x384D46E0,0x398F2CD7,0x3BC9928E,0x3A0BF8B9,0x3F44EE3C,0x3E86840B,0x3CC03A52,0x3D025065, + 0x365E1758,0x379C7D6F,0x35DAC336,0x3418A901,0x3157BF84,0x3095D5B3,0x32D36BEA,0x331101DD, + 0x246BE590,0x25A98FA7,0x27EF31FE,0x262D5BC9,0x23624D4C,0x22A0277B,0x20E69922,0x2124F315, + 0x2A78B428,0x2BBADE1F,0x29FC6046,0x283E0A71,0x2D711CF4,0x2CB376C3,0x2EF5C89A,0x2F37A2AD, + 0x709A8DC0,0x7158E7F7,0x731E59AE,0x72DC3399,0x7793251C,0x76514F2B,0x7417F172,0x75D59B45, + 0x7E89DC78,0x7F4BB64F,0x7D0D0816,0x7CCF6221,0x798074A4,0x78421E93,0x7A04A0CA,0x7BC6CAFD, + 0x6CBC2EB0,0x6D7E4487,0x6F38FADE,0x6EFA90E9,0x6BB5866C,0x6A77EC5B,0x68315202,0x69F33835, + 0x62AF7F08,0x636D153F,0x612BAB66,0x60E9C151,0x65A6D7D4,0x6464BDE3,0x662203BA,0x67E0698D, + 0x48D7CB20,0x4915A117,0x4B531F4E,0x4A917579,0x4FDE63FC,0x4E1C09CB,0x4C5AB792,0x4D98DDA5, + 0x46C49A98,0x4706F0AF,0x45404EF6,0x448224C1,0x41CD3244,0x400F5873,0x4249E62A,0x438B8C1D, + 0x54F16850,0x55330267,0x5775BC3E,0x56B7D609,0x53F8C08C,0x523AAABB,0x507C14E2,0x51BE7ED5, + 0x5AE239E8,0x5B2053DF,0x5966ED86,0x58A487B1,0x5DEB9134,0x5C29FB03,0x5E6F455A,0x5FAD2F6D, + 0xE1351B80,0xE0F771B7,0xE2B1CFEE,0xE373A5D9,0xE63CB35C,0xE7FED96B,0xE5B86732,0xE47A0D05, + 0xEF264A38,0xEEE4200F,0xECA29E56,0xED60F461,0xE82FE2E4,0xE9ED88D3,0xEBAB368A,0xEA695CBD, + 0xFD13B8F0,0xFCD1D2C7,0xFE976C9E,0xFF5506A9,0xFA1A102C,0xFBD87A1B,0xF99EC442,0xF85CAE75, + 0xF300E948,0xF2C2837F,0xF0843D26,0xF1465711,0xF4094194,0xF5CB2BA3,0xF78D95FA,0xF64FFFCD, + 0xD9785D60,0xD8BA3757,0xDAFC890E,0xDB3EE339,0xDE71F5BC,0xDFB39F8B,0xDDF521D2,0xDC374BE5, + 0xD76B0CD8,0xD6A966EF,0xD4EFD8B6,0xD52DB281,0xD062A404,0xD1A0CE33,0xD3E6706A,0xD2241A5D, + 0xC55EFE10,0xC49C9427,0xC6DA2A7E,0xC7184049,0xC25756CC,0xC3953CFB,0xC1D382A2,0xC011E895, + 0xCB4DAFA8,0xCA8FC59F,0xC8C97BC6,0xC90B11F1,0xCC440774,0xCD866D43,0xCFC0D31A,0xCE02B92D, + 0x91AF9640,0x906DFC77,0x922B422E,0x93E92819,0x96A63E9C,0x976454AB,0x9522EAF2,0x94E080C5, + 0x9FBCC7F8,0x9E7EADCF,0x9C381396,0x9DFA79A1,0x98B56F24,0x99770513,0x9B31BB4A,0x9AF3D17D, + 0x8D893530,0x8C4B5F07,0x8E0DE15E,0x8FCF8B69,0x8A809DEC,0x8B42F7DB,0x89044982,0x88C623B5, + 0x839A6488,0x82580EBF,0x801EB0E6,0x81DCDAD1,0x8493CC54,0x8551A663,0x8717183A,0x86D5720D, + 0xA9E2D0A0,0xA820BA97,0xAA6604CE,0xABA46EF9,0xAEEB787C,0xAF29124B,0xAD6FAC12,0xACADC625, + 0xA7F18118,0xA633EB2F,0xA4755576,0xA5B73F41,0xA0F829C4,0xA13A43F3,0xA37CFDAA,0xA2BE979D, + 0xB5C473D0,0xB40619E7,0xB640A7BE,0xB782CD89,0xB2CDDB0C,0xB30FB13B,0xB1490F62,0xB08B6555, + 0xBBD72268,0xBA15485F,0xB853F606,0xB9919C31,0xBCDE8AB4,0xBD1CE083,0xBF5A5EDA,0xBE9834ED, + }, + + { + 0x00000000,0xB8BC6765,0xAA09C88B,0x12B5AFEE,0x8F629757,0x37DEF032,0x256B5FDC,0x9DD738B9, + 0xC5B428EF,0x7D084F8A,0x6FBDE064,0xD7018701,0x4AD6BFB8,0xF26AD8DD,0xE0DF7733,0x58631056, + 0x5019579F,0xE8A530FA,0xFA109F14,0x42ACF871,0xDF7BC0C8,0x67C7A7AD,0x75720843,0xCDCE6F26, + 0x95AD7F70,0x2D111815,0x3FA4B7FB,0x8718D09E,0x1ACFE827,0xA2738F42,0xB0C620AC,0x087A47C9, + 0xA032AF3E,0x188EC85B,0x0A3B67B5,0xB28700D0,0x2F503869,0x97EC5F0C,0x8559F0E2,0x3DE59787, + 0x658687D1,0xDD3AE0B4,0xCF8F4F5A,0x7733283F,0xEAE41086,0x525877E3,0x40EDD80D,0xF851BF68, + 0xF02BF8A1,0x48979FC4,0x5A22302A,0xE29E574F,0x7F496FF6,0xC7F50893,0xD540A77D,0x6DFCC018, + 0x359FD04E,0x8D23B72B,0x9F9618C5,0x272A7FA0,0xBAFD4719,0x0241207C,0x10F48F92,0xA848E8F7, + 0x9B14583D,0x23A83F58,0x311D90B6,0x89A1F7D3,0x1476CF6A,0xACCAA80F,0xBE7F07E1,0x06C36084, + 0x5EA070D2,0xE61C17B7,0xF4A9B859,0x4C15DF3C,0xD1C2E785,0x697E80E0,0x7BCB2F0E,0xC377486B, + 0xCB0D0FA2,0x73B168C7,0x6104C729,0xD9B8A04C,0x446F98F5,0xFCD3FF90,0xEE66507E,0x56DA371B, + 0x0EB9274D,0xB6054028,0xA4B0EFC6,0x1C0C88A3,0x81DBB01A,0x3967D77F,0x2BD27891,0x936E1FF4, + 0x3B26F703,0x839A9066,0x912F3F88,0x299358ED,0xB4446054,0x0CF80731,0x1E4DA8DF,0xA6F1CFBA, + 0xFE92DFEC,0x462EB889,0x549B1767,0xEC277002,0x71F048BB,0xC94C2FDE,0xDBF98030,0x6345E755, + 0x6B3FA09C,0xD383C7F9,0xC1366817,0x798A0F72,0xE45D37CB,0x5CE150AE,0x4E54FF40,0xF6E89825, + 0xAE8B8873,0x1637EF16,0x048240F8,0xBC3E279D,0x21E91F24,0x99557841,0x8BE0D7AF,0x335CB0CA, + 0xED59B63B,0x55E5D15E,0x47507EB0,0xFFEC19D5,0x623B216C,0xDA874609,0xC832E9E7,0x708E8E82, + 0x28ED9ED4,0x9051F9B1,0x82E4565F,0x3A58313A,0xA78F0983,0x1F336EE6,0x0D86C108,0xB53AA66D, + 0xBD40E1A4,0x05FC86C1,0x1749292F,0xAFF54E4A,0x322276F3,0x8A9E1196,0x982BBE78,0x2097D91D, + 0x78F4C94B,0xC048AE2E,0xD2FD01C0,0x6A4166A5,0xF7965E1C,0x4F2A3979,0x5D9F9697,0xE523F1F2, + 0x4D6B1905,0xF5D77E60,0xE762D18E,0x5FDEB6EB,0xC2098E52,0x7AB5E937,0x680046D9,0xD0BC21BC, + 0x88DF31EA,0x3063568F,0x22D6F961,0x9A6A9E04,0x07BDA6BD,0xBF01C1D8,0xADB46E36,0x15080953, + 0x1D724E9A,0xA5CE29FF,0xB77B8611,0x0FC7E174,0x9210D9CD,0x2AACBEA8,0x38191146,0x80A57623, + 0xD8C66675,0x607A0110,0x72CFAEFE,0xCA73C99B,0x57A4F122,0xEF189647,0xFDAD39A9,0x45115ECC, + 0x764DEE06,0xCEF18963,0xDC44268D,0x64F841E8,0xF92F7951,0x41931E34,0x5326B1DA,0xEB9AD6BF, + 0xB3F9C6E9,0x0B45A18C,0x19F00E62,0xA14C6907,0x3C9B51BE,0x842736DB,0x96929935,0x2E2EFE50, + 0x2654B999,0x9EE8DEFC,0x8C5D7112,0x34E11677,0xA9362ECE,0x118A49AB,0x033FE645,0xBB838120, + 0xE3E09176,0x5B5CF613,0x49E959FD,0xF1553E98,0x6C820621,0xD43E6144,0xC68BCEAA,0x7E37A9CF, + 0xD67F4138,0x6EC3265D,0x7C7689B3,0xC4CAEED6,0x591DD66F,0xE1A1B10A,0xF3141EE4,0x4BA87981, + 0x13CB69D7,0xAB770EB2,0xB9C2A15C,0x017EC639,0x9CA9FE80,0x241599E5,0x36A0360B,0x8E1C516E, + 0x866616A7,0x3EDA71C2,0x2C6FDE2C,0x94D3B949,0x090481F0,0xB1B8E695,0xA30D497B,0x1BB12E1E, + 0x43D23E48,0xFB6E592D,0xE9DBF6C3,0x516791A6,0xCCB0A91F,0x740CCE7A,0x66B96194,0xDE0506F1, + } +#endif // defined(CRC32_USE_LOOKUP_TABLE_SLICING_BY_4) || defined(CRC32_USE_LOOKUP_TABLE_SLICING_BY_8) || defined(CRC32_USE_LOOKUP_TABLE_SLICING_BY_16) +#if defined (CRC32_USE_LOOKUP_TABLE_SLICING_BY_8) || defined(CRC32_USE_LOOKUP_TABLE_SLICING_BY_16) + // beyond this point only relevant for Slicing-by-8 and Slicing-by-16 + ,{ + 0x00000000,0x3D6029B0,0x7AC05360,0x47A07AD0,0xF580A6C0,0xC8E08F70,0x8F40F5A0,0xB220DC10, + 0x30704BC1,0x0D106271,0x4AB018A1,0x77D03111,0xC5F0ED01,0xF890C4B1,0xBF30BE61,0x825097D1, + 0x60E09782,0x5D80BE32,0x1A20C4E2,0x2740ED52,0x95603142,0xA80018F2,0xEFA06222,0xD2C04B92, + 0x5090DC43,0x6DF0F5F3,0x2A508F23,0x1730A693,0xA5107A83,0x98705333,0xDFD029E3,0xE2B00053, + 0xC1C12F04,0xFCA106B4,0xBB017C64,0x866155D4,0x344189C4,0x0921A074,0x4E81DAA4,0x73E1F314, + 0xF1B164C5,0xCCD14D75,0x8B7137A5,0xB6111E15,0x0431C205,0x3951EBB5,0x7EF19165,0x4391B8D5, + 0xA121B886,0x9C419136,0xDBE1EBE6,0xE681C256,0x54A11E46,0x69C137F6,0x2E614D26,0x13016496, + 0x9151F347,0xAC31DAF7,0xEB91A027,0xD6F18997,0x64D15587,0x59B17C37,0x1E1106E7,0x23712F57, + 0x58F35849,0x659371F9,0x22330B29,0x1F532299,0xAD73FE89,0x9013D739,0xD7B3ADE9,0xEAD38459, + 0x68831388,0x55E33A38,0x124340E8,0x2F236958,0x9D03B548,0xA0639CF8,0xE7C3E628,0xDAA3CF98, + 0x3813CFCB,0x0573E67B,0x42D39CAB,0x7FB3B51B,0xCD93690B,0xF0F340BB,0xB7533A6B,0x8A3313DB, + 0x0863840A,0x3503ADBA,0x72A3D76A,0x4FC3FEDA,0xFDE322CA,0xC0830B7A,0x872371AA,0xBA43581A, + 0x9932774D,0xA4525EFD,0xE3F2242D,0xDE920D9D,0x6CB2D18D,0x51D2F83D,0x167282ED,0x2B12AB5D, + 0xA9423C8C,0x9422153C,0xD3826FEC,0xEEE2465C,0x5CC29A4C,0x61A2B3FC,0x2602C92C,0x1B62E09C, + 0xF9D2E0CF,0xC4B2C97F,0x8312B3AF,0xBE729A1F,0x0C52460F,0x31326FBF,0x7692156F,0x4BF23CDF, + 0xC9A2AB0E,0xF4C282BE,0xB362F86E,0x8E02D1DE,0x3C220DCE,0x0142247E,0x46E25EAE,0x7B82771E, + 0xB1E6B092,0x8C869922,0xCB26E3F2,0xF646CA42,0x44661652,0x79063FE2,0x3EA64532,0x03C66C82, + 0x8196FB53,0xBCF6D2E3,0xFB56A833,0xC6368183,0x74165D93,0x49767423,0x0ED60EF3,0x33B62743, + 0xD1062710,0xEC660EA0,0xABC67470,0x96A65DC0,0x248681D0,0x19E6A860,0x5E46D2B0,0x6326FB00, + 0xE1766CD1,0xDC164561,0x9BB63FB1,0xA6D61601,0x14F6CA11,0x2996E3A1,0x6E369971,0x5356B0C1, + 0x70279F96,0x4D47B626,0x0AE7CCF6,0x3787E546,0x85A73956,0xB8C710E6,0xFF676A36,0xC2074386, + 0x4057D457,0x7D37FDE7,0x3A978737,0x07F7AE87,0xB5D77297,0x88B75B27,0xCF1721F7,0xF2770847, + 0x10C70814,0x2DA721A4,0x6A075B74,0x576772C4,0xE547AED4,0xD8278764,0x9F87FDB4,0xA2E7D404, + 0x20B743D5,0x1DD76A65,0x5A7710B5,0x67173905,0xD537E515,0xE857CCA5,0xAFF7B675,0x92979FC5, + 0xE915E8DB,0xD475C16B,0x93D5BBBB,0xAEB5920B,0x1C954E1B,0x21F567AB,0x66551D7B,0x5B3534CB, + 0xD965A31A,0xE4058AAA,0xA3A5F07A,0x9EC5D9CA,0x2CE505DA,0x11852C6A,0x562556BA,0x6B457F0A, + 0x89F57F59,0xB49556E9,0xF3352C39,0xCE550589,0x7C75D999,0x4115F029,0x06B58AF9,0x3BD5A349, + 0xB9853498,0x84E51D28,0xC34567F8,0xFE254E48,0x4C059258,0x7165BBE8,0x36C5C138,0x0BA5E888, + 0x28D4C7DF,0x15B4EE6F,0x521494BF,0x6F74BD0F,0xDD54611F,0xE03448AF,0xA794327F,0x9AF41BCF, + 0x18A48C1E,0x25C4A5AE,0x6264DF7E,0x5F04F6CE,0xED242ADE,0xD044036E,0x97E479BE,0xAA84500E, + 0x4834505D,0x755479ED,0x32F4033D,0x0F942A8D,0xBDB4F69D,0x80D4DF2D,0xC774A5FD,0xFA148C4D, + 0x78441B9C,0x4524322C,0x028448FC,0x3FE4614C,0x8DC4BD5C,0xB0A494EC,0xF704EE3C,0xCA64C78C, + }, + + { + 0x00000000,0xCB5CD3A5,0x4DC8A10B,0x869472AE,0x9B914216,0x50CD91B3,0xD659E31D,0x1D0530B8, + 0xEC53826D,0x270F51C8,0xA19B2366,0x6AC7F0C3,0x77C2C07B,0xBC9E13DE,0x3A0A6170,0xF156B2D5, + 0x03D6029B,0xC88AD13E,0x4E1EA390,0x85427035,0x9847408D,0x531B9328,0xD58FE186,0x1ED33223, + 0xEF8580F6,0x24D95353,0xA24D21FD,0x6911F258,0x7414C2E0,0xBF481145,0x39DC63EB,0xF280B04E, + 0x07AC0536,0xCCF0D693,0x4A64A43D,0x81387798,0x9C3D4720,0x57619485,0xD1F5E62B,0x1AA9358E, + 0xEBFF875B,0x20A354FE,0xA6372650,0x6D6BF5F5,0x706EC54D,0xBB3216E8,0x3DA66446,0xF6FAB7E3, + 0x047A07AD,0xCF26D408,0x49B2A6A6,0x82EE7503,0x9FEB45BB,0x54B7961E,0xD223E4B0,0x197F3715, + 0xE82985C0,0x23755665,0xA5E124CB,0x6EBDF76E,0x73B8C7D6,0xB8E41473,0x3E7066DD,0xF52CB578, + 0x0F580A6C,0xC404D9C9,0x4290AB67,0x89CC78C2,0x94C9487A,0x5F959BDF,0xD901E971,0x125D3AD4, + 0xE30B8801,0x28575BA4,0xAEC3290A,0x659FFAAF,0x789ACA17,0xB3C619B2,0x35526B1C,0xFE0EB8B9, + 0x0C8E08F7,0xC7D2DB52,0x4146A9FC,0x8A1A7A59,0x971F4AE1,0x5C439944,0xDAD7EBEA,0x118B384F, + 0xE0DD8A9A,0x2B81593F,0xAD152B91,0x6649F834,0x7B4CC88C,0xB0101B29,0x36846987,0xFDD8BA22, + 0x08F40F5A,0xC3A8DCFF,0x453CAE51,0x8E607DF4,0x93654D4C,0x58399EE9,0xDEADEC47,0x15F13FE2, + 0xE4A78D37,0x2FFB5E92,0xA96F2C3C,0x6233FF99,0x7F36CF21,0xB46A1C84,0x32FE6E2A,0xF9A2BD8F, + 0x0B220DC1,0xC07EDE64,0x46EAACCA,0x8DB67F6F,0x90B34FD7,0x5BEF9C72,0xDD7BEEDC,0x16273D79, + 0xE7718FAC,0x2C2D5C09,0xAAB92EA7,0x61E5FD02,0x7CE0CDBA,0xB7BC1E1F,0x31286CB1,0xFA74BF14, + 0x1EB014D8,0xD5ECC77D,0x5378B5D3,0x98246676,0x852156CE,0x4E7D856B,0xC8E9F7C5,0x03B52460, + 0xF2E396B5,0x39BF4510,0xBF2B37BE,0x7477E41B,0x6972D4A3,0xA22E0706,0x24BA75A8,0xEFE6A60D, + 0x1D661643,0xD63AC5E6,0x50AEB748,0x9BF264ED,0x86F75455,0x4DAB87F0,0xCB3FF55E,0x006326FB, + 0xF135942E,0x3A69478B,0xBCFD3525,0x77A1E680,0x6AA4D638,0xA1F8059D,0x276C7733,0xEC30A496, + 0x191C11EE,0xD240C24B,0x54D4B0E5,0x9F886340,0x828D53F8,0x49D1805D,0xCF45F2F3,0x04192156, + 0xF54F9383,0x3E134026,0xB8873288,0x73DBE12D,0x6EDED195,0xA5820230,0x2316709E,0xE84AA33B, + 0x1ACA1375,0xD196C0D0,0x5702B27E,0x9C5E61DB,0x815B5163,0x4A0782C6,0xCC93F068,0x07CF23CD, + 0xF6999118,0x3DC542BD,0xBB513013,0x700DE3B6,0x6D08D30E,0xA65400AB,0x20C07205,0xEB9CA1A0, + 0x11E81EB4,0xDAB4CD11,0x5C20BFBF,0x977C6C1A,0x8A795CA2,0x41258F07,0xC7B1FDA9,0x0CED2E0C, + 0xFDBB9CD9,0x36E74F7C,0xB0733DD2,0x7B2FEE77,0x662ADECF,0xAD760D6A,0x2BE27FC4,0xE0BEAC61, + 0x123E1C2F,0xD962CF8A,0x5FF6BD24,0x94AA6E81,0x89AF5E39,0x42F38D9C,0xC467FF32,0x0F3B2C97, + 0xFE6D9E42,0x35314DE7,0xB3A53F49,0x78F9ECEC,0x65FCDC54,0xAEA00FF1,0x28347D5F,0xE368AEFA, + 0x16441B82,0xDD18C827,0x5B8CBA89,0x90D0692C,0x8DD55994,0x46898A31,0xC01DF89F,0x0B412B3A, + 0xFA1799EF,0x314B4A4A,0xB7DF38E4,0x7C83EB41,0x6186DBF9,0xAADA085C,0x2C4E7AF2,0xE712A957, + 0x15921919,0xDECECABC,0x585AB812,0x93066BB7,0x8E035B0F,0x455F88AA,0xC3CBFA04,0x089729A1, + 0xF9C19B74,0x329D48D1,0xB4093A7F,0x7F55E9DA,0x6250D962,0xA90C0AC7,0x2F987869,0xE4C4ABCC, + }, + + { + 0x00000000,0xA6770BB4,0x979F1129,0x31E81A9D,0xF44F2413,0x52382FA7,0x63D0353A,0xC5A73E8E, + 0x33EF4E67,0x959845D3,0xA4705F4E,0x020754FA,0xC7A06A74,0x61D761C0,0x503F7B5D,0xF64870E9, + 0x67DE9CCE,0xC1A9977A,0xF0418DE7,0x56368653,0x9391B8DD,0x35E6B369,0x040EA9F4,0xA279A240, + 0x5431D2A9,0xF246D91D,0xC3AEC380,0x65D9C834,0xA07EF6BA,0x0609FD0E,0x37E1E793,0x9196EC27, + 0xCFBD399C,0x69CA3228,0x582228B5,0xFE552301,0x3BF21D8F,0x9D85163B,0xAC6D0CA6,0x0A1A0712, + 0xFC5277FB,0x5A257C4F,0x6BCD66D2,0xCDBA6D66,0x081D53E8,0xAE6A585C,0x9F8242C1,0x39F54975, + 0xA863A552,0x0E14AEE6,0x3FFCB47B,0x998BBFCF,0x5C2C8141,0xFA5B8AF5,0xCBB39068,0x6DC49BDC, + 0x9B8CEB35,0x3DFBE081,0x0C13FA1C,0xAA64F1A8,0x6FC3CF26,0xC9B4C492,0xF85CDE0F,0x5E2BD5BB, + 0x440B7579,0xE27C7ECD,0xD3946450,0x75E36FE4,0xB044516A,0x16335ADE,0x27DB4043,0x81AC4BF7, + 0x77E43B1E,0xD19330AA,0xE07B2A37,0x460C2183,0x83AB1F0D,0x25DC14B9,0x14340E24,0xB2430590, + 0x23D5E9B7,0x85A2E203,0xB44AF89E,0x123DF32A,0xD79ACDA4,0x71EDC610,0x4005DC8D,0xE672D739, + 0x103AA7D0,0xB64DAC64,0x87A5B6F9,0x21D2BD4D,0xE47583C3,0x42028877,0x73EA92EA,0xD59D995E, + 0x8BB64CE5,0x2DC14751,0x1C295DCC,0xBA5E5678,0x7FF968F6,0xD98E6342,0xE86679DF,0x4E11726B, + 0xB8590282,0x1E2E0936,0x2FC613AB,0x89B1181F,0x4C162691,0xEA612D25,0xDB8937B8,0x7DFE3C0C, + 0xEC68D02B,0x4A1FDB9F,0x7BF7C102,0xDD80CAB6,0x1827F438,0xBE50FF8C,0x8FB8E511,0x29CFEEA5, + 0xDF879E4C,0x79F095F8,0x48188F65,0xEE6F84D1,0x2BC8BA5F,0x8DBFB1EB,0xBC57AB76,0x1A20A0C2, + 0x8816EAF2,0x2E61E146,0x1F89FBDB,0xB9FEF06F,0x7C59CEE1,0xDA2EC555,0xEBC6DFC8,0x4DB1D47C, + 0xBBF9A495,0x1D8EAF21,0x2C66B5BC,0x8A11BE08,0x4FB68086,0xE9C18B32,0xD82991AF,0x7E5E9A1B, + 0xEFC8763C,0x49BF7D88,0x78576715,0xDE206CA1,0x1B87522F,0xBDF0599B,0x8C184306,0x2A6F48B2, + 0xDC27385B,0x7A5033EF,0x4BB82972,0xEDCF22C6,0x28681C48,0x8E1F17FC,0xBFF70D61,0x198006D5, + 0x47ABD36E,0xE1DCD8DA,0xD034C247,0x7643C9F3,0xB3E4F77D,0x1593FCC9,0x247BE654,0x820CEDE0, + 0x74449D09,0xD23396BD,0xE3DB8C20,0x45AC8794,0x800BB91A,0x267CB2AE,0x1794A833,0xB1E3A387, + 0x20754FA0,0x86024414,0xB7EA5E89,0x119D553D,0xD43A6BB3,0x724D6007,0x43A57A9A,0xE5D2712E, + 0x139A01C7,0xB5ED0A73,0x840510EE,0x22721B5A,0xE7D525D4,0x41A22E60,0x704A34FD,0xD63D3F49, + 0xCC1D9F8B,0x6A6A943F,0x5B828EA2,0xFDF58516,0x3852BB98,0x9E25B02C,0xAFCDAAB1,0x09BAA105, + 0xFFF2D1EC,0x5985DA58,0x686DC0C5,0xCE1ACB71,0x0BBDF5FF,0xADCAFE4B,0x9C22E4D6,0x3A55EF62, + 0xABC30345,0x0DB408F1,0x3C5C126C,0x9A2B19D8,0x5F8C2756,0xF9FB2CE2,0xC813367F,0x6E643DCB, + 0x982C4D22,0x3E5B4696,0x0FB35C0B,0xA9C457BF,0x6C636931,0xCA146285,0xFBFC7818,0x5D8B73AC, + 0x03A0A617,0xA5D7ADA3,0x943FB73E,0x3248BC8A,0xF7EF8204,0x519889B0,0x6070932D,0xC6079899, + 0x304FE870,0x9638E3C4,0xA7D0F959,0x01A7F2ED,0xC400CC63,0x6277C7D7,0x539FDD4A,0xF5E8D6FE, + 0x647E3AD9,0xC209316D,0xF3E12BF0,0x55962044,0x90311ECA,0x3646157E,0x07AE0FE3,0xA1D90457, + 0x579174BE,0xF1E67F0A,0xC00E6597,0x66796E23,0xA3DE50AD,0x05A95B19,0x34414184,0x92364A30, + }, + + { + 0x00000000,0xCCAA009E,0x4225077D,0x8E8F07E3,0x844A0EFA,0x48E00E64,0xC66F0987,0x0AC50919, + 0xD3E51BB5,0x1F4F1B2B,0x91C01CC8,0x5D6A1C56,0x57AF154F,0x9B0515D1,0x158A1232,0xD92012AC, + 0x7CBB312B,0xB01131B5,0x3E9E3656,0xF23436C8,0xF8F13FD1,0x345B3F4F,0xBAD438AC,0x767E3832, + 0xAF5E2A9E,0x63F42A00,0xED7B2DE3,0x21D12D7D,0x2B142464,0xE7BE24FA,0x69312319,0xA59B2387, + 0xF9766256,0x35DC62C8,0xBB53652B,0x77F965B5,0x7D3C6CAC,0xB1966C32,0x3F196BD1,0xF3B36B4F, + 0x2A9379E3,0xE639797D,0x68B67E9E,0xA41C7E00,0xAED97719,0x62737787,0xECFC7064,0x205670FA, + 0x85CD537D,0x496753E3,0xC7E85400,0x0B42549E,0x01875D87,0xCD2D5D19,0x43A25AFA,0x8F085A64, + 0x562848C8,0x9A824856,0x140D4FB5,0xD8A74F2B,0xD2624632,0x1EC846AC,0x9047414F,0x5CED41D1, + 0x299DC2ED,0xE537C273,0x6BB8C590,0xA712C50E,0xADD7CC17,0x617DCC89,0xEFF2CB6A,0x2358CBF4, + 0xFA78D958,0x36D2D9C6,0xB85DDE25,0x74F7DEBB,0x7E32D7A2,0xB298D73C,0x3C17D0DF,0xF0BDD041, + 0x5526F3C6,0x998CF358,0x1703F4BB,0xDBA9F425,0xD16CFD3C,0x1DC6FDA2,0x9349FA41,0x5FE3FADF, + 0x86C3E873,0x4A69E8ED,0xC4E6EF0E,0x084CEF90,0x0289E689,0xCE23E617,0x40ACE1F4,0x8C06E16A, + 0xD0EBA0BB,0x1C41A025,0x92CEA7C6,0x5E64A758,0x54A1AE41,0x980BAEDF,0x1684A93C,0xDA2EA9A2, + 0x030EBB0E,0xCFA4BB90,0x412BBC73,0x8D81BCED,0x8744B5F4,0x4BEEB56A,0xC561B289,0x09CBB217, + 0xAC509190,0x60FA910E,0xEE7596ED,0x22DF9673,0x281A9F6A,0xE4B09FF4,0x6A3F9817,0xA6959889, + 0x7FB58A25,0xB31F8ABB,0x3D908D58,0xF13A8DC6,0xFBFF84DF,0x37558441,0xB9DA83A2,0x7570833C, + 0x533B85DA,0x9F918544,0x111E82A7,0xDDB48239,0xD7718B20,0x1BDB8BBE,0x95548C5D,0x59FE8CC3, + 0x80DE9E6F,0x4C749EF1,0xC2FB9912,0x0E51998C,0x04949095,0xC83E900B,0x46B197E8,0x8A1B9776, + 0x2F80B4F1,0xE32AB46F,0x6DA5B38C,0xA10FB312,0xABCABA0B,0x6760BA95,0xE9EFBD76,0x2545BDE8, + 0xFC65AF44,0x30CFAFDA,0xBE40A839,0x72EAA8A7,0x782FA1BE,0xB485A120,0x3A0AA6C3,0xF6A0A65D, + 0xAA4DE78C,0x66E7E712,0xE868E0F1,0x24C2E06F,0x2E07E976,0xE2ADE9E8,0x6C22EE0B,0xA088EE95, + 0x79A8FC39,0xB502FCA7,0x3B8DFB44,0xF727FBDA,0xFDE2F2C3,0x3148F25D,0xBFC7F5BE,0x736DF520, + 0xD6F6D6A7,0x1A5CD639,0x94D3D1DA,0x5879D144,0x52BCD85D,0x9E16D8C3,0x1099DF20,0xDC33DFBE, + 0x0513CD12,0xC9B9CD8C,0x4736CA6F,0x8B9CCAF1,0x8159C3E8,0x4DF3C376,0xC37CC495,0x0FD6C40B, + 0x7AA64737,0xB60C47A9,0x3883404A,0xF42940D4,0xFEEC49CD,0x32464953,0xBCC94EB0,0x70634E2E, + 0xA9435C82,0x65E95C1C,0xEB665BFF,0x27CC5B61,0x2D095278,0xE1A352E6,0x6F2C5505,0xA386559B, + 0x061D761C,0xCAB77682,0x44387161,0x889271FF,0x825778E6,0x4EFD7878,0xC0727F9B,0x0CD87F05, + 0xD5F86DA9,0x19526D37,0x97DD6AD4,0x5B776A4A,0x51B26353,0x9D1863CD,0x1397642E,0xDF3D64B0, + 0x83D02561,0x4F7A25FF,0xC1F5221C,0x0D5F2282,0x079A2B9B,0xCB302B05,0x45BF2CE6,0x89152C78, + 0x50353ED4,0x9C9F3E4A,0x121039A9,0xDEBA3937,0xD47F302E,0x18D530B0,0x965A3753,0x5AF037CD, + 0xFF6B144A,0x33C114D4,0xBD4E1337,0x71E413A9,0x7B211AB0,0xB78B1A2E,0x39041DCD,0xF5AE1D53, + 0x2C8E0FFF,0xE0240F61,0x6EAB0882,0xA201081C,0xA8C40105,0x646E019B,0xEAE10678,0x264B06E6, + } +#endif // CRC32_USE_LOOKUP_TABLE_SLICING_BY_8 || CRC32_USE_LOOKUP_TABLE_SLICING_BY_16 +#ifdef CRC32_USE_LOOKUP_TABLE_SLICING_BY_16 + // beyond this point only relevant for Slicing-by-16 + ,{ + 0x00000000,0x177B1443,0x2EF62886,0x398D3CC5,0x5DEC510C,0x4A97454F,0x731A798A,0x64616DC9, + 0xBBD8A218,0xACA3B65B,0x952E8A9E,0x82559EDD,0xE634F314,0xF14FE757,0xC8C2DB92,0xDFB9CFD1, + 0xACC04271,0xBBBB5632,0x82366AF7,0x954D7EB4,0xF12C137D,0xE657073E,0xDFDA3BFB,0xC8A12FB8, + 0x1718E069,0x0063F42A,0x39EEC8EF,0x2E95DCAC,0x4AF4B165,0x5D8FA526,0x640299E3,0x73798DA0, + 0x82F182A3,0x958A96E0,0xAC07AA25,0xBB7CBE66,0xDF1DD3AF,0xC866C7EC,0xF1EBFB29,0xE690EF6A, + 0x392920BB,0x2E5234F8,0x17DF083D,0x00A41C7E,0x64C571B7,0x73BE65F4,0x4A335931,0x5D484D72, + 0x2E31C0D2,0x394AD491,0x00C7E854,0x17BCFC17,0x73DD91DE,0x64A6859D,0x5D2BB958,0x4A50AD1B, + 0x95E962CA,0x82927689,0xBB1F4A4C,0xAC645E0F,0xC80533C6,0xDF7E2785,0xE6F31B40,0xF1880F03, + 0xDE920307,0xC9E91744,0xF0642B81,0xE71F3FC2,0x837E520B,0x94054648,0xAD887A8D,0xBAF36ECE, + 0x654AA11F,0x7231B55C,0x4BBC8999,0x5CC79DDA,0x38A6F013,0x2FDDE450,0x1650D895,0x012BCCD6, + 0x72524176,0x65295535,0x5CA469F0,0x4BDF7DB3,0x2FBE107A,0x38C50439,0x014838FC,0x16332CBF, + 0xC98AE36E,0xDEF1F72D,0xE77CCBE8,0xF007DFAB,0x9466B262,0x831DA621,0xBA909AE4,0xADEB8EA7, + 0x5C6381A4,0x4B1895E7,0x7295A922,0x65EEBD61,0x018FD0A8,0x16F4C4EB,0x2F79F82E,0x3802EC6D, + 0xE7BB23BC,0xF0C037FF,0xC94D0B3A,0xDE361F79,0xBA5772B0,0xAD2C66F3,0x94A15A36,0x83DA4E75, + 0xF0A3C3D5,0xE7D8D796,0xDE55EB53,0xC92EFF10,0xAD4F92D9,0xBA34869A,0x83B9BA5F,0x94C2AE1C, + 0x4B7B61CD,0x5C00758E,0x658D494B,0x72F65D08,0x169730C1,0x01EC2482,0x38611847,0x2F1A0C04, + 0x6655004F,0x712E140C,0x48A328C9,0x5FD83C8A,0x3BB95143,0x2CC24500,0x154F79C5,0x02346D86, + 0xDD8DA257,0xCAF6B614,0xF37B8AD1,0xE4009E92,0x8061F35B,0x971AE718,0xAE97DBDD,0xB9ECCF9E, + 0xCA95423E,0xDDEE567D,0xE4636AB8,0xF3187EFB,0x97791332,0x80020771,0xB98F3BB4,0xAEF42FF7, + 0x714DE026,0x6636F465,0x5FBBC8A0,0x48C0DCE3,0x2CA1B12A,0x3BDAA569,0x025799AC,0x152C8DEF, + 0xE4A482EC,0xF3DF96AF,0xCA52AA6A,0xDD29BE29,0xB948D3E0,0xAE33C7A3,0x97BEFB66,0x80C5EF25, + 0x5F7C20F4,0x480734B7,0x718A0872,0x66F11C31,0x029071F8,0x15EB65BB,0x2C66597E,0x3B1D4D3D, + 0x4864C09D,0x5F1FD4DE,0x6692E81B,0x71E9FC58,0x15889191,0x02F385D2,0x3B7EB917,0x2C05AD54, + 0xF3BC6285,0xE4C776C6,0xDD4A4A03,0xCA315E40,0xAE503389,0xB92B27CA,0x80A61B0F,0x97DD0F4C, + 0xB8C70348,0xAFBC170B,0x96312BCE,0x814A3F8D,0xE52B5244,0xF2504607,0xCBDD7AC2,0xDCA66E81, + 0x031FA150,0x1464B513,0x2DE989D6,0x3A929D95,0x5EF3F05C,0x4988E41F,0x7005D8DA,0x677ECC99, + 0x14074139,0x037C557A,0x3AF169BF,0x2D8A7DFC,0x49EB1035,0x5E900476,0x671D38B3,0x70662CF0, + 0xAFDFE321,0xB8A4F762,0x8129CBA7,0x9652DFE4,0xF233B22D,0xE548A66E,0xDCC59AAB,0xCBBE8EE8, + 0x3A3681EB,0x2D4D95A8,0x14C0A96D,0x03BBBD2E,0x67DAD0E7,0x70A1C4A4,0x492CF861,0x5E57EC22, + 0x81EE23F3,0x969537B0,0xAF180B75,0xB8631F36,0xDC0272FF,0xCB7966BC,0xF2F45A79,0xE58F4E3A, + 0x96F6C39A,0x818DD7D9,0xB800EB1C,0xAF7BFF5F,0xCB1A9296,0xDC6186D5,0xE5ECBA10,0xF297AE53, + 0x2D2E6182,0x3A5575C1,0x03D84904,0x14A35D47,0x70C2308E,0x67B924CD,0x5E341808,0x494F0C4B, + }, + + { + 0x00000000,0xEFC26B3E,0x04F5D03D,0xEB37BB03,0x09EBA07A,0xE629CB44,0x0D1E7047,0xE2DC1B79, + 0x13D740F4,0xFC152BCA,0x172290C9,0xF8E0FBF7,0x1A3CE08E,0xF5FE8BB0,0x1EC930B3,0xF10B5B8D, + 0x27AE81E8,0xC86CEAD6,0x235B51D5,0xCC993AEB,0x2E452192,0xC1874AAC,0x2AB0F1AF,0xC5729A91, + 0x3479C11C,0xDBBBAA22,0x308C1121,0xDF4E7A1F,0x3D926166,0xD2500A58,0x3967B15B,0xD6A5DA65, + 0x4F5D03D0,0xA09F68EE,0x4BA8D3ED,0xA46AB8D3,0x46B6A3AA,0xA974C894,0x42437397,0xAD8118A9, + 0x5C8A4324,0xB348281A,0x587F9319,0xB7BDF827,0x5561E35E,0xBAA38860,0x51943363,0xBE56585D, + 0x68F38238,0x8731E906,0x6C065205,0x83C4393B,0x61182242,0x8EDA497C,0x65EDF27F,0x8A2F9941, + 0x7B24C2CC,0x94E6A9F2,0x7FD112F1,0x901379CF,0x72CF62B6,0x9D0D0988,0x763AB28B,0x99F8D9B5, + 0x9EBA07A0,0x71786C9E,0x9A4FD79D,0x758DBCA3,0x9751A7DA,0x7893CCE4,0x93A477E7,0x7C661CD9, + 0x8D6D4754,0x62AF2C6A,0x89989769,0x665AFC57,0x8486E72E,0x6B448C10,0x80733713,0x6FB15C2D, + 0xB9148648,0x56D6ED76,0xBDE15675,0x52233D4B,0xB0FF2632,0x5F3D4D0C,0xB40AF60F,0x5BC89D31, + 0xAAC3C6BC,0x4501AD82,0xAE361681,0x41F47DBF,0xA32866C6,0x4CEA0DF8,0xA7DDB6FB,0x481FDDC5, + 0xD1E70470,0x3E256F4E,0xD512D44D,0x3AD0BF73,0xD80CA40A,0x37CECF34,0xDCF97437,0x333B1F09, + 0xC2304484,0x2DF22FBA,0xC6C594B9,0x2907FF87,0xCBDBE4FE,0x24198FC0,0xCF2E34C3,0x20EC5FFD, + 0xF6498598,0x198BEEA6,0xF2BC55A5,0x1D7E3E9B,0xFFA225E2,0x10604EDC,0xFB57F5DF,0x14959EE1, + 0xE59EC56C,0x0A5CAE52,0xE16B1551,0x0EA97E6F,0xEC756516,0x03B70E28,0xE880B52B,0x0742DE15, + 0xE6050901,0x09C7623F,0xE2F0D93C,0x0D32B202,0xEFEEA97B,0x002CC245,0xEB1B7946,0x04D91278, + 0xF5D249F5,0x1A1022CB,0xF12799C8,0x1EE5F2F6,0xFC39E98F,0x13FB82B1,0xF8CC39B2,0x170E528C, + 0xC1AB88E9,0x2E69E3D7,0xC55E58D4,0x2A9C33EA,0xC8402893,0x278243AD,0xCCB5F8AE,0x23779390, + 0xD27CC81D,0x3DBEA323,0xD6891820,0x394B731E,0xDB976867,0x34550359,0xDF62B85A,0x30A0D364, + 0xA9580AD1,0x469A61EF,0xADADDAEC,0x426FB1D2,0xA0B3AAAB,0x4F71C195,0xA4467A96,0x4B8411A8, + 0xBA8F4A25,0x554D211B,0xBE7A9A18,0x51B8F126,0xB364EA5F,0x5CA68161,0xB7913A62,0x5853515C, + 0x8EF68B39,0x6134E007,0x8A035B04,0x65C1303A,0x871D2B43,0x68DF407D,0x83E8FB7E,0x6C2A9040, + 0x9D21CBCD,0x72E3A0F3,0x99D41BF0,0x761670CE,0x94CA6BB7,0x7B080089,0x903FBB8A,0x7FFDD0B4, + 0x78BF0EA1,0x977D659F,0x7C4ADE9C,0x9388B5A2,0x7154AEDB,0x9E96C5E5,0x75A17EE6,0x9A6315D8, + 0x6B684E55,0x84AA256B,0x6F9D9E68,0x805FF556,0x6283EE2F,0x8D418511,0x66763E12,0x89B4552C, + 0x5F118F49,0xB0D3E477,0x5BE45F74,0xB426344A,0x56FA2F33,0xB938440D,0x520FFF0E,0xBDCD9430, + 0x4CC6CFBD,0xA304A483,0x48331F80,0xA7F174BE,0x452D6FC7,0xAAEF04F9,0x41D8BFFA,0xAE1AD4C4, + 0x37E20D71,0xD820664F,0x3317DD4C,0xDCD5B672,0x3E09AD0B,0xD1CBC635,0x3AFC7D36,0xD53E1608, + 0x24354D85,0xCBF726BB,0x20C09DB8,0xCF02F686,0x2DDEEDFF,0xC21C86C1,0x292B3DC2,0xC6E956FC, + 0x104C8C99,0xFF8EE7A7,0x14B95CA4,0xFB7B379A,0x19A72CE3,0xF66547DD,0x1D52FCDE,0xF29097E0, + 0x039BCC6D,0xEC59A753,0x076E1C50,0xE8AC776E,0x0A706C17,0xE5B20729,0x0E85BC2A,0xE147D714, + }, + + { + 0x00000000,0xC18EDFC0,0x586CB9C1,0x99E26601,0xB0D97382,0x7157AC42,0xE8B5CA43,0x293B1583, + 0xBAC3E145,0x7B4D3E85,0xE2AF5884,0x23218744,0x0A1A92C7,0xCB944D07,0x52762B06,0x93F8F4C6, + 0xAEF6C4CB,0x6F781B0B,0xF69A7D0A,0x3714A2CA,0x1E2FB749,0xDFA16889,0x46430E88,0x87CDD148, + 0x1435258E,0xD5BBFA4E,0x4C599C4F,0x8DD7438F,0xA4EC560C,0x656289CC,0xFC80EFCD,0x3D0E300D, + 0x869C8FD7,0x47125017,0xDEF03616,0x1F7EE9D6,0x3645FC55,0xF7CB2395,0x6E294594,0xAFA79A54, + 0x3C5F6E92,0xFDD1B152,0x6433D753,0xA5BD0893,0x8C861D10,0x4D08C2D0,0xD4EAA4D1,0x15647B11, + 0x286A4B1C,0xE9E494DC,0x7006F2DD,0xB1882D1D,0x98B3389E,0x593DE75E,0xC0DF815F,0x01515E9F, + 0x92A9AA59,0x53277599,0xCAC51398,0x0B4BCC58,0x2270D9DB,0xE3FE061B,0x7A1C601A,0xBB92BFDA, + 0xD64819EF,0x17C6C62F,0x8E24A02E,0x4FAA7FEE,0x66916A6D,0xA71FB5AD,0x3EFDD3AC,0xFF730C6C, + 0x6C8BF8AA,0xAD05276A,0x34E7416B,0xF5699EAB,0xDC528B28,0x1DDC54E8,0x843E32E9,0x45B0ED29, + 0x78BEDD24,0xB93002E4,0x20D264E5,0xE15CBB25,0xC867AEA6,0x09E97166,0x900B1767,0x5185C8A7, + 0xC27D3C61,0x03F3E3A1,0x9A1185A0,0x5B9F5A60,0x72A44FE3,0xB32A9023,0x2AC8F622,0xEB4629E2, + 0x50D49638,0x915A49F8,0x08B82FF9,0xC936F039,0xE00DE5BA,0x21833A7A,0xB8615C7B,0x79EF83BB, + 0xEA17777D,0x2B99A8BD,0xB27BCEBC,0x73F5117C,0x5ACE04FF,0x9B40DB3F,0x02A2BD3E,0xC32C62FE, + 0xFE2252F3,0x3FAC8D33,0xA64EEB32,0x67C034F2,0x4EFB2171,0x8F75FEB1,0x169798B0,0xD7194770, + 0x44E1B3B6,0x856F6C76,0x1C8D0A77,0xDD03D5B7,0xF438C034,0x35B61FF4,0xAC5479F5,0x6DDAA635, + 0x77E1359F,0xB66FEA5F,0x2F8D8C5E,0xEE03539E,0xC738461D,0x06B699DD,0x9F54FFDC,0x5EDA201C, + 0xCD22D4DA,0x0CAC0B1A,0x954E6D1B,0x54C0B2DB,0x7DFBA758,0xBC757898,0x25971E99,0xE419C159, + 0xD917F154,0x18992E94,0x817B4895,0x40F59755,0x69CE82D6,0xA8405D16,0x31A23B17,0xF02CE4D7, + 0x63D41011,0xA25ACFD1,0x3BB8A9D0,0xFA367610,0xD30D6393,0x1283BC53,0x8B61DA52,0x4AEF0592, + 0xF17DBA48,0x30F36588,0xA9110389,0x689FDC49,0x41A4C9CA,0x802A160A,0x19C8700B,0xD846AFCB, + 0x4BBE5B0D,0x8A3084CD,0x13D2E2CC,0xD25C3D0C,0xFB67288F,0x3AE9F74F,0xA30B914E,0x62854E8E, + 0x5F8B7E83,0x9E05A143,0x07E7C742,0xC6691882,0xEF520D01,0x2EDCD2C1,0xB73EB4C0,0x76B06B00, + 0xE5489FC6,0x24C64006,0xBD242607,0x7CAAF9C7,0x5591EC44,0x941F3384,0x0DFD5585,0xCC738A45, + 0xA1A92C70,0x6027F3B0,0xF9C595B1,0x384B4A71,0x11705FF2,0xD0FE8032,0x491CE633,0x889239F3, + 0x1B6ACD35,0xDAE412F5,0x430674F4,0x8288AB34,0xABB3BEB7,0x6A3D6177,0xF3DF0776,0x3251D8B6, + 0x0F5FE8BB,0xCED1377B,0x5733517A,0x96BD8EBA,0xBF869B39,0x7E0844F9,0xE7EA22F8,0x2664FD38, + 0xB59C09FE,0x7412D63E,0xEDF0B03F,0x2C7E6FFF,0x05457A7C,0xC4CBA5BC,0x5D29C3BD,0x9CA71C7D, + 0x2735A3A7,0xE6BB7C67,0x7F591A66,0xBED7C5A6,0x97ECD025,0x56620FE5,0xCF8069E4,0x0E0EB624, + 0x9DF642E2,0x5C789D22,0xC59AFB23,0x041424E3,0x2D2F3160,0xECA1EEA0,0x754388A1,0xB4CD5761, + 0x89C3676C,0x484DB8AC,0xD1AFDEAD,0x1021016D,0x391A14EE,0xF894CB2E,0x6176AD2F,0xA0F872EF, + 0x33008629,0xF28E59E9,0x6B6C3FE8,0xAAE2E028,0x83D9F5AB,0x42572A6B,0xDBB54C6A,0x1A3B93AA, + }, + + { + 0x00000000,0x9BA54C6F,0xEC3B9E9F,0x779ED2F0,0x03063B7F,0x98A37710,0xEF3DA5E0,0x7498E98F, + 0x060C76FE,0x9DA93A91,0xEA37E861,0x7192A40E,0x050A4D81,0x9EAF01EE,0xE931D31E,0x72949F71, + 0x0C18EDFC,0x97BDA193,0xE0237363,0x7B863F0C,0x0F1ED683,0x94BB9AEC,0xE325481C,0x78800473, + 0x0A149B02,0x91B1D76D,0xE62F059D,0x7D8A49F2,0x0912A07D,0x92B7EC12,0xE5293EE2,0x7E8C728D, + 0x1831DBF8,0x83949797,0xF40A4567,0x6FAF0908,0x1B37E087,0x8092ACE8,0xF70C7E18,0x6CA93277, + 0x1E3DAD06,0x8598E169,0xF2063399,0x69A37FF6,0x1D3B9679,0x869EDA16,0xF10008E6,0x6AA54489, + 0x14293604,0x8F8C7A6B,0xF812A89B,0x63B7E4F4,0x172F0D7B,0x8C8A4114,0xFB1493E4,0x60B1DF8B, + 0x122540FA,0x89800C95,0xFE1EDE65,0x65BB920A,0x11237B85,0x8A8637EA,0xFD18E51A,0x66BDA975, + 0x3063B7F0,0xABC6FB9F,0xDC58296F,0x47FD6500,0x33658C8F,0xA8C0C0E0,0xDF5E1210,0x44FB5E7F, + 0x366FC10E,0xADCA8D61,0xDA545F91,0x41F113FE,0x3569FA71,0xAECCB61E,0xD95264EE,0x42F72881, + 0x3C7B5A0C,0xA7DE1663,0xD040C493,0x4BE588FC,0x3F7D6173,0xA4D82D1C,0xD346FFEC,0x48E3B383, + 0x3A772CF2,0xA1D2609D,0xD64CB26D,0x4DE9FE02,0x3971178D,0xA2D45BE2,0xD54A8912,0x4EEFC57D, + 0x28526C08,0xB3F72067,0xC469F297,0x5FCCBEF8,0x2B545777,0xB0F11B18,0xC76FC9E8,0x5CCA8587, + 0x2E5E1AF6,0xB5FB5699,0xC2658469,0x59C0C806,0x2D582189,0xB6FD6DE6,0xC163BF16,0x5AC6F379, + 0x244A81F4,0xBFEFCD9B,0xC8711F6B,0x53D45304,0x274CBA8B,0xBCE9F6E4,0xCB772414,0x50D2687B, + 0x2246F70A,0xB9E3BB65,0xCE7D6995,0x55D825FA,0x2140CC75,0xBAE5801A,0xCD7B52EA,0x56DE1E85, + 0x60C76FE0,0xFB62238F,0x8CFCF17F,0x1759BD10,0x63C1549F,0xF86418F0,0x8FFACA00,0x145F866F, + 0x66CB191E,0xFD6E5571,0x8AF08781,0x1155CBEE,0x65CD2261,0xFE686E0E,0x89F6BCFE,0x1253F091, + 0x6CDF821C,0xF77ACE73,0x80E41C83,0x1B4150EC,0x6FD9B963,0xF47CF50C,0x83E227FC,0x18476B93, + 0x6AD3F4E2,0xF176B88D,0x86E86A7D,0x1D4D2612,0x69D5CF9D,0xF27083F2,0x85EE5102,0x1E4B1D6D, + 0x78F6B418,0xE353F877,0x94CD2A87,0x0F6866E8,0x7BF08F67,0xE055C308,0x97CB11F8,0x0C6E5D97, + 0x7EFAC2E6,0xE55F8E89,0x92C15C79,0x09641016,0x7DFCF999,0xE659B5F6,0x91C76706,0x0A622B69, + 0x74EE59E4,0xEF4B158B,0x98D5C77B,0x03708B14,0x77E8629B,0xEC4D2EF4,0x9BD3FC04,0x0076B06B, + 0x72E22F1A,0xE9476375,0x9ED9B185,0x057CFDEA,0x71E41465,0xEA41580A,0x9DDF8AFA,0x067AC695, + 0x50A4D810,0xCB01947F,0xBC9F468F,0x273A0AE0,0x53A2E36F,0xC807AF00,0xBF997DF0,0x243C319F, + 0x56A8AEEE,0xCD0DE281,0xBA933071,0x21367C1E,0x55AE9591,0xCE0BD9FE,0xB9950B0E,0x22304761, + 0x5CBC35EC,0xC7197983,0xB087AB73,0x2B22E71C,0x5FBA0E93,0xC41F42FC,0xB381900C,0x2824DC63, + 0x5AB04312,0xC1150F7D,0xB68BDD8D,0x2D2E91E2,0x59B6786D,0xC2133402,0xB58DE6F2,0x2E28AA9D, + 0x489503E8,0xD3304F87,0xA4AE9D77,0x3F0BD118,0x4B933897,0xD03674F8,0xA7A8A608,0x3C0DEA67, + 0x4E997516,0xD53C3979,0xA2A2EB89,0x3907A7E6,0x4D9F4E69,0xD63A0206,0xA1A4D0F6,0x3A019C99, + 0x448DEE14,0xDF28A27B,0xA8B6708B,0x33133CE4,0x478BD56B,0xDC2E9904,0xABB04BF4,0x3015079B, + 0x428198EA,0xD924D485,0xAEBA0675,0x351F4A1A,0x4187A395,0xDA22EFFA,0xADBC3D0A,0x36197165, + }, + + { + 0x00000000,0xDD96D985,0x605CB54B,0xBDCA6CCE,0xC0B96A96,0x1D2FB313,0xA0E5DFDD,0x7D730658, + 0x5A03D36D,0x87950AE8,0x3A5F6626,0xE7C9BFA3,0x9ABAB9FB,0x472C607E,0xFAE60CB0,0x2770D535, + 0xB407A6DA,0x69917F5F,0xD45B1391,0x09CDCA14,0x74BECC4C,0xA92815C9,0x14E27907,0xC974A082, + 0xEE0475B7,0x3392AC32,0x8E58C0FC,0x53CE1979,0x2EBD1F21,0xF32BC6A4,0x4EE1AA6A,0x937773EF, + 0xB37E4BF5,0x6EE89270,0xD322FEBE,0x0EB4273B,0x73C72163,0xAE51F8E6,0x139B9428,0xCE0D4DAD, + 0xE97D9898,0x34EB411D,0x89212DD3,0x54B7F456,0x29C4F20E,0xF4522B8B,0x49984745,0x940E9EC0, + 0x0779ED2F,0xDAEF34AA,0x67255864,0xBAB381E1,0xC7C087B9,0x1A565E3C,0xA79C32F2,0x7A0AEB77, + 0x5D7A3E42,0x80ECE7C7,0x3D268B09,0xE0B0528C,0x9DC354D4,0x40558D51,0xFD9FE19F,0x2009381A, + 0xBD8D91AB,0x601B482E,0xDDD124E0,0x0047FD65,0x7D34FB3D,0xA0A222B8,0x1D684E76,0xC0FE97F3, + 0xE78E42C6,0x3A189B43,0x87D2F78D,0x5A442E08,0x27372850,0xFAA1F1D5,0x476B9D1B,0x9AFD449E, + 0x098A3771,0xD41CEEF4,0x69D6823A,0xB4405BBF,0xC9335DE7,0x14A58462,0xA96FE8AC,0x74F93129, + 0x5389E41C,0x8E1F3D99,0x33D55157,0xEE4388D2,0x93308E8A,0x4EA6570F,0xF36C3BC1,0x2EFAE244, + 0x0EF3DA5E,0xD36503DB,0x6EAF6F15,0xB339B690,0xCE4AB0C8,0x13DC694D,0xAE160583,0x7380DC06, + 0x54F00933,0x8966D0B6,0x34ACBC78,0xE93A65FD,0x944963A5,0x49DFBA20,0xF415D6EE,0x29830F6B, + 0xBAF47C84,0x6762A501,0xDAA8C9CF,0x073E104A,0x7A4D1612,0xA7DBCF97,0x1A11A359,0xC7877ADC, + 0xE0F7AFE9,0x3D61766C,0x80AB1AA2,0x5D3DC327,0x204EC57F,0xFDD81CFA,0x40127034,0x9D84A9B1, + 0xA06A2517,0x7DFCFC92,0xC036905C,0x1DA049D9,0x60D34F81,0xBD459604,0x008FFACA,0xDD19234F, + 0xFA69F67A,0x27FF2FFF,0x9A354331,0x47A39AB4,0x3AD09CEC,0xE7464569,0x5A8C29A7,0x871AF022, + 0x146D83CD,0xC9FB5A48,0x74313686,0xA9A7EF03,0xD4D4E95B,0x094230DE,0xB4885C10,0x691E8595, + 0x4E6E50A0,0x93F88925,0x2E32E5EB,0xF3A43C6E,0x8ED73A36,0x5341E3B3,0xEE8B8F7D,0x331D56F8, + 0x13146EE2,0xCE82B767,0x7348DBA9,0xAEDE022C,0xD3AD0474,0x0E3BDDF1,0xB3F1B13F,0x6E6768BA, + 0x4917BD8F,0x9481640A,0x294B08C4,0xF4DDD141,0x89AED719,0x54380E9C,0xE9F26252,0x3464BBD7, + 0xA713C838,0x7A8511BD,0xC74F7D73,0x1AD9A4F6,0x67AAA2AE,0xBA3C7B2B,0x07F617E5,0xDA60CE60, + 0xFD101B55,0x2086C2D0,0x9D4CAE1E,0x40DA779B,0x3DA971C3,0xE03FA846,0x5DF5C488,0x80631D0D, + 0x1DE7B4BC,0xC0716D39,0x7DBB01F7,0xA02DD872,0xDD5EDE2A,0x00C807AF,0xBD026B61,0x6094B2E4, + 0x47E467D1,0x9A72BE54,0x27B8D29A,0xFA2E0B1F,0x875D0D47,0x5ACBD4C2,0xE701B80C,0x3A976189, + 0xA9E01266,0x7476CBE3,0xC9BCA72D,0x142A7EA8,0x695978F0,0xB4CFA175,0x0905CDBB,0xD493143E, + 0xF3E3C10B,0x2E75188E,0x93BF7440,0x4E29ADC5,0x335AAB9D,0xEECC7218,0x53061ED6,0x8E90C753, + 0xAE99FF49,0x730F26CC,0xCEC54A02,0x13539387,0x6E2095DF,0xB3B64C5A,0x0E7C2094,0xD3EAF911, + 0xF49A2C24,0x290CF5A1,0x94C6996F,0x495040EA,0x342346B2,0xE9B59F37,0x547FF3F9,0x89E92A7C, + 0x1A9E5993,0xC7088016,0x7AC2ECD8,0xA754355D,0xDA273305,0x07B1EA80,0xBA7B864E,0x67ED5FCB, + 0x409D8AFE,0x9D0B537B,0x20C13FB5,0xFD57E630,0x8024E068,0x5DB239ED,0xE0785523,0x3DEE8CA6, + }, + + { + 0x00000000,0x9D0FE176,0xE16EC4AD,0x7C6125DB,0x19AC8F1B,0x84A36E6D,0xF8C24BB6,0x65CDAAC0, + 0x33591E36,0xAE56FF40,0xD237DA9B,0x4F383BED,0x2AF5912D,0xB7FA705B,0xCB9B5580,0x5694B4F6, + 0x66B23C6C,0xFBBDDD1A,0x87DCF8C1,0x1AD319B7,0x7F1EB377,0xE2115201,0x9E7077DA,0x037F96AC, + 0x55EB225A,0xC8E4C32C,0xB485E6F7,0x298A0781,0x4C47AD41,0xD1484C37,0xAD2969EC,0x3026889A, + 0xCD6478D8,0x506B99AE,0x2C0ABC75,0xB1055D03,0xD4C8F7C3,0x49C716B5,0x35A6336E,0xA8A9D218, + 0xFE3D66EE,0x63328798,0x1F53A243,0x825C4335,0xE791E9F5,0x7A9E0883,0x06FF2D58,0x9BF0CC2E, + 0xABD644B4,0x36D9A5C2,0x4AB88019,0xD7B7616F,0xB27ACBAF,0x2F752AD9,0x53140F02,0xCE1BEE74, + 0x988F5A82,0x0580BBF4,0x79E19E2F,0xE4EE7F59,0x8123D599,0x1C2C34EF,0x604D1134,0xFD42F042, + 0x41B9F7F1,0xDCB61687,0xA0D7335C,0x3DD8D22A,0x581578EA,0xC51A999C,0xB97BBC47,0x24745D31, + 0x72E0E9C7,0xEFEF08B1,0x938E2D6A,0x0E81CC1C,0x6B4C66DC,0xF64387AA,0x8A22A271,0x172D4307, + 0x270BCB9D,0xBA042AEB,0xC6650F30,0x5B6AEE46,0x3EA74486,0xA3A8A5F0,0xDFC9802B,0x42C6615D, + 0x1452D5AB,0x895D34DD,0xF53C1106,0x6833F070,0x0DFE5AB0,0x90F1BBC6,0xEC909E1D,0x719F7F6B, + 0x8CDD8F29,0x11D26E5F,0x6DB34B84,0xF0BCAAF2,0x95710032,0x087EE144,0x741FC49F,0xE91025E9, + 0xBF84911F,0x228B7069,0x5EEA55B2,0xC3E5B4C4,0xA6281E04,0x3B27FF72,0x4746DAA9,0xDA493BDF, + 0xEA6FB345,0x77605233,0x0B0177E8,0x960E969E,0xF3C33C5E,0x6ECCDD28,0x12ADF8F3,0x8FA21985, + 0xD936AD73,0x44394C05,0x385869DE,0xA55788A8,0xC09A2268,0x5D95C31E,0x21F4E6C5,0xBCFB07B3, + 0x8373EFE2,0x1E7C0E94,0x621D2B4F,0xFF12CA39,0x9ADF60F9,0x07D0818F,0x7BB1A454,0xE6BE4522, + 0xB02AF1D4,0x2D2510A2,0x51443579,0xCC4BD40F,0xA9867ECF,0x34899FB9,0x48E8BA62,0xD5E75B14, + 0xE5C1D38E,0x78CE32F8,0x04AF1723,0x99A0F655,0xFC6D5C95,0x6162BDE3,0x1D039838,0x800C794E, + 0xD698CDB8,0x4B972CCE,0x37F60915,0xAAF9E863,0xCF3442A3,0x523BA3D5,0x2E5A860E,0xB3556778, + 0x4E17973A,0xD318764C,0xAF795397,0x3276B2E1,0x57BB1821,0xCAB4F957,0xB6D5DC8C,0x2BDA3DFA, + 0x7D4E890C,0xE041687A,0x9C204DA1,0x012FACD7,0x64E20617,0xF9EDE761,0x858CC2BA,0x188323CC, + 0x28A5AB56,0xB5AA4A20,0xC9CB6FFB,0x54C48E8D,0x3109244D,0xAC06C53B,0xD067E0E0,0x4D680196, + 0x1BFCB560,0x86F35416,0xFA9271CD,0x679D90BB,0x02503A7B,0x9F5FDB0D,0xE33EFED6,0x7E311FA0, + 0xC2CA1813,0x5FC5F965,0x23A4DCBE,0xBEAB3DC8,0xDB669708,0x4669767E,0x3A0853A5,0xA707B2D3, + 0xF1930625,0x6C9CE753,0x10FDC288,0x8DF223FE,0xE83F893E,0x75306848,0x09514D93,0x945EACE5, + 0xA478247F,0x3977C509,0x4516E0D2,0xD81901A4,0xBDD4AB64,0x20DB4A12,0x5CBA6FC9,0xC1B58EBF, + 0x97213A49,0x0A2EDB3F,0x764FFEE4,0xEB401F92,0x8E8DB552,0x13825424,0x6FE371FF,0xF2EC9089, + 0x0FAE60CB,0x92A181BD,0xEEC0A466,0x73CF4510,0x1602EFD0,0x8B0D0EA6,0xF76C2B7D,0x6A63CA0B, + 0x3CF77EFD,0xA1F89F8B,0xDD99BA50,0x40965B26,0x255BF1E6,0xB8541090,0xC435354B,0x593AD43D, + 0x691C5CA7,0xF413BDD1,0x8872980A,0x157D797C,0x70B0D3BC,0xEDBF32CA,0x91DE1711,0x0CD1F667, + 0x5A454291,0xC74AA3E7,0xBB2B863C,0x2624674A,0x43E9CD8A,0xDEE62CFC,0xA2870927,0x3F88E851, + }, + + { + 0x00000000,0xB9FBDBE8,0xA886B191,0x117D6A79,0x8A7C6563,0x3387BE8B,0x22FAD4F2,0x9B010F1A, + 0xCF89CC87,0x7672176F,0x670F7D16,0xDEF4A6FE,0x45F5A9E4,0xFC0E720C,0xED731875,0x5488C39D, + 0x44629F4F,0xFD9944A7,0xECE42EDE,0x551FF536,0xCE1EFA2C,0x77E521C4,0x66984BBD,0xDF639055, + 0x8BEB53C8,0x32108820,0x236DE259,0x9A9639B1,0x019736AB,0xB86CED43,0xA911873A,0x10EA5CD2, + 0x88C53E9E,0x313EE576,0x20438F0F,0x99B854E7,0x02B95BFD,0xBB428015,0xAA3FEA6C,0x13C43184, + 0x474CF219,0xFEB729F1,0xEFCA4388,0x56319860,0xCD30977A,0x74CB4C92,0x65B626EB,0xDC4DFD03, + 0xCCA7A1D1,0x755C7A39,0x64211040,0xDDDACBA8,0x46DBC4B2,0xFF201F5A,0xEE5D7523,0x57A6AECB, + 0x032E6D56,0xBAD5B6BE,0xABA8DCC7,0x1253072F,0x89520835,0x30A9D3DD,0x21D4B9A4,0x982F624C, + 0xCAFB7B7D,0x7300A095,0x627DCAEC,0xDB861104,0x40871E1E,0xF97CC5F6,0xE801AF8F,0x51FA7467, + 0x0572B7FA,0xBC896C12,0xADF4066B,0x140FDD83,0x8F0ED299,0x36F50971,0x27886308,0x9E73B8E0, + 0x8E99E432,0x37623FDA,0x261F55A3,0x9FE48E4B,0x04E58151,0xBD1E5AB9,0xAC6330C0,0x1598EB28, + 0x411028B5,0xF8EBF35D,0xE9969924,0x506D42CC,0xCB6C4DD6,0x7297963E,0x63EAFC47,0xDA1127AF, + 0x423E45E3,0xFBC59E0B,0xEAB8F472,0x53432F9A,0xC8422080,0x71B9FB68,0x60C49111,0xD93F4AF9, + 0x8DB78964,0x344C528C,0x253138F5,0x9CCAE31D,0x07CBEC07,0xBE3037EF,0xAF4D5D96,0x16B6867E, + 0x065CDAAC,0xBFA70144,0xAEDA6B3D,0x1721B0D5,0x8C20BFCF,0x35DB6427,0x24A60E5E,0x9D5DD5B6, + 0xC9D5162B,0x702ECDC3,0x6153A7BA,0xD8A87C52,0x43A97348,0xFA52A8A0,0xEB2FC2D9,0x52D41931, + 0x4E87F0BB,0xF77C2B53,0xE601412A,0x5FFA9AC2,0xC4FB95D8,0x7D004E30,0x6C7D2449,0xD586FFA1, + 0x810E3C3C,0x38F5E7D4,0x29888DAD,0x90735645,0x0B72595F,0xB28982B7,0xA3F4E8CE,0x1A0F3326, + 0x0AE56FF4,0xB31EB41C,0xA263DE65,0x1B98058D,0x80990A97,0x3962D17F,0x281FBB06,0x91E460EE, + 0xC56CA373,0x7C97789B,0x6DEA12E2,0xD411C90A,0x4F10C610,0xF6EB1DF8,0xE7967781,0x5E6DAC69, + 0xC642CE25,0x7FB915CD,0x6EC47FB4,0xD73FA45C,0x4C3EAB46,0xF5C570AE,0xE4B81AD7,0x5D43C13F, + 0x09CB02A2,0xB030D94A,0xA14DB333,0x18B668DB,0x83B767C1,0x3A4CBC29,0x2B31D650,0x92CA0DB8, + 0x8220516A,0x3BDB8A82,0x2AA6E0FB,0x935D3B13,0x085C3409,0xB1A7EFE1,0xA0DA8598,0x19215E70, + 0x4DA99DED,0xF4524605,0xE52F2C7C,0x5CD4F794,0xC7D5F88E,0x7E2E2366,0x6F53491F,0xD6A892F7, + 0x847C8BC6,0x3D87502E,0x2CFA3A57,0x9501E1BF,0x0E00EEA5,0xB7FB354D,0xA6865F34,0x1F7D84DC, + 0x4BF54741,0xF20E9CA9,0xE373F6D0,0x5A882D38,0xC1892222,0x7872F9CA,0x690F93B3,0xD0F4485B, + 0xC01E1489,0x79E5CF61,0x6898A518,0xD1637EF0,0x4A6271EA,0xF399AA02,0xE2E4C07B,0x5B1F1B93, + 0x0F97D80E,0xB66C03E6,0xA711699F,0x1EEAB277,0x85EBBD6D,0x3C106685,0x2D6D0CFC,0x9496D714, + 0x0CB9B558,0xB5426EB0,0xA43F04C9,0x1DC4DF21,0x86C5D03B,0x3F3E0BD3,0x2E4361AA,0x97B8BA42, + 0xC33079DF,0x7ACBA237,0x6BB6C84E,0xD24D13A6,0x494C1CBC,0xF0B7C754,0xE1CAAD2D,0x583176C5, + 0x48DB2A17,0xF120F1FF,0xE05D9B86,0x59A6406E,0xC2A74F74,0x7B5C949C,0x6A21FEE5,0xD3DA250D, + 0x8752E690,0x3EA93D78,0x2FD45701,0x962F8CE9,0x0D2E83F3,0xB4D5581B,0xA5A83262,0x1C53E98A, + }, + + { + 0x00000000,0xAE689191,0x87A02563,0x29C8B4F2,0xD4314C87,0x7A59DD16,0x539169E4,0xFDF9F875, + 0x73139F4F,0xDD7B0EDE,0xF4B3BA2C,0x5ADB2BBD,0xA722D3C8,0x094A4259,0x2082F6AB,0x8EEA673A, + 0xE6273E9E,0x484FAF0F,0x61871BFD,0xCFEF8A6C,0x32167219,0x9C7EE388,0xB5B6577A,0x1BDEC6EB, + 0x9534A1D1,0x3B5C3040,0x129484B2,0xBCFC1523,0x4105ED56,0xEF6D7CC7,0xC6A5C835,0x68CD59A4, + 0x173F7B7D,0xB957EAEC,0x909F5E1E,0x3EF7CF8F,0xC30E37FA,0x6D66A66B,0x44AE1299,0xEAC68308, + 0x642CE432,0xCA4475A3,0xE38CC151,0x4DE450C0,0xB01DA8B5,0x1E753924,0x37BD8DD6,0x99D51C47, + 0xF11845E3,0x5F70D472,0x76B86080,0xD8D0F111,0x25290964,0x8B4198F5,0xA2892C07,0x0CE1BD96, + 0x820BDAAC,0x2C634B3D,0x05ABFFCF,0xABC36E5E,0x563A962B,0xF85207BA,0xD19AB348,0x7FF222D9, + 0x2E7EF6FA,0x8016676B,0xA9DED399,0x07B64208,0xFA4FBA7D,0x54272BEC,0x7DEF9F1E,0xD3870E8F, + 0x5D6D69B5,0xF305F824,0xDACD4CD6,0x74A5DD47,0x895C2532,0x2734B4A3,0x0EFC0051,0xA09491C0, + 0xC859C864,0x663159F5,0x4FF9ED07,0xE1917C96,0x1C6884E3,0xB2001572,0x9BC8A180,0x35A03011, + 0xBB4A572B,0x1522C6BA,0x3CEA7248,0x9282E3D9,0x6F7B1BAC,0xC1138A3D,0xE8DB3ECF,0x46B3AF5E, + 0x39418D87,0x97291C16,0xBEE1A8E4,0x10893975,0xED70C100,0x43185091,0x6AD0E463,0xC4B875F2, + 0x4A5212C8,0xE43A8359,0xCDF237AB,0x639AA63A,0x9E635E4F,0x300BCFDE,0x19C37B2C,0xB7ABEABD, + 0xDF66B319,0x710E2288,0x58C6967A,0xF6AE07EB,0x0B57FF9E,0xA53F6E0F,0x8CF7DAFD,0x229F4B6C, + 0xAC752C56,0x021DBDC7,0x2BD50935,0x85BD98A4,0x784460D1,0xD62CF140,0xFFE445B2,0x518CD423, + 0x5CFDEDF4,0xF2957C65,0xDB5DC897,0x75355906,0x88CCA173,0x26A430E2,0x0F6C8410,0xA1041581, + 0x2FEE72BB,0x8186E32A,0xA84E57D8,0x0626C649,0xFBDF3E3C,0x55B7AFAD,0x7C7F1B5F,0xD2178ACE, + 0xBADAD36A,0x14B242FB,0x3D7AF609,0x93126798,0x6EEB9FED,0xC0830E7C,0xE94BBA8E,0x47232B1F, + 0xC9C94C25,0x67A1DDB4,0x4E696946,0xE001F8D7,0x1DF800A2,0xB3909133,0x9A5825C1,0x3430B450, + 0x4BC29689,0xE5AA0718,0xCC62B3EA,0x620A227B,0x9FF3DA0E,0x319B4B9F,0x1853FF6D,0xB63B6EFC, + 0x38D109C6,0x96B99857,0xBF712CA5,0x1119BD34,0xECE04541,0x4288D4D0,0x6B406022,0xC528F1B3, + 0xADE5A817,0x038D3986,0x2A458D74,0x842D1CE5,0x79D4E490,0xD7BC7501,0xFE74C1F3,0x501C5062, + 0xDEF63758,0x709EA6C9,0x5956123B,0xF73E83AA,0x0AC77BDF,0xA4AFEA4E,0x8D675EBC,0x230FCF2D, + 0x72831B0E,0xDCEB8A9F,0xF5233E6D,0x5B4BAFFC,0xA6B25789,0x08DAC618,0x211272EA,0x8F7AE37B, + 0x01908441,0xAFF815D0,0x8630A122,0x285830B3,0xD5A1C8C6,0x7BC95957,0x5201EDA5,0xFC697C34, + 0x94A42590,0x3ACCB401,0x130400F3,0xBD6C9162,0x40956917,0xEEFDF886,0xC7354C74,0x695DDDE5, + 0xE7B7BADF,0x49DF2B4E,0x60179FBC,0xCE7F0E2D,0x3386F658,0x9DEE67C9,0xB426D33B,0x1A4E42AA, + 0x65BC6073,0xCBD4F1E2,0xE21C4510,0x4C74D481,0xB18D2CF4,0x1FE5BD65,0x362D0997,0x98459806, + 0x16AFFF3C,0xB8C76EAD,0x910FDA5F,0x3F674BCE,0xC29EB3BB,0x6CF6222A,0x453E96D8,0xEB560749, + 0x839B5EED,0x2DF3CF7C,0x043B7B8E,0xAA53EA1F,0x57AA126A,0xF9C283FB,0xD00A3709,0x7E62A698, + 0xF088C1A2,0x5EE05033,0x7728E4C1,0xD9407550,0x24B98D25,0x8AD11CB4,0xA319A846,0x0D7139D7, + } +#endif // CRC32_USE_LOOKUP_TABLE_SLICING_BY_16 +}; +#endif + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/serialize/file_adapter.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/serialize/file_adapter.h new file mode 100644 index 0000000000000000000000000000000000000000..398b9e97c0bb68bb2fe2e3e223c641b7dd114acb --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/serialize/file_adapter.h @@ -0,0 +1,41 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include + +#include "caffe2/serialize/istream_adapter.h" +#include "caffe2/serialize/read_adapter_interface.h" + +namespace caffe2 { +namespace serialize { + +class TORCH_API FileAdapter final : public ReadAdapterInterface { + public: + C10_DISABLE_COPY_AND_ASSIGN(FileAdapter); + explicit FileAdapter(const std::string& file_name); + size_t size() const override; + size_t read(uint64_t pos, void* buf, size_t n, const char* what = "") + const override; + ~FileAdapter() override; + + private: + // An RAII Wrapper for a FILE pointer. Closes on destruction. + struct RAIIFile { + FILE* fp_; + explicit RAIIFile(const std::string& file_name); + ~RAIIFile(); + }; + + RAIIFile file_; + // The size of the opened file in bytes + uint64_t size_; +}; + +} // namespace serialize +} // namespace caffe2 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/serialize/in_memory_adapter.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/serialize/in_memory_adapter.h new file mode 100644 index 0000000000000000000000000000000000000000..394898e5ed08ec4c62c8868ae12cf846ad7bf22f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/serialize/in_memory_adapter.h @@ -0,0 +1,35 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include +#include + +namespace caffe2 { +namespace serialize { + +class MemoryReadAdapter final : public caffe2::serialize::ReadAdapterInterface { + public: + explicit MemoryReadAdapter(const void* data, off_t size) + : data_(data), size_(size) {} + + size_t size() const override { + return size_; + } + + size_t read(uint64_t pos, void* buf, size_t n, const char* what = "") + const override { + (void)what; + memcpy(buf, (int8_t*)(data_) + pos, n); + return n; + } + + private: + const void* data_; + off_t size_; +}; + +} // namespace serialize +} // namespace caffe2 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/serialize/inline_container.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/serialize/inline_container.h new file mode 100644 index 0000000000000000000000000000000000000000..ef3436b6fece5e661fa4977cafb8d8534f2235fd --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/serialize/inline_container.h @@ -0,0 +1,315 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "caffe2/serialize/istream_adapter.h" +#include "caffe2/serialize/read_adapter_interface.h" +#include "caffe2/serialize/versions.h" + +extern "C" { +typedef struct mz_zip_archive mz_zip_archive; +} + +// PyTorch containers are a special zip archive with the following layout +// archive_name.zip contains: +// archive_name/ +// version # a file with a single decimal number written in ascii, +// # used to establish the version of the archive format +// model.json # overall model description, this is a json output of +// # ModelDef from torch.proto +// # the following names are by convention only, model.json will +// # refer to these files by full names +// tensors/ +// 0 # flat storage for tensor data, meta-data about shapes, etc. is +// # in model.json +// 1 +// ... +// # code entries will only exist for modules that have methods attached +// code/ +// archive_name.py # serialized torch script code (python syntax, using +// PythonPrint) archive_name_my_submodule.py # submodules have separate +// files +// +// The PyTorchStreamWriter also ensures additional useful properties for these +// files +// 1. All files are stored uncompressed. +// 2. All files in the archive are aligned to 64 byte boundaries such that +// it is possible to mmap the entire file and get an aligned pointer to +// tensor data. +// 3. We universally write in ZIP64 format for consistency. + +// The PyTorchStreamReader also provides additional properties: +// 1. It can read zip files that are created with common +// zip tools. This means that even though our writer doesn't compress files, +// the reader can still read files that were compressed. +// 2. It provides a getRecordOffset function which returns the offset into the +// raw file where file data lives. If the file was written with +// PyTorchStreamWriter it is guaranteed to be 64 byte aligned. + +// PyTorchReader/Writer handle checking the version number on the archive format +// and ensure that all files are written to a archive_name directory so they +// unzip cleanly. + +// When developing this format we want to pay particular attention to the +// following use cases: +// +// -- Reading -- +// 1) Reading with full random access +// a) Reading with file api's such as fread() +// b) mmaping the file and jumping around the mapped region +// 2) Reading with 1-pass sequential access +// -> A reader will need to build up a data structure of parsed structures +// as it reads +// +// -- Writing -- +// 1) Writing with full random access +// 2) Writing with 1-pass sequential access +// -> We must take care not to require updating values that have already +// been written. We place the variable-length index at the end and do +// not put any index into the header to fulfill this constraint. + +// The model.json, which contains all the metadata information, +// should be written as the last file. One reason is that the size of tensor +// data is usually stable. As long as the shape and type of the tensor do not +// change, the size of the data won't change. On the other sied, the size of the +// serialized model is likely to change, so we store it as the last record, and +// we don't need to move previous records when updating the model data. + +// The zip format is sufficiently flexible to handle the above use-case. +// it puts its central directory at the end of the archive and we write +// model.json as the last file when writing after we have accumulated all +// other information. + +namespace caffe2 { +namespace serialize { + +static constexpr const char* kSerializationIdRecordName = + ".data/serialization_id"; + +struct MzZipReaderIterWrapper; + +class TORCH_API ChunkRecordIterator { + public: + ~ChunkRecordIterator(); + + // Read at most `chunkSize` into `buf`. Return the number of actual bytes + // read. + size_t next(void* buf); + size_t recordSize() const { + return recordSize_; + } + + private: + ChunkRecordIterator( + size_t recordSize, + size_t chunkSize, + std::unique_ptr iter); + + const size_t recordSize_; + const size_t chunkSize_; + size_t offset_; + std::unique_ptr iter_; + + friend class PyTorchStreamReader; +}; + +class TORCH_API PyTorchStreamReader final { + public: + explicit PyTorchStreamReader(const std::string& file_name); + explicit PyTorchStreamReader(std::istream* in); + explicit PyTorchStreamReader(std::shared_ptr in); + + // return dataptr, size + // set allocator to override default cpu allocator + std::tuple getRecord( + const std::string& name, + std::optional allocator = std::nullopt); + // multi-thread getRecord + std::tuple getRecord( + const std::string& name, + std::vector>& additionalReaders, + std::optional allocator = std::nullopt); + // inplace memory writing + size_t getRecord(const std::string& name, void* dst, size_t n); + // inplace memory writing, multi-threads. + // When additionalReaders is empty, the default behavior is call + // getRecord(name, dst, n) with default reader This approach can be used for + // reading large tensors. + size_t getRecord( + const std::string& name, + void* dst, + size_t n, + std::vector>& additionalReaders); + size_t getRecord( + const std::string& name, + void* dst, + size_t n, + size_t chunk_size, + void* buf, + const std::function& memcpy_func = + nullptr); + + // Concurrent reading records with multiple readers. + // additionalReaders are additional clients to access the underlying record at + // different offsets and write to different trunks of buffers. If the overall + // size of the tensor is 10, and size of additionalReader is 2. The default + // thread will read [0,4), the additional reader will read [4,8). The default + // reader will read [8,10). The default reader will write to buffer[0,4), the + // additional reader will write to buffer[4,8), the additional reader will + // write to buffer[8,10). When additionalReaders is empty, the default + // behavior is call getRecord(name) with default reader This approach can be + // used for reading large tensors. + size_t getRecordMultiReaders( + const std::string& name, + std::vector>& additionalReaders, + void* dst, + size_t n); + + size_t getRecordSize(const std::string& name); + size_t getRecordHeaderOffset(const std::string& name); + size_t getRecordOffset(const std::string& name); + size_t getRecordOffsetNoRead( + size_t cursor, + std::string filename, + size_t size, + uint64_t alignment); + bool hasRecord(const std::string& name); + std::vector getAllRecords(); + + ChunkRecordIterator createChunkReaderIter( + const std::string& name, + const size_t recordSize, + const size_t chunkSize); + + ~PyTorchStreamReader(); + uint64_t version() const { + return version_; + } + const std::string& serializationId() { + return serialization_id_; + } + + void setShouldLoadDebugSymbol(bool should_load_debug_symbol) { + load_debug_symbol_ = should_load_debug_symbol; + } + void setAdditionalReaderSizeThreshold(const size_t& size) { + additional_reader_size_threshold_ = size; + } + + private: + void init(); + size_t read(uint64_t pos, char* buf, size_t n); + void valid(const char* what, const char* info = ""); + size_t getRecordID(const std::string& name); + + friend size_t + istream_read_func(void* pOpaque, uint64_t file_ofs, void* pBuf, size_t n); + std::unique_ptr ar_; + std::string archive_name_; + std::string archive_name_plus_slash_; + std::shared_ptr in_; + int64_t version_; + std::mutex reader_lock_; + bool load_debug_symbol_ = true; + std::string serialization_id_; + size_t additional_reader_size_threshold_; +}; + +class TORCH_API PyTorchStreamWriter final { + public: + explicit PyTorchStreamWriter( + const std::string& archive_name, + bool compute_crc32 = true, + uint64_t alignment = 64); + explicit PyTorchStreamWriter( + const std::function writer_func, + bool compute_crc32 = true, + uint64_t alignment = 64); + + void setMinVersion(const uint64_t version); + + void writeRecord( + const std::string& name, + const void* data, + size_t size, + bool compress = false); + void writeEndOfFile(); + + const std::unordered_set& getAllWrittenRecords(); + + bool finalized() const { + return finalized_; + } + + const std::string& archiveName() { + return archive_name_; + } + + const std::string& serializationId() { + return serialization_id_; + } + + ~PyTorchStreamWriter(); + + private: + void setup(const std::string& file_name); + void valid(const char* what, const char* info = ""); + void writeSerializationId(); + size_t current_pos_ = 0; + std::unordered_set files_written_; + std::unique_ptr ar_; + std::string archive_name_; + std::string archive_name_plus_slash_; + std::string padding_; + std::ofstream file_stream_; + std::function writer_func_; + uint64_t combined_uncomp_crc32_ = 0; + std::string serialization_id_; + bool compute_crc32_; + uint64_t alignment_; + + // This number will be updated when the model has operators + // that have valid upgraders. + uint64_t version_ = kMinProducedFileFormatVersion; + bool finalized_ = false; + bool err_seen_ = false; + friend size_t ostream_write_func( + void* pOpaque, + uint64_t file_ofs, + const void* pBuf, + size_t n); +}; + +namespace detail { + +// Returns a record to be appended to the local user extra data entry in order +// to make data beginning aligned at kFieldAlignment bytes boundary. +size_t getPadding( + size_t cursor, + size_t filename_size, + size_t size, + std::string& padding_buf, + uint64_t alignment); + +std::tuple +getOffset(size_t cursor, size_t filename_size, size_t size, uint64_t alignment); + +} // namespace detail + +} // namespace serialize +} // namespace caffe2 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/serialize/istream_adapter.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/serialize/istream_adapter.h new file mode 100644 index 0000000000000000000000000000000000000000..0e205be7f1ceef1ffc92f686d1cd464f60899ae3 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/serialize/istream_adapter.h @@ -0,0 +1,32 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include + +#include "c10/macros/Macros.h" +#include "caffe2/serialize/read_adapter_interface.h" + +namespace caffe2 { +namespace serialize { + +// this is a reader implemented by std::istream +class TORCH_API IStreamAdapter final : public ReadAdapterInterface { + public: + C10_DISABLE_COPY_AND_ASSIGN(IStreamAdapter); + explicit IStreamAdapter(std::istream* istream); + size_t size() const override; + size_t read(uint64_t pos, void* buf, size_t n, const char* what = "") + const override; + ~IStreamAdapter() override; + + private: + std::istream* istream_; + void validate(const char* what) const; +}; + +} // namespace serialize +} // namespace caffe2 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/serialize/read_adapter_interface.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/serialize/read_adapter_interface.h new file mode 100644 index 0000000000000000000000000000000000000000..bc4b4505f4b786a0c8088e7ecc2253b877a20298 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/serialize/read_adapter_interface.h @@ -0,0 +1,28 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +#include "c10/macros/Macros.h" + +namespace caffe2 { +namespace serialize { + +// this is the interface for the (file/stream/memory) reader in +// PyTorchStreamReader. with this interface, we can extend the support +// besides standard istream +class TORCH_API ReadAdapterInterface { + public: + virtual size_t size() const = 0; + virtual size_t read(uint64_t pos, void* buf, size_t n, const char* what = "") + const = 0; + virtual ~ReadAdapterInterface(); +}; + +} // namespace serialize +} // namespace caffe2 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/serialize/versions.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/serialize/versions.h new file mode 100644 index 0000000000000000000000000000000000000000..f21f4db27caa05bf69b3f05fdcf93ccf241d0944 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/serialize/versions.h @@ -0,0 +1,138 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include + +namespace caffe2 { +namespace serialize { + +constexpr uint64_t kMinSupportedFileFormatVersion = 0x1L; + +constexpr uint64_t kMaxSupportedFileFormatVersion = 0xAL; + +// Versions (i.e. why was the version number bumped?) + +// Note [Dynamic Versions and torch.jit.save vs. torch.save] +// +// Our versioning scheme has a "produced file format version" which +// describes how an archive is to be read. The version written in an archive +// is at least this current produced file format version, but may be greater +// if it includes certain symbols. We refer to these conditional versions +// as "dynamic," since they are identified at runtime. +// +// Dynamic versioning is useful when an operator's semantics are updated. +// When using torch.jit.save we want those semantics to be preserved. If +// we bumped the produced file format version on every change, however, +// then older versions of PyTorch couldn't read even simple archives, like +// a single tensor, from newer versions of PyTorch. Instead, we +// assign dynamic versions to these changes that override the +// produced file format version as needed. That is, when the semantics +// of torch.div changed it was assigned dynamic version 4, and when +// torch.jit.saving modules that use torch.div those archives also have +// (at least) version 4. This prevents earlier versions of PyTorch +// from accidentally performing the wrong kind of division. Modules +// that don't use torch.div or other operators with dynamic versions +// can write the produced file format version, and these programs will +// run as expected on earlier versions of PyTorch. +// +// While torch.jit.save attempts to preserve operator semantics, +// torch.save does not. torch.save is analogous to pickling Python, so +// a function that uses torch.div will have different behavior if torch.saved +// and torch.loaded across PyTorch versions. From a technical perspective, +// torch.save ignores dynamic versioning. + +// 1. Initial version +// 2. Removed op_version_set version numbers +// 3. Added type tags to pickle serialization of container types +// 4. (Dynamic) Stopped integer division using torch.div +// (a versioned symbol preserves the historic behavior of versions 1--3) +// 5. (Dynamic) Stops torch.full inferring a floating point dtype +// when given bool or integer fill values. +// 6. Write version string to `./data/version` instead of `version`. + +// [12/15/2021] +// kProducedFileFormatVersion is set to 7 from 3 due to a different +// interpretation of what file format version is. +// Whenever there is new upgrader introduced, +// this number should be bumped. +// The reasons that version is bumped in the past: +// 1. aten::div is changed at version 4 +// 2. aten::full is changed at version 5 +// 3. torch.package uses version 6 +// 4. Introduce new upgrader design and set the version number to 7 +// mark this change +// -------------------------------------------------- +// We describe new operator version bump reasons here: +// 1) [01/24/2022] +// We bump the version number to 8 to update aten::linspace +// and aten::linspace.out to error out when steps is not +// provided. (see: https://github.com/pytorch/pytorch/issues/55951) +// 2) [01/30/2022] +// Bump the version number to 9 to update aten::logspace and +// and aten::logspace.out to error out when steps is not +// provided. (see: https://github.com/pytorch/pytorch/issues/55951) +// 3) [02/11/2022] +// Bump the version number to 10 to update aten::gelu and +// and aten::gelu.out to support the new approximate kwarg. +// (see: https://github.com/pytorch/pytorch/pull/61439) +constexpr uint64_t kProducedFileFormatVersion = 0xAL; + +// Absolute minimum version we will write packages. This +// means that every package from now on will always be +// greater than this number. +constexpr uint64_t kMinProducedFileFormatVersion = 0x3L; + +// The version we write when the archive contains bytecode. +// It must be higher or eq to kProducedFileFormatVersion. +// Because torchscript changes is likely introduce bytecode change. +// If kProducedFileFormatVersion is increased, kProducedBytecodeVersion +// should be increased too. The relationship is: +// kMaxSupportedFileFormatVersion >= (most likely ==) kProducedBytecodeVersion +// >= kProducedFileFormatVersion +// If a format change is forward compatible (still readable by older +// executables), we will not increment the version number, to minimize the +// risk of breaking existing clients. TODO: A better way would be to allow +// the caller that creates a model to specify a maximum version that its +// clients can accept. +// Versions: +// 0x1L: Initial version +// 0x2L: (Comment missing) +// 0x3L: (Comment missing) +// 0x4L: (update) Added schema to function tuple. Forward-compatible change. +// 0x5L: (update) Update bytecode is sharing constant tensor files from +// torchscript, and only serialize extra tensors that are not in the +// torchscript constant table. Also update tensor storage schema adapting to +// the unify format, the root key of tensor storage is updated from {index} to +// {the_pointer_value_the_tensor.storage}, for example: +// `140245072983168.storage` Forward-compatibility change. +// 0x6L: Implicit opereator versioning using number of specified argument. +// Refer to the summary of https://github.com/pytorch/pytorch/pull/56845 for +// details. +// 0x7L: Enable support for operators with default arguments plus out +// arguments. Refer. See https://github.com/pytorch/pytorch/pull/63651 for +// details. +// 0x8L: Emit promoted operators as instructions. See +// https://github.com/pytorch/pytorch/pull/71662 for details. +// 0x9L: Change serialization format from pickle to format This version is to +// serve migration. v8 pickle and v9 flatbuffer are the same. Refer to the +// summary of https://github.com/pytorch/pytorch/pull/75201 for more details. +constexpr uint64_t kProducedBytecodeVersion = 0x8L; + +// static_assert( +// kProducedBytecodeVersion >= kProducedFileFormatVersion, +// "kProducedBytecodeVersion must be higher or equal to +// kProducedFileFormatVersion."); + +// Introduce kMinSupportedBytecodeVersion and kMaxSupportedBytecodeVersion +// for limited backward/forward compatibility support of bytecode. If +// kMinSupportedBytecodeVersion <= model_version <= kMaxSupportedBytecodeVersion +// (in loader), we should support this model_version. For example, we provide a +// wrapper to handle an updated operator. +constexpr uint64_t kMinSupportedBytecodeVersion = 0x4L; +constexpr uint64_t kMaxSupportedBytecodeVersion = 0x9L; + +} // namespace serialize +} // namespace caffe2 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/utils/fixed_divisor.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/utils/fixed_divisor.h new file mode 100644 index 0000000000000000000000000000000000000000..8041a2723c8603b05b26956126e37eff436ac905 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/utils/fixed_divisor.h @@ -0,0 +1,137 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#ifndef CAFFE2_UTILS_FIXED_DIVISOR_H_ +#define CAFFE2_UTILS_FIXED_DIVISOR_H_ + +#include +#include +#include + +// See Note [hip-clang differences to hcc] + +#if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__) || defined(__HIP__) || \ + (defined(__clang__) && defined(__CUDA__)) +#define FIXED_DIVISOR_DECL inline __host__ __device__ +#else +#define FIXED_DIVISOR_DECL inline +#endif + +namespace caffe2 { + +// Utility class for quickly calculating quotients and remainders for +// a known integer divisor +template +class FixedDivisor {}; + +// Works for any positive divisor, 1 to INT_MAX. One 64-bit +// multiplication and one 64-bit shift is used to calculate the +// result. +template <> +class FixedDivisor { + public: + FixedDivisor() = default; + + explicit FixedDivisor(const std::int32_t d) : d_(d) { +#if !defined(USE_ROCM) + CalcSignedMagic(); +#endif // USE_ROCM + } + + FIXED_DIVISOR_DECL std::int32_t d() const { + return d_; + } + +#if !defined(USE_ROCM) + FIXED_DIVISOR_DECL std::uint64_t magic() const { + return magic_; + } + + FIXED_DIVISOR_DECL int shift() const { + return shift_; + } +#endif // USE_ROCM + + /// Calculates `q = n / d`. + FIXED_DIVISOR_DECL std::int32_t Div(const std::int32_t n) const { +#if defined(USE_ROCM) + return n / d_; +#else // USE_ROCM + // In lieu of a mulhi instruction being available, perform the + // work in uint64 + return (int32_t)((magic_ * (uint64_t)n) >> shift_); +#endif // USE_ROCM + } + + /// Calculates `r = n % d`. + FIXED_DIVISOR_DECL std::int32_t Mod(const std::int32_t n) const { + return n - d_ * Div(n); + } + + /// Calculates `q = n / d` and `r = n % d` together. + FIXED_DIVISOR_DECL void + DivMod(const std::int32_t n, std::int32_t* q, int32_t* r) const { + *q = Div(n); + *r = n - d_ * *q; + } + + private: +#if !defined(USE_ROCM) + // Calculates magic multiplicative value and shift amount for calculating `q = + // n / d` for signed 32-bit integers. + // Implementation taken from Hacker's Delight section 10. + void CalcSignedMagic() { + if (d_ == 1) { + magic_ = UINT64_C(0x1) << 32; + shift_ = 32; + return; + } + + const std::uint32_t two31 = UINT32_C(0x80000000); + const std::uint32_t ad = std::abs(d_); + const std::uint32_t t = two31 + ((uint32_t)d_ >> 31); + const std::uint32_t anc = t - 1 - t % ad; // Absolute value of nc. + std::uint32_t p = 31; // Init. p. + std::uint32_t q1 = two31 / anc; // Init. q1 = 2**p/|nc|. + std::uint32_t r1 = two31 - q1 * anc; // Init. r1 = rem(2**p, |nc|). + std::uint32_t q2 = two31 / ad; // Init. q2 = 2**p/|d|. + std::uint32_t r2 = two31 - q2 * ad; // Init. r2 = rem(2**p, |d|). + std::uint32_t delta = 0; + do { + ++p; + q1 <<= 1; // Update q1 = 2**p/|nc|. + r1 <<= 1; // Update r1 = rem(2**p, |nc|). + if (r1 >= anc) { // (Must be an unsigned + ++q1; // comparison here). + r1 -= anc; + } + q2 <<= 1; // Update q2 = 2**p/|d|. + r2 <<= 1; // Update r2 = rem(2**p, |d|). + if (r2 >= ad) { // (Must be an unsigned + ++q2; // comparison here). + r2 -= ad; + } + delta = ad - r2; + } while (q1 < delta || (q1 == delta && r1 == 0)); + std::int32_t magic = q2 + 1; + if (d_ < 0) { + magic = -magic; + } + shift_ = p; + magic_ = (std::uint64_t)(std::uint32_t)magic; + } +#endif // USE_ROCM + + std::int32_t d_ = 1; + +#if !defined(USE_ROCM) + std::uint64_t magic_; + int shift_; +#endif // USE_ROCM +}; + +} // namespace caffe2 + +#endif // CAFFE2_UTILS_FIXED_DIVISOR_H_ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/utils/proto_wrap.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/utils/proto_wrap.h new file mode 100644 index 0000000000000000000000000000000000000000..29b58072e159b1ca826fc4b6d8631e7590943969 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/utils/proto_wrap.h @@ -0,0 +1,42 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#ifndef CAFFE2_UTILS_PROTO_WRAP_H_ +#define CAFFE2_UTILS_PROTO_WRAP_H_ + +#include + +namespace caffe2 { + +// A wrapper function to shut down protobuf library (this is needed in ASAN +// testing and valgrind cases to avoid protobuf appearing to "leak" memory). +TORCH_API void ShutdownProtobufLibrary(); + +// Caffe2 wrapper functions for protobuf's GetEmptyStringAlreadyInited() +// function used to avoid duplicated global variable in the case when protobuf +// is built with hidden visibility. +TORCH_API const ::std::string& GetEmptyStringAlreadyInited(); +} // namespace caffe2 + +namespace ONNX_NAMESPACE { + +// ONNX wrapper functions for protobuf's GetEmptyStringAlreadyInited() function +// used to avoid duplicated global variable in the case when protobuf +// is built with hidden visibility. +TORCH_API const ::std::string& GetEmptyStringAlreadyInited(); + +} // namespace ONNX_NAMESPACE + +namespace torch { + +// Caffe2 wrapper functions for protobuf's GetEmptyStringAlreadyInited() +// function used to avoid duplicated global variable in the case when protobuf +// is built with hidden visibility. +TORCH_API const ::std::string& GetEmptyStringAlreadyInited(); + +void ShutdownProtobufLibrary(); + +} // namespace torch +#endif // CAFFE2_UTILS_PROTO_WRAP_H_ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/utils/string_utils.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/utils/string_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..f8d2d49efdb0ca402a6e6b60c0c6de7db9249684 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/utils/string_utils.h @@ -0,0 +1,56 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include + +#include + +namespace caffe2 { + +TORCH_API std::vector +split(char separator, const std::string& string, bool ignore_empty = false); + +TORCH_API std::string trim(const std::string& str); + +TORCH_API size_t editDistance( + const std::string& s1, + const std::string& s2, + size_t max_distance = 0); + +TORCH_API inline bool StartsWith( + const std::string& str, + const std::string& prefix) { + return str.length() >= prefix.length() && + std::mismatch(prefix.begin(), prefix.end(), str.begin()).first == + prefix.end(); +} + +TORCH_API inline bool EndsWith( + const std::string& full, + const std::string& ending) { + if (full.length() >= ending.length()) { + return ( + 0 == + full.compare(full.length() - ending.length(), ending.length(), ending)); + } else { + return false; + } +} + +TORCH_API int32_t editDistanceHelper( + const char* s1, + size_t s1_len, + const char* s2, + size_t s2_len, + std::vector& current, + std::vector& previous, + std::vector& previous1, + size_t max_distance); +} // namespace caffe2 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/utils/threadpool/ThreadPool.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/utils/threadpool/ThreadPool.h new file mode 100644 index 0000000000000000000000000000000000000000..a3769ec59ebdc60be60a685779aa4c3903e0f721 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/utils/threadpool/ThreadPool.h @@ -0,0 +1,84 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#ifndef CAFFE2_UTILS_THREADPOOL_H_ +#define CAFFE2_UTILS_THREADPOOL_H_ + +#include "ThreadPoolCommon.h" + +#include +#include +#include +#include +#include + +#include "c10/util/Flags.h" +#include "caffe2/core/common.h" + +// +// A work-stealing threadpool loosely based off of pthreadpool +// + +namespace caffe2 { + +struct Task; +class WorkersPool; + +constexpr size_t kCacheLineSize = 64; + +// A threadpool with the given number of threads. +// NOTE: the kCacheLineSize alignment is present only for cache +// performance, and is not strictly enforced (for example, when +// the object is created on the heap). Thus, in order to avoid +// misaligned intrinsics, no SSE instructions shall be involved in +// the ThreadPool implementation. +// Note: alignas is disabled because some compilers do not deal with +// TORCH_API and alignas annotations at the same time. +class TORCH_API /*alignas(kCacheLineSize)*/ ThreadPool { + public: + static ThreadPool* createThreadPool(int numThreads); + static std::unique_ptr defaultThreadPool(); + virtual ~ThreadPool() = default; + // Returns the number of threads currently in use + virtual int getNumThreads() const = 0; + virtual void setNumThreads(size_t numThreads) = 0; + + // Sets the minimum work size (range) for which to invoke the + // threadpool; work sizes smaller than this will just be run on the + // main (calling) thread + void setMinWorkSize(size_t size) { + std::lock_guard guard(executionMutex_); + minWorkSize_ = size; + } + + size_t getMinWorkSize() const { + return minWorkSize_; + } + virtual void run(const std::function& fn, size_t range) = 0; + + // Run an arbitrary function in a thread-safe manner accessing the Workers + // Pool + virtual void withPool(const std::function& fn) = 0; + + protected: + static size_t defaultNumThreads_; + mutable std::mutex executionMutex_; + size_t minWorkSize_; +}; + +size_t getDefaultNumThreads(); +} // namespace caffe2 + +C10_DECLARE_bool(caffe2_threadpool_force_inline); + +// Whether or not threadpool caps apply to Android +C10_DECLARE_int(caffe2_threadpool_android_cap); + +// Whether or not threadpool caps apply to iOS and MacOS +C10_DECLARE_int(caffe2_threadpool_ios_cap); +C10_DECLARE_int(caffe2_threadpool_macos_cap); + +C10_DECLARE_int(pthreadpool_size); +#endif // CAFFE2_UTILS_THREADPOOL_H_ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/utils/threadpool/ThreadPoolCommon.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/utils/threadpool/ThreadPoolCommon.h new file mode 100644 index 0000000000000000000000000000000000000000..0bd04aa595c383ea8c1e0cb833e81e5478bc879b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/utils/threadpool/ThreadPoolCommon.h @@ -0,0 +1,25 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#ifndef CAFFE2_UTILS_THREADPOOL_COMMON_H_ +#define CAFFE2_UTILS_THREADPOOL_COMMON_H_ + +#ifdef __APPLE__ +#include +#endif + +// caffe2 depends upon NNPACK, which depends upon this threadpool, so +// unfortunately we can't reference core/common.h here + +// This is copied from core/common.h's definition of C10_MOBILE +// Define enabled when building for iOS or Android devices +#if defined(__ANDROID__) +#define C10_ANDROID 1 +#elif (defined(__APPLE__) && \ + (TARGET_IPHONE_SIMULATOR || TARGET_OS_SIMULATOR || TARGET_OS_IPHONE)) +#define C10_IOS 1 +#endif // ANDROID / IOS + +#endif // CAFFE2_UTILS_THREADPOOL_COMMON_H_ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/utils/threadpool/WorkersPool.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/utils/threadpool/WorkersPool.h new file mode 100644 index 0000000000000000000000000000000000000000..a4adbac9b3c1b3a9672b511cb24dda1c48a4622e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/utils/threadpool/WorkersPool.h @@ -0,0 +1,383 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include "c10/util/thread_name.h" +#include +#include + +#if defined(_MSC_VER) +#include +#endif + +namespace caffe2 { + +// Uses code derived from gemmlowp, +// https://github.com/google/gemmlowp/blob/6c91e1ed0c2eff1182d804310b92911fe9c18019/internal/multi_thread_gemm.h +// Changes: +// - allocation-free execute() +// - Use RAII where possible. +// - Run the first task on the main thread (since that is the largest task). +// - removed custom allocator. +// - Removed some ifdef's +// - cache-line align Worker. +// - use std::atomic instead of volatile and custom barriers. +// - use std::mutex/std::condition_variable instead of raw pthreads. + +constexpr size_t kGEMMLOWPCacheLineSize = 64; + +template +struct AllocAligned { + // Allocate a T aligned at an `align` byte address + template + static T* alloc(Args&&... args) { + void* p = nullptr; + +#if defined(__ANDROID__) + p = memalign(kGEMMLOWPCacheLineSize, sizeof(T)); +#elif defined(_MSC_VER) + p = _aligned_malloc(sizeof(T), kGEMMLOWPCacheLineSize); +#else + auto res = posix_memalign(&p, kGEMMLOWPCacheLineSize, sizeof(T)); + (void)res; +#endif + + if (p) { + return new (p) T(std::forward(args)...); + } + + return nullptr; + } + + // Free a T previously allocated via AllocAligned::alloc() + static void release(T* p) { + if (p) { + p->~T(); +#if defined(_MSC_VER) + _aligned_free((void*)p); +#else + free((void*)p); +#endif + } + } +}; + +// Deleter object for unique_ptr for an aligned object +template +struct AlignedDeleter { + void operator()(T* p) const { AllocAligned::release(p); } +}; + +// make_unique that guarantees alignment +template +struct MakeAligned { + template + static std::unique_ptr> make(Args&&... args) { + return std::unique_ptr>( + AllocAligned::alloc(std::forward(args)...)); + } +}; + +const int kMaxBusyWaitNOPs = 32 * 1000 * 1000; + +#if defined(_MSC_VER) +#define GEMMLOWP_NOP __nop(); +#else +#define GEMMLOWP_NOP "nop\n" +#endif + +#define GEMMLOWP_STRING_CONCAT_4(X) X X X X +#define GEMMLOWP_NOP4 GEMMLOWP_STRING_CONCAT_4(GEMMLOWP_NOP) +#define GEMMLOWP_NOP16 GEMMLOWP_STRING_CONCAT_4(GEMMLOWP_NOP4) +#define GEMMLOWP_NOP64 GEMMLOWP_STRING_CONCAT_4(GEMMLOWP_NOP16) + +inline int Do256NOPs() { +#if defined(_MSC_VER) + GEMMLOWP_NOP64; +#else + asm volatile(GEMMLOWP_NOP64); +#endif + return 64; +} + +#undef GEMMLOWP_STRING_CONCAT_4 +#undef GEMMLOWP_NOP256 +#undef GEMMLOWP_NOP64 +#undef GEMMLOWP_NOP16 +#undef GEMMLOWP_NOP4 +#undef GEMMLOWP_NOP + +// Waits until *var != initial_value. +// +// Returns the new value of *var. The guarantee here is that +// the return value is different from initial_value, and that that +// new value has been taken by *var at some point during the +// execution of this function. There is no guarantee that this is +// still the value of *var when this function returns, since *var is +// not assumed to be guarded by any lock. +// +// First does some busy-waiting for a fixed number of no-op cycles, +// then falls back to passive waiting for the given condvar, guarded +// by the given mutex. +// +// The idea of doing some initial busy-waiting is to help get +// better and more consistent multithreading benefits for small GEMM sizes. +// Busy-waiting help ensuring that if we need to wake up soon after having +// started waiting, then we can wake up quickly (as opposed to, say, +// having to wait to be scheduled again by the OS). On the other hand, +// we must still eventually revert to passive waiting for longer waits +// (e.g. worker threads having finished a GEMM and waiting until the next GEMM) +// so as to avoid permanently spinning. +// +template +T WaitForVariableChange(std::atomic* var, + T initial_value, + std::condition_variable* cond, + std::mutex* mutex) { + // If we are on a platform that supports it, spin for some time. + { + int nops = 0; + // First, trivial case where the variable already changed value. + T new_value = var->load(std::memory_order_relaxed); + if (new_value != initial_value) { + std::atomic_thread_fence(std::memory_order_acquire); + return new_value; + } + // Then try busy-waiting. + while (nops < kMaxBusyWaitNOPs) { + nops += Do256NOPs(); + new_value = var->load(std::memory_order_relaxed); + if (new_value != initial_value) { + std::atomic_thread_fence(std::memory_order_acquire); + return new_value; + } + } + } + + // Finally, do real passive waiting. + { + std::unique_lock g(*mutex); + T new_value = var->load(std::memory_order_relaxed); + // Handle spurious wakeups. + cond->wait(g, [&]() { + new_value = var->load(std::memory_order_relaxed); + return new_value != initial_value; + }); + TORCH_DCHECK_NE(static_cast(new_value), static_cast(initial_value)); + return new_value; + } +} + +// A BlockingCounter lets one thread to wait for N events to occur. +// This is how the master thread waits for all the worker threads +// to have finished working. +class BlockingCounter { + public: + // Sets/resets the counter; initial_count is the number of + // decrementing events that the Wait() call will be waiting for. + void Reset(std::size_t initial_count) { + std::lock_guard g(mutex_); + TORCH_DCHECK_EQ(count_, 0); + count_ = initial_count; + } + + // Decrements the counter; if the counter hits zero, signals + // the thread that was waiting for that, and returns true. + // Otherwise (if the decremented count is still nonzero), + // returns false. + bool DecrementCount() { + const auto count_value = count_.fetch_sub(1, std::memory_order_relaxed) - 1; + if (count_value == 0) { + std::lock_guard g(mutex_); + cond_.notify_one(); + } + bool retval = count_value == 0; + return retval; + } + + // Waits for the N other threads (N having been set by Reset()) + // to hit the BlockingCounter. + void Wait() { + while (size_t count_value = count_.load(std::memory_order_relaxed)) { + WaitForVariableChange(&count_, count_value, &cond_, &mutex_); + } + } + + private: + std::condition_variable cond_; + std::mutex mutex_; + std::atomic count_{0}; +}; + +// A workload for a worker. +struct Task { + Task() = default; + virtual ~Task() = default; + virtual void Run() = 0; +}; + +// A worker thread. +class alignas(kGEMMLOWPCacheLineSize) Worker { + public: + enum class State : uint8_t { + ThreadStartup, // The initial state before the thread main loop runs. + Ready, // Is not working, has not yet received new work to do. + HasWork, // Has work to do. + ExitAsSoonAsPossible // Should exit at earliest convenience. + }; + + explicit Worker(BlockingCounter* counter_to_decrement_when_ready) + : task_(nullptr), + state_(State::ThreadStartup), + counter_to_decrement_when_ready_(counter_to_decrement_when_ready) { + thread_ = std::make_unique([this]() { + c10::setThreadName("pt_thread_pool"); + this->ThreadFunc(); + }); + } + + ~Worker() { + ChangeState(State::ExitAsSoonAsPossible); + thread_->join(); + } + + // Changes State; may be called from either the worker thread + // or the master thread; however, not all state transitions are legal, + // which is guarded by assertions. + void ChangeState(State new_state) { + std::lock_guard g(state_mutex_); + DCHECK(new_state != state_.load(std::memory_order_relaxed)); + switch (state_.load(std::memory_order_relaxed)) { + case State::ThreadStartup: + DCHECK(new_state == State::Ready); + break; + case State::Ready: + DCHECK(new_state == State::HasWork || new_state == State::ExitAsSoonAsPossible); + break; + case State::HasWork: + DCHECK(new_state == State::Ready || new_state == State::ExitAsSoonAsPossible); + break; + case State::ExitAsSoonAsPossible: + default: + abort(); + } + state_.store(new_state, std::memory_order_relaxed); + state_cond_.notify_one(); + if (new_state == State::Ready) { + counter_to_decrement_when_ready_->DecrementCount(); + } + } + + // Thread entry point. + void ThreadFunc() { + c10::setThreadName("CaffeWorkersPool"); + ChangeState(State::Ready); + + // Thread main loop + while (true) { + // Get a state to act on + // In the 'Ready' state, we have nothing to do but to wait until + // we switch to another state. + State state_to_act_upon = + WaitForVariableChange(&state_, State::Ready, &state_cond_, &state_mutex_); + + // We now have a state to act on, so act. + switch (state_to_act_upon) { + case State::HasWork: + // Got work to do! So do it, and then revert to 'Ready' state. + DCHECK(task_.load()); + (*task_).Run(); + task_ = nullptr; + ChangeState(State::Ready); + break; + case State::ExitAsSoonAsPossible: + return; + case State::Ready: + case State::ThreadStartup: + default: + abort(); + } + } + } + + static void* ThreadFunc(void* arg) { + static_cast(arg)->ThreadFunc(); + return nullptr; + } + + // Called by the master thread to give this worker work to do. + // It is only legal to call this if the worker + void StartWork(Task* task) { + DCHECK(!task_.load()); + task_ = task; + DCHECK(state_.load(std::memory_order_acquire) == State::Ready); + ChangeState(State::HasWork); + } + + private: + // The underlying thread. + std::unique_ptr thread_; + + // The task to be worked on. + std::atomic task_; + + // The condition variable and mutex guarding state changes. + std::condition_variable state_cond_; + std::mutex state_mutex_; + + // The state enum tells if we're currently working, waiting for work, etc. + std::atomic state_; + + // pointer to the master's thread BlockingCounter object, to notify the + // master thread of when this worker switches to the 'Ready' state. + BlockingCounter* const counter_to_decrement_when_ready_; +}; + +class WorkersPool { + public: + WorkersPool() = default; + + void Execute(const std::vector>& tasks) { + CAFFE_ENFORCE_GE(tasks.size(), 1); + // One of the tasks will be run on the current thread. + int workers_count = tasks.size() - 1; + CreateWorkers(workers_count); + TORCH_DCHECK_LE(workers_count, (int)workers_.size()); + counter_to_decrement_when_ready_.Reset(workers_count); + for (const auto task : c10::irange(1, tasks.size())) { + workers_[task - 1]->StartWork(tasks[task].get()); + } + // Execute the remaining workload immediately on the current thread. + auto& task = tasks.front(); + task->Run(); + // Wait for the workers submitted above to finish. + counter_to_decrement_when_ready_.Wait(); + } + + private: + // Ensures that the pool has at least the given count of workers. + // If any new worker has to be created, this function waits for it to + // be ready. + void CreateWorkers(std::size_t workers_count) { + if (workers_.size() >= workers_count) { + return; + } + counter_to_decrement_when_ready_.Reset(workers_count - workers_.size()); + while (workers_.size() < workers_count) { + workers_.push_back(MakeAligned::make(&counter_to_decrement_when_ready_)); + } + counter_to_decrement_when_ready_.Wait(); + } + + C10_DISABLE_COPY_AND_ASSIGN(WorkersPool); + std::vector>> workers_; + // The BlockingCounter used to wait for the workers. + BlockingCounter counter_to_decrement_when_ready_; +}; +} // namespace caffe2 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/utils/threadpool/pthreadpool-cpp.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/utils/threadpool/pthreadpool-cpp.h new file mode 100644 index 0000000000000000000000000000000000000000..cb9a01d3bd2ec1bc12d5290b965f18d9bb0cbfb4 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/utils/threadpool/pthreadpool-cpp.h @@ -0,0 +1,60 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#ifdef USE_PTHREADPOOL + +#ifdef USE_INTERNAL_PTHREADPOOL_IMPL +#include +#else +#include +#endif + +#include +#include +#include + +namespace caffe2 { + +class PThreadPool final { + public: + explicit PThreadPool(size_t thread_count); + ~PThreadPool() = default; + + PThreadPool(const PThreadPool&) = delete; + PThreadPool& operator=(const PThreadPool&) = delete; + + PThreadPool(PThreadPool&&) = delete; + PThreadPool& operator=(PThreadPool&&) = delete; + + size_t get_thread_count() const; + void set_thread_count(size_t thread_count); + + // Run, in parallel, function fn(task_id) over task_id in range [0, range). + // This function is blocking. All input is processed by the time it returns. + void run(const std::function& fn, size_t range); + + private: + friend pthreadpool_t pthreadpool_(); + + private: + mutable std::mutex mutex_; + std::unique_ptr threadpool_; +}; + +// Return a singleton instance of PThreadPool for ATen/TH multithreading. +PThreadPool* pthreadpool(); +PThreadPool* pthreadpool(size_t thread_count); + +// Exposes the underlying implementation of PThreadPool. +// Only for use in external libraries so as to unify threading across +// internal (i.e. ATen, etc.) and external (e.g. NNPACK, QNNPACK, XNNPACK) +// use cases. +pthreadpool_t pthreadpool_(); + +} // namespace caffe2 + +#endif /* USE_PTHREADPOOL */ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/utils/threadpool/pthreadpool.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/utils/threadpool/pthreadpool.h new file mode 100644 index 0000000000000000000000000000000000000000..ff7ff896b589dff1f51abd155e685fe2ee231750 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/utils/threadpool/pthreadpool.h @@ -0,0 +1,198 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// pthreadpool header from https://github.com/Maratyszcza/pthreadpool +// for NNPACK +#ifndef CAFFE2_UTILS_PTHREADPOOL_H_ +#define CAFFE2_UTILS_PTHREADPOOL_H_ + +#include "ThreadPoolCommon.h" + +#include // for size_t +#include // for uint32_t + +#if defined(USE_PTHREADPOOL) +// This is a hack. +// Mainly introduced here because +// 1. NNPACK can be compiled to use internal legacy threadpool implementation because much of C2 depends on that. +// 2. Then if we want to use NNPACK in PyTorch, which uses new pthreadpool, then we will supply new pthreadpool pointer +// to NNPACK. This will not work if NNPACK is compiled with internal legacy threadpool. Thus this guard +// along with changes in pthreadpool_impl.cc allows us to override that behavior. +// It enables us to use NNPACK from pytorch using `caffe2::pthreadpool_()` +namespace caffe2 { +class WithCastToNewThreadPool { + public: + explicit WithCastToNewThreadPool(bool use_new_threadpool); + ~WithCastToNewThreadPool(); + private: + bool use_new_threadpool_; +}; +} +#endif + +typedef struct pthreadpool* legacy_pthreadpool_t; + +typedef void (*legacy_pthreadpool_function_1d_t)(void*, size_t); +typedef void (*legacy_pthreadpool_function_1d_tiled_t)(void*, size_t, size_t); +typedef void (*legacy_pthreadpool_function_2d_t)(void*, size_t, size_t); +typedef void (*legacy_pthreadpool_function_2d_tiled_t)(void*, size_t, size_t, size_t, size_t); +typedef void (*legacy_pthreadpool_function_3d_tiled_t)( + void*, + size_t, + size_t, + size_t, + size_t, + size_t, + size_t); +typedef void (*legacy_pthreadpool_function_4d_tiled_t)( + void*, + size_t, + size_t, + size_t, + size_t, + size_t, + size_t, + size_t, + size_t); + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * Creates a thread pool with the specified number of threads. + * + * @param[in] threads_count The number of threads in the thread pool. + * A value of 0 has special interpretation: it creates a thread for each + * processor core available in the system. + * + * @returns A pointer to an opaque thread pool object. + * On error the function returns NULL and sets errno accordingly. + */ + +// Returns internal threadpool impl. +legacy_pthreadpool_t legacy_pthreadpool_create(size_t threads_count); + +/** + * Queries the number of threads in a thread pool. + * + * @param[in] threadpool The thread pool to query. + * + * @returns The number of threads in the thread pool. + */ +size_t legacy_pthreadpool_get_threads_count(legacy_pthreadpool_t threadpool); + +/** + * Processes items in parallel using threads from a thread pool. + * + * When the call returns, all items have been processed and the thread pool is + * ready for a new task. + * + * @note If multiple threads call this function with the same thread pool, the + * calls are serialized. + * + * @param[in] threadpool The thread pool to use for parallelisation. + * @param[in] function The function to call for each item. + * @param[in] argument The first argument passed to the @a function. + * @param[in] items The number of items to process. The @a function + * will be called once for each item. + */ +void legacy_pthreadpool_compute_1d( + legacy_pthreadpool_t threadpool, + legacy_pthreadpool_function_1d_t function, + void* argument, + size_t range); + +void legacy_pthreadpool_parallelize_1d( + legacy_pthreadpool_t threadpool, + legacy_pthreadpool_function_1d_t function, + void* argument, + size_t range, + uint32_t flags); + +void legacy_pthreadpool_compute_1d_tiled( + legacy_pthreadpool_t threadpool, + legacy_pthreadpool_function_1d_tiled_t function, + void* argument, + size_t range, + size_t tile); + +void legacy_pthreadpool_compute_2d( + legacy_pthreadpool_t threadpool, + legacy_pthreadpool_function_2d_t function, + void* argument, + size_t range_i, + size_t range_j); + +void legacy_pthreadpool_compute_2d_tiled( + legacy_pthreadpool_t threadpool, + legacy_pthreadpool_function_2d_tiled_t function, + void* argument, + size_t range_i, + size_t range_j, + size_t tile_i, + size_t tile_j); + +void legacy_pthreadpool_compute_3d_tiled( + legacy_pthreadpool_t threadpool, + legacy_pthreadpool_function_3d_tiled_t function, + void* argument, + size_t range_i, + size_t range_j, + size_t range_k, + size_t tile_i, + size_t tile_j, + size_t tile_k); + +void legacy_pthreadpool_compute_4d_tiled( + legacy_pthreadpool_t threadpool, + legacy_pthreadpool_function_4d_tiled_t function, + void* argument, + size_t range_i, + size_t range_j, + size_t range_k, + size_t range_l, + size_t tile_i, + size_t tile_j, + size_t tile_k, + size_t tile_l); + +/** + * Terminates threads in the thread pool and releases associated resources. + * + * @warning Accessing the thread pool after a call to this function constitutes + * undefined behaviour and may cause data corruption. + * + * @param[in,out] threadpool The thread pool to destroy. + */ +void legacy_pthreadpool_destroy(legacy_pthreadpool_t threadpool); + +#ifdef USE_INTERNAL_PTHREADPOOL_IMPL + +#define pthreadpool_t legacy_pthreadpool_t +#define pthreadpool_function_1d_t legacy_pthreadpool_function_1d_t +#define pthreadpool_function_1d_tiled_t legacy_pthreadpool_function_1d_tiled_t +#define pthreadpool_function_2d_t legacy_pthreadpool_function_2d_t +#define pthreadpool_function_2d_tiled_t legacy_pthreadpool_function_2d_tiled_t +#define pthreadpool_function_3d_tiled_t legacy_pthreadpool_function_3d_tiled_t +#define pthreadpool_function_4d_tiled_t legacy_pthreadpool_function_4d_tiled_t +#define pthreadpool_create legacy_pthreadpool_create +#define pthreadpool_destroy legacy_pthreadpool_destroy +#define pthreadpool_get_threads_count legacy_pthreadpool_get_threads_count +#define pthreadpool_compute_1d legacy_pthreadpool_compute_1d +#define pthreadpool_parallelize_1d legacy_pthreadpool_parallelize_1d +#define pthreadpool_compute_1d_tiled legacy_pthreadpool_compute_1d_tiled +#define pthreadpool_compute_2d legacy_pthreadpool_compute_2d +#define pthreadpool_compute_2d_tiled legacy_pthreadpool_compute_2d_tiled +#define pthreadpool_compute_3d_tiled legacy_pthreadpool_compute_3d_tiled +#define pthreadpool_compute_4d_tiled legacy_pthreadpool_compute_4d_tiled + +#endif /* USE_INTERNAL_PTHREADPOOL_IMPL */ + +#ifdef __cplusplus +} /* extern "C" */ +#endif + +#endif // CAFFE2_UTILS_PTHREADPOOL_H_ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/utils/threadpool/thread_pool_guard.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/utils/threadpool/thread_pool_guard.h new file mode 100644 index 0000000000000000000000000000000000000000..cb76646e6f61bdc1540bacd7dcbf88b4aa09b5f4 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/caffe2/utils/threadpool/thread_pool_guard.h @@ -0,0 +1,28 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include + +namespace caffe2 { + +// A RAII, thread local (!) guard that enables or disables grad mode upon +// construction, and sets it back to the original value upon destruction. +struct TORCH_API _NoPThreadPoolGuard { + static bool is_enabled(); + static void set_enabled(bool enabled); + + _NoPThreadPoolGuard(): prev_mode_(_NoPThreadPoolGuard::is_enabled()) { + _NoPThreadPoolGuard::set_enabled(true); + } + ~_NoPThreadPoolGuard() { + _NoPThreadPoolGuard::set_enabled(prev_mode_); + } + private: + bool prev_mode_; +}; + +} + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/any.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/any.h new file mode 100644 index 0000000000000000000000000000000000000000..1bb6306e75cda733f20e34123e0f6720aefb8a06 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/any.h @@ -0,0 +1,154 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef GOOGLE_PROTOBUF_ANY_H__ +#define GOOGLE_PROTOBUF_ANY_H__ + +#include + +#include +#include +#include + +#include + +namespace google { +namespace protobuf { + +class FieldDescriptor; +class Message; + +namespace internal { + +extern const char kAnyFullTypeName[]; // "google.protobuf.Any". +extern const char kTypeGoogleApisComPrefix[]; // "type.googleapis.com/". +extern const char kTypeGoogleProdComPrefix[]; // "type.googleprod.com/". + +std::string GetTypeUrl(StringPiece message_name, + StringPiece type_url_prefix); + +// Helper class used to implement google::protobuf::Any. +class PROTOBUF_EXPORT AnyMetadata { + typedef ArenaStringPtr UrlType; + typedef ArenaStringPtr ValueType; + public: + // AnyMetadata does not take ownership of "type_url" and "value". + AnyMetadata(UrlType* type_url, ValueType* value); + + // Packs a message using the default type URL prefix: "type.googleapis.com". + // The resulted type URL will be "type.googleapis.com/". + template + void PackFrom(const T& message) { + InternalPackFrom(message, kTypeGoogleApisComPrefix, T::FullMessageName()); + } + + void PackFrom(const Message& message); + + // Packs a message using the given type URL prefix. The type URL will be + // constructed by concatenating the message type's full name to the prefix + // with an optional "/" separator if the prefix doesn't already end with "/". + // For example, both PackFrom(message, "type.googleapis.com") and + // PackFrom(message, "type.googleapis.com/") yield the same result type + // URL: "type.googleapis.com/". + template + void PackFrom(const T& message, StringPiece type_url_prefix) { + InternalPackFrom(message, type_url_prefix, T::FullMessageName()); + } + + void PackFrom(const Message& message, const std::string& type_url_prefix); + + // Unpacks the payload into the given message. Returns false if the message's + // type doesn't match the type specified in the type URL (i.e., the full + // name after the last "/" of the type URL doesn't match the message's actual + // full name) or parsing the payload has failed. + template + bool UnpackTo(T* message) const { + return InternalUnpackTo(T::FullMessageName(), message); + } + + bool UnpackTo(Message* message) const; + + // Checks whether the type specified in the type URL matches the given type. + // A type is considered matching if its full name matches the full name after + // the last "/" in the type URL. + template + bool Is() const { + return InternalIs(T::FullMessageName()); + } + + private: + void InternalPackFrom(const MessageLite& message, + StringPiece type_url_prefix, + StringPiece type_name); + bool InternalUnpackTo(StringPiece type_name, + MessageLite* message) const; + bool InternalIs(StringPiece type_name) const; + + UrlType* type_url_; + ValueType* value_; + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(AnyMetadata); +}; + +// Get the proto type name from Any::type_url value. For example, passing +// "type.googleapis.com/rpc.QueryOrigin" will return "rpc.QueryOrigin" in +// *full_type_name. Returns false if the type_url does not have a "/" +// in the type url separating the full type name. +// +// NOTE: this function is available publicly as: +// google::protobuf::Any() // static method on the generated message type. +bool ParseAnyTypeUrl(const std::string& type_url, std::string* full_type_name); + +// Get the proto type name and prefix from Any::type_url value. For example, +// passing "type.googleapis.com/rpc.QueryOrigin" will return +// "type.googleapis.com/" in *url_prefix and "rpc.QueryOrigin" in +// *full_type_name. Returns false if the type_url does not have a "/" in the +// type url separating the full type name. +bool ParseAnyTypeUrl(const std::string& type_url, std::string* url_prefix, + std::string* full_type_name); + +// See if message is of type google.protobuf.Any, if so, return the descriptors +// for "type_url" and "value" fields. +bool GetAnyFieldDescriptors(const Message& message, + const FieldDescriptor** type_url_field, + const FieldDescriptor** value_field); + +} // namespace internal +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_ANY_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/any.pb.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/any.pb.h new file mode 100644 index 0000000000000000000000000000000000000000..7d54052e15455f00eb41246585ecd9e0470508e5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/any.pb.h @@ -0,0 +1,414 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: google/protobuf/any.proto + +#ifndef GOOGLE_PROTOBUF_INCLUDED_google_2fprotobuf_2fany_2eproto +#define GOOGLE_PROTOBUF_INCLUDED_google_2fprotobuf_2fany_2eproto + +#include +#include + +#include +#if PROTOBUF_VERSION < 3013000 +#error This file was generated by a newer version of protoc which is +#error incompatible with your Protocol Buffer headers. Please update +#error your headers. +#endif +#if 3013000 < PROTOBUF_MIN_PROTOC_VERSION +#error This file was generated by an older version of protoc which is +#error incompatible with your Protocol Buffer headers. Please +#error regenerate this file with a newer version of protoc. +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include // IWYU pragma: export +#include // IWYU pragma: export +#include +// @@protoc_insertion_point(includes) +#include +#define PROTOBUF_INTERNAL_EXPORT_google_2fprotobuf_2fany_2eproto PROTOBUF_EXPORT +PROTOBUF_NAMESPACE_OPEN +namespace internal { +class AnyMetadata; +} // namespace internal +PROTOBUF_NAMESPACE_CLOSE + +// Internal implementation detail -- do not use these members. +struct PROTOBUF_EXPORT TableStruct_google_2fprotobuf_2fany_2eproto { + static const ::PROTOBUF_NAMESPACE_ID::internal::ParseTableField entries[] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::AuxiliaryParseTableField aux[] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::ParseTable schema[1] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::FieldMetadata field_metadata[]; + static const ::PROTOBUF_NAMESPACE_ID::internal::SerializationTable serialization_table[]; + static const ::PROTOBUF_NAMESPACE_ID::uint32 offsets[]; +}; +extern PROTOBUF_EXPORT const ::PROTOBUF_NAMESPACE_ID::internal::DescriptorTable descriptor_table_google_2fprotobuf_2fany_2eproto; +PROTOBUF_NAMESPACE_OPEN +class Any; +class AnyDefaultTypeInternal; +PROTOBUF_EXPORT extern AnyDefaultTypeInternal _Any_default_instance_; +PROTOBUF_NAMESPACE_CLOSE +PROTOBUF_NAMESPACE_OPEN +template<> PROTOBUF_EXPORT PROTOBUF_NAMESPACE_ID::Any* Arena::CreateMaybeMessage(Arena*); +PROTOBUF_NAMESPACE_CLOSE +PROTOBUF_NAMESPACE_OPEN + +// =================================================================== + +class PROTOBUF_EXPORT Any PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:google.protobuf.Any) */ { + public: + inline Any() : Any(nullptr) {} + virtual ~Any(); + + Any(const Any& from); + Any(Any&& from) noexcept + : Any() { + *this = ::std::move(from); + } + + inline Any& operator=(const Any& from) { + CopyFrom(from); + return *this; + } + inline Any& operator=(Any&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const Any& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const Any* internal_default_instance() { + return reinterpret_cast( + &_Any_default_instance_); + } + static constexpr int kIndexInFileMessages = + 0; + + // implements Any ----------------------------------------------- + + void PackFrom(const ::PROTOBUF_NAMESPACE_ID::Message& message) { + _any_metadata_.PackFrom(message); + } + void PackFrom(const ::PROTOBUF_NAMESPACE_ID::Message& message, + const std::string& type_url_prefix) { + _any_metadata_.PackFrom(message, type_url_prefix); + } + bool UnpackTo(::PROTOBUF_NAMESPACE_ID::Message* message) const { + return _any_metadata_.UnpackTo(message); + } + static bool GetAnyFieldDescriptors( + const ::PROTOBUF_NAMESPACE_ID::Message& message, + const ::PROTOBUF_NAMESPACE_ID::FieldDescriptor** type_url_field, + const ::PROTOBUF_NAMESPACE_ID::FieldDescriptor** value_field); + template ::value>::type> + void PackFrom(const T& message) { + _any_metadata_.PackFrom(message); + } + template ::value>::type> + void PackFrom(const T& message, + const std::string& type_url_prefix) { + _any_metadata_.PackFrom(message, type_url_prefix);} + template ::value>::type> + bool UnpackTo(T* message) const { + return _any_metadata_.UnpackTo(message); + } + template bool Is() const { + return _any_metadata_.Is(); + } + static bool ParseAnyTypeUrl(const string& type_url, + std::string* full_type_name); + friend void swap(Any& a, Any& b) { + a.Swap(&b); + } + inline void Swap(Any* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(Any* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline Any* New() const final { + return CreateMaybeMessage(nullptr); + } + + Any* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const Any& from); + void MergeFrom(const Any& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(Any* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "google.protobuf.Any"; + } + protected: + explicit Any(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_google_2fprotobuf_2fany_2eproto); + return ::descriptor_table_google_2fprotobuf_2fany_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kTypeUrlFieldNumber = 1, + kValueFieldNumber = 2, + }; + // string type_url = 1; + void clear_type_url(); + const std::string& type_url() const; + void set_type_url(const std::string& value); + void set_type_url(std::string&& value); + void set_type_url(const char* value); + void set_type_url(const char* value, size_t size); + std::string* mutable_type_url(); + std::string* release_type_url(); + void set_allocated_type_url(std::string* type_url); + private: + const std::string& _internal_type_url() const; + void _internal_set_type_url(const std::string& value); + std::string* _internal_mutable_type_url(); + public: + + // bytes value = 2; + void clear_value(); + const std::string& value() const; + void set_value(const std::string& value); + void set_value(std::string&& value); + void set_value(const char* value); + void set_value(const void* value, size_t size); + std::string* mutable_value(); + std::string* release_value(); + void set_allocated_value(std::string* value); + private: + const std::string& _internal_value() const; + void _internal_set_value(const std::string& value); + std::string* _internal_mutable_value(); + public: + + // @@protoc_insertion_point(class_scope:google.protobuf.Any) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr type_url_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr value_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata _any_metadata_; + friend struct ::TableStruct_google_2fprotobuf_2fany_2eproto; +}; +// =================================================================== + + +// =================================================================== + +#ifdef __GNUC__ + #pragma GCC diagnostic push + #pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif // __GNUC__ +// Any + +// string type_url = 1; +inline void Any::clear_type_url() { + type_url_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline const std::string& Any::type_url() const { + // @@protoc_insertion_point(field_get:google.protobuf.Any.type_url) + return _internal_type_url(); +} +inline void Any::set_type_url(const std::string& value) { + _internal_set_type_url(value); + // @@protoc_insertion_point(field_set:google.protobuf.Any.type_url) +} +inline std::string* Any::mutable_type_url() { + // @@protoc_insertion_point(field_mutable:google.protobuf.Any.type_url) + return _internal_mutable_type_url(); +} +inline const std::string& Any::_internal_type_url() const { + return type_url_.Get(); +} +inline void Any::_internal_set_type_url(const std::string& value) { + + type_url_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void Any::set_type_url(std::string&& value) { + + type_url_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:google.protobuf.Any.type_url) +} +inline void Any::set_type_url(const char* value) { + GOOGLE_DCHECK(value != nullptr); + + type_url_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:google.protobuf.Any.type_url) +} +inline void Any::set_type_url(const char* value, + size_t size) { + + type_url_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:google.protobuf.Any.type_url) +} +inline std::string* Any::_internal_mutable_type_url() { + + return type_url_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* Any::release_type_url() { + // @@protoc_insertion_point(field_release:google.protobuf.Any.type_url) + return type_url_.Release(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void Any::set_allocated_type_url(std::string* type_url) { + if (type_url != nullptr) { + + } else { + + } + type_url_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), type_url, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:google.protobuf.Any.type_url) +} + +// bytes value = 2; +inline void Any::clear_value() { + value_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline const std::string& Any::value() const { + // @@protoc_insertion_point(field_get:google.protobuf.Any.value) + return _internal_value(); +} +inline void Any::set_value(const std::string& value) { + _internal_set_value(value); + // @@protoc_insertion_point(field_set:google.protobuf.Any.value) +} +inline std::string* Any::mutable_value() { + // @@protoc_insertion_point(field_mutable:google.protobuf.Any.value) + return _internal_mutable_value(); +} +inline const std::string& Any::_internal_value() const { + return value_.Get(); +} +inline void Any::_internal_set_value(const std::string& value) { + + value_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void Any::set_value(std::string&& value) { + + value_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:google.protobuf.Any.value) +} +inline void Any::set_value(const char* value) { + GOOGLE_DCHECK(value != nullptr); + + value_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:google.protobuf.Any.value) +} +inline void Any::set_value(const void* value, + size_t size) { + + value_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:google.protobuf.Any.value) +} +inline std::string* Any::_internal_mutable_value() { + + return value_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* Any::release_value() { + // @@protoc_insertion_point(field_release:google.protobuf.Any.value) + return value_.Release(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void Any::set_allocated_value(std::string* value) { + if (value != nullptr) { + + } else { + + } + value_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:google.protobuf.Any.value) +} + +#ifdef __GNUC__ + #pragma GCC diagnostic pop +#endif // __GNUC__ + +// @@protoc_insertion_point(namespace_scope) + +PROTOBUF_NAMESPACE_CLOSE + +// @@protoc_insertion_point(global_scope) + +#include +#endif // GOOGLE_PROTOBUF_INCLUDED_GOOGLE_PROTOBUF_INCLUDED_google_2fprotobuf_2fany_2eproto + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/api.pb.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/api.pb.h new file mode 100644 index 0000000000000000000000000000000000000000..5b5c902661b1330f34e1ad49c3e7d291d895bda5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/api.pb.h @@ -0,0 +1,1505 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: google/protobuf/api.proto + +#ifndef GOOGLE_PROTOBUF_INCLUDED_google_2fprotobuf_2fapi_2eproto +#define GOOGLE_PROTOBUF_INCLUDED_google_2fprotobuf_2fapi_2eproto + +#include +#include + +#include +#if PROTOBUF_VERSION < 3013000 +#error This file was generated by a newer version of protoc which is +#error incompatible with your Protocol Buffer headers. Please update +#error your headers. +#endif +#if 3013000 < PROTOBUF_MIN_PROTOC_VERSION +#error This file was generated by an older version of protoc which is +#error incompatible with your Protocol Buffer headers. Please +#error regenerate this file with a newer version of protoc. +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include // IWYU pragma: export +#include // IWYU pragma: export +#include +#include +#include +// @@protoc_insertion_point(includes) +#include +#define PROTOBUF_INTERNAL_EXPORT_google_2fprotobuf_2fapi_2eproto PROTOBUF_EXPORT +PROTOBUF_NAMESPACE_OPEN +namespace internal { +class AnyMetadata; +} // namespace internal +PROTOBUF_NAMESPACE_CLOSE + +// Internal implementation detail -- do not use these members. +struct PROTOBUF_EXPORT TableStruct_google_2fprotobuf_2fapi_2eproto { + static const ::PROTOBUF_NAMESPACE_ID::internal::ParseTableField entries[] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::AuxiliaryParseTableField aux[] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::ParseTable schema[3] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::FieldMetadata field_metadata[]; + static const ::PROTOBUF_NAMESPACE_ID::internal::SerializationTable serialization_table[]; + static const ::PROTOBUF_NAMESPACE_ID::uint32 offsets[]; +}; +extern PROTOBUF_EXPORT const ::PROTOBUF_NAMESPACE_ID::internal::DescriptorTable descriptor_table_google_2fprotobuf_2fapi_2eproto; +PROTOBUF_NAMESPACE_OPEN +class Api; +class ApiDefaultTypeInternal; +PROTOBUF_EXPORT extern ApiDefaultTypeInternal _Api_default_instance_; +class Method; +class MethodDefaultTypeInternal; +PROTOBUF_EXPORT extern MethodDefaultTypeInternal _Method_default_instance_; +class Mixin; +class MixinDefaultTypeInternal; +PROTOBUF_EXPORT extern MixinDefaultTypeInternal _Mixin_default_instance_; +PROTOBUF_NAMESPACE_CLOSE +PROTOBUF_NAMESPACE_OPEN +template<> PROTOBUF_EXPORT PROTOBUF_NAMESPACE_ID::Api* Arena::CreateMaybeMessage(Arena*); +template<> PROTOBUF_EXPORT PROTOBUF_NAMESPACE_ID::Method* Arena::CreateMaybeMessage(Arena*); +template<> PROTOBUF_EXPORT PROTOBUF_NAMESPACE_ID::Mixin* Arena::CreateMaybeMessage(Arena*); +PROTOBUF_NAMESPACE_CLOSE +PROTOBUF_NAMESPACE_OPEN + +// =================================================================== + +class PROTOBUF_EXPORT Api PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:google.protobuf.Api) */ { + public: + inline Api() : Api(nullptr) {} + virtual ~Api(); + + Api(const Api& from); + Api(Api&& from) noexcept + : Api() { + *this = ::std::move(from); + } + + inline Api& operator=(const Api& from) { + CopyFrom(from); + return *this; + } + inline Api& operator=(Api&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const Api& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const Api* internal_default_instance() { + return reinterpret_cast( + &_Api_default_instance_); + } + static constexpr int kIndexInFileMessages = + 0; + + friend void swap(Api& a, Api& b) { + a.Swap(&b); + } + inline void Swap(Api* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(Api* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline Api* New() const final { + return CreateMaybeMessage(nullptr); + } + + Api* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const Api& from); + void MergeFrom(const Api& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(Api* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "google.protobuf.Api"; + } + protected: + explicit Api(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_google_2fprotobuf_2fapi_2eproto); + return ::descriptor_table_google_2fprotobuf_2fapi_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kMethodsFieldNumber = 2, + kOptionsFieldNumber = 3, + kMixinsFieldNumber = 6, + kNameFieldNumber = 1, + kVersionFieldNumber = 4, + kSourceContextFieldNumber = 5, + kSyntaxFieldNumber = 7, + }; + // repeated .google.protobuf.Method methods = 2; + int methods_size() const; + private: + int _internal_methods_size() const; + public: + void clear_methods(); + PROTOBUF_NAMESPACE_ID::Method* mutable_methods(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::Method >* + mutable_methods(); + private: + const PROTOBUF_NAMESPACE_ID::Method& _internal_methods(int index) const; + PROTOBUF_NAMESPACE_ID::Method* _internal_add_methods(); + public: + const PROTOBUF_NAMESPACE_ID::Method& methods(int index) const; + PROTOBUF_NAMESPACE_ID::Method* add_methods(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::Method >& + methods() const; + + // repeated .google.protobuf.Option options = 3; + int options_size() const; + private: + int _internal_options_size() const; + public: + void clear_options(); + PROTOBUF_NAMESPACE_ID::Option* mutable_options(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::Option >* + mutable_options(); + private: + const PROTOBUF_NAMESPACE_ID::Option& _internal_options(int index) const; + PROTOBUF_NAMESPACE_ID::Option* _internal_add_options(); + public: + const PROTOBUF_NAMESPACE_ID::Option& options(int index) const; + PROTOBUF_NAMESPACE_ID::Option* add_options(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::Option >& + options() const; + + // repeated .google.protobuf.Mixin mixins = 6; + int mixins_size() const; + private: + int _internal_mixins_size() const; + public: + void clear_mixins(); + PROTOBUF_NAMESPACE_ID::Mixin* mutable_mixins(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::Mixin >* + mutable_mixins(); + private: + const PROTOBUF_NAMESPACE_ID::Mixin& _internal_mixins(int index) const; + PROTOBUF_NAMESPACE_ID::Mixin* _internal_add_mixins(); + public: + const PROTOBUF_NAMESPACE_ID::Mixin& mixins(int index) const; + PROTOBUF_NAMESPACE_ID::Mixin* add_mixins(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::Mixin >& + mixins() const; + + // string name = 1; + void clear_name(); + const std::string& name() const; + void set_name(const std::string& value); + void set_name(std::string&& value); + void set_name(const char* value); + void set_name(const char* value, size_t size); + std::string* mutable_name(); + std::string* release_name(); + void set_allocated_name(std::string* name); + private: + const std::string& _internal_name() const; + void _internal_set_name(const std::string& value); + std::string* _internal_mutable_name(); + public: + + // string version = 4; + void clear_version(); + const std::string& version() const; + void set_version(const std::string& value); + void set_version(std::string&& value); + void set_version(const char* value); + void set_version(const char* value, size_t size); + std::string* mutable_version(); + std::string* release_version(); + void set_allocated_version(std::string* version); + private: + const std::string& _internal_version() const; + void _internal_set_version(const std::string& value); + std::string* _internal_mutable_version(); + public: + + // .google.protobuf.SourceContext source_context = 5; + bool has_source_context() const; + private: + bool _internal_has_source_context() const; + public: + void clear_source_context(); + const PROTOBUF_NAMESPACE_ID::SourceContext& source_context() const; + PROTOBUF_NAMESPACE_ID::SourceContext* release_source_context(); + PROTOBUF_NAMESPACE_ID::SourceContext* mutable_source_context(); + void set_allocated_source_context(PROTOBUF_NAMESPACE_ID::SourceContext* source_context); + private: + const PROTOBUF_NAMESPACE_ID::SourceContext& _internal_source_context() const; + PROTOBUF_NAMESPACE_ID::SourceContext* _internal_mutable_source_context(); + public: + void unsafe_arena_set_allocated_source_context( + PROTOBUF_NAMESPACE_ID::SourceContext* source_context); + PROTOBUF_NAMESPACE_ID::SourceContext* unsafe_arena_release_source_context(); + + // .google.protobuf.Syntax syntax = 7; + void clear_syntax(); + PROTOBUF_NAMESPACE_ID::Syntax syntax() const; + void set_syntax(PROTOBUF_NAMESPACE_ID::Syntax value); + private: + PROTOBUF_NAMESPACE_ID::Syntax _internal_syntax() const; + void _internal_set_syntax(PROTOBUF_NAMESPACE_ID::Syntax value); + public: + + // @@protoc_insertion_point(class_scope:google.protobuf.Api) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::Method > methods_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::Option > options_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::Mixin > mixins_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr name_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr version_; + PROTOBUF_NAMESPACE_ID::SourceContext* source_context_; + int syntax_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + friend struct ::TableStruct_google_2fprotobuf_2fapi_2eproto; +}; +// ------------------------------------------------------------------- + +class PROTOBUF_EXPORT Method PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:google.protobuf.Method) */ { + public: + inline Method() : Method(nullptr) {} + virtual ~Method(); + + Method(const Method& from); + Method(Method&& from) noexcept + : Method() { + *this = ::std::move(from); + } + + inline Method& operator=(const Method& from) { + CopyFrom(from); + return *this; + } + inline Method& operator=(Method&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const Method& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const Method* internal_default_instance() { + return reinterpret_cast( + &_Method_default_instance_); + } + static constexpr int kIndexInFileMessages = + 1; + + friend void swap(Method& a, Method& b) { + a.Swap(&b); + } + inline void Swap(Method* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(Method* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline Method* New() const final { + return CreateMaybeMessage(nullptr); + } + + Method* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const Method& from); + void MergeFrom(const Method& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(Method* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "google.protobuf.Method"; + } + protected: + explicit Method(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_google_2fprotobuf_2fapi_2eproto); + return ::descriptor_table_google_2fprotobuf_2fapi_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kOptionsFieldNumber = 6, + kNameFieldNumber = 1, + kRequestTypeUrlFieldNumber = 2, + kResponseTypeUrlFieldNumber = 4, + kRequestStreamingFieldNumber = 3, + kResponseStreamingFieldNumber = 5, + kSyntaxFieldNumber = 7, + }; + // repeated .google.protobuf.Option options = 6; + int options_size() const; + private: + int _internal_options_size() const; + public: + void clear_options(); + PROTOBUF_NAMESPACE_ID::Option* mutable_options(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::Option >* + mutable_options(); + private: + const PROTOBUF_NAMESPACE_ID::Option& _internal_options(int index) const; + PROTOBUF_NAMESPACE_ID::Option* _internal_add_options(); + public: + const PROTOBUF_NAMESPACE_ID::Option& options(int index) const; + PROTOBUF_NAMESPACE_ID::Option* add_options(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::Option >& + options() const; + + // string name = 1; + void clear_name(); + const std::string& name() const; + void set_name(const std::string& value); + void set_name(std::string&& value); + void set_name(const char* value); + void set_name(const char* value, size_t size); + std::string* mutable_name(); + std::string* release_name(); + void set_allocated_name(std::string* name); + private: + const std::string& _internal_name() const; + void _internal_set_name(const std::string& value); + std::string* _internal_mutable_name(); + public: + + // string request_type_url = 2; + void clear_request_type_url(); + const std::string& request_type_url() const; + void set_request_type_url(const std::string& value); + void set_request_type_url(std::string&& value); + void set_request_type_url(const char* value); + void set_request_type_url(const char* value, size_t size); + std::string* mutable_request_type_url(); + std::string* release_request_type_url(); + void set_allocated_request_type_url(std::string* request_type_url); + private: + const std::string& _internal_request_type_url() const; + void _internal_set_request_type_url(const std::string& value); + std::string* _internal_mutable_request_type_url(); + public: + + // string response_type_url = 4; + void clear_response_type_url(); + const std::string& response_type_url() const; + void set_response_type_url(const std::string& value); + void set_response_type_url(std::string&& value); + void set_response_type_url(const char* value); + void set_response_type_url(const char* value, size_t size); + std::string* mutable_response_type_url(); + std::string* release_response_type_url(); + void set_allocated_response_type_url(std::string* response_type_url); + private: + const std::string& _internal_response_type_url() const; + void _internal_set_response_type_url(const std::string& value); + std::string* _internal_mutable_response_type_url(); + public: + + // bool request_streaming = 3; + void clear_request_streaming(); + bool request_streaming() const; + void set_request_streaming(bool value); + private: + bool _internal_request_streaming() const; + void _internal_set_request_streaming(bool value); + public: + + // bool response_streaming = 5; + void clear_response_streaming(); + bool response_streaming() const; + void set_response_streaming(bool value); + private: + bool _internal_response_streaming() const; + void _internal_set_response_streaming(bool value); + public: + + // .google.protobuf.Syntax syntax = 7; + void clear_syntax(); + PROTOBUF_NAMESPACE_ID::Syntax syntax() const; + void set_syntax(PROTOBUF_NAMESPACE_ID::Syntax value); + private: + PROTOBUF_NAMESPACE_ID::Syntax _internal_syntax() const; + void _internal_set_syntax(PROTOBUF_NAMESPACE_ID::Syntax value); + public: + + // @@protoc_insertion_point(class_scope:google.protobuf.Method) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::Option > options_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr name_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr request_type_url_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr response_type_url_; + bool request_streaming_; + bool response_streaming_; + int syntax_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + friend struct ::TableStruct_google_2fprotobuf_2fapi_2eproto; +}; +// ------------------------------------------------------------------- + +class PROTOBUF_EXPORT Mixin PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:google.protobuf.Mixin) */ { + public: + inline Mixin() : Mixin(nullptr) {} + virtual ~Mixin(); + + Mixin(const Mixin& from); + Mixin(Mixin&& from) noexcept + : Mixin() { + *this = ::std::move(from); + } + + inline Mixin& operator=(const Mixin& from) { + CopyFrom(from); + return *this; + } + inline Mixin& operator=(Mixin&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const Mixin& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const Mixin* internal_default_instance() { + return reinterpret_cast( + &_Mixin_default_instance_); + } + static constexpr int kIndexInFileMessages = + 2; + + friend void swap(Mixin& a, Mixin& b) { + a.Swap(&b); + } + inline void Swap(Mixin* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(Mixin* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline Mixin* New() const final { + return CreateMaybeMessage(nullptr); + } + + Mixin* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const Mixin& from); + void MergeFrom(const Mixin& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(Mixin* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "google.protobuf.Mixin"; + } + protected: + explicit Mixin(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_google_2fprotobuf_2fapi_2eproto); + return ::descriptor_table_google_2fprotobuf_2fapi_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kNameFieldNumber = 1, + kRootFieldNumber = 2, + }; + // string name = 1; + void clear_name(); + const std::string& name() const; + void set_name(const std::string& value); + void set_name(std::string&& value); + void set_name(const char* value); + void set_name(const char* value, size_t size); + std::string* mutable_name(); + std::string* release_name(); + void set_allocated_name(std::string* name); + private: + const std::string& _internal_name() const; + void _internal_set_name(const std::string& value); + std::string* _internal_mutable_name(); + public: + + // string root = 2; + void clear_root(); + const std::string& root() const; + void set_root(const std::string& value); + void set_root(std::string&& value); + void set_root(const char* value); + void set_root(const char* value, size_t size); + std::string* mutable_root(); + std::string* release_root(); + void set_allocated_root(std::string* root); + private: + const std::string& _internal_root() const; + void _internal_set_root(const std::string& value); + std::string* _internal_mutable_root(); + public: + + // @@protoc_insertion_point(class_scope:google.protobuf.Mixin) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr name_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr root_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + friend struct ::TableStruct_google_2fprotobuf_2fapi_2eproto; +}; +// =================================================================== + + +// =================================================================== + +#ifdef __GNUC__ + #pragma GCC diagnostic push + #pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif // __GNUC__ +// Api + +// string name = 1; +inline void Api::clear_name() { + name_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline const std::string& Api::name() const { + // @@protoc_insertion_point(field_get:google.protobuf.Api.name) + return _internal_name(); +} +inline void Api::set_name(const std::string& value) { + _internal_set_name(value); + // @@protoc_insertion_point(field_set:google.protobuf.Api.name) +} +inline std::string* Api::mutable_name() { + // @@protoc_insertion_point(field_mutable:google.protobuf.Api.name) + return _internal_mutable_name(); +} +inline const std::string& Api::_internal_name() const { + return name_.Get(); +} +inline void Api::_internal_set_name(const std::string& value) { + + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void Api::set_name(std::string&& value) { + + name_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:google.protobuf.Api.name) +} +inline void Api::set_name(const char* value) { + GOOGLE_DCHECK(value != nullptr); + + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:google.protobuf.Api.name) +} +inline void Api::set_name(const char* value, + size_t size) { + + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:google.protobuf.Api.name) +} +inline std::string* Api::_internal_mutable_name() { + + return name_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* Api::release_name() { + // @@protoc_insertion_point(field_release:google.protobuf.Api.name) + return name_.Release(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void Api::set_allocated_name(std::string* name) { + if (name != nullptr) { + + } else { + + } + name_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), name, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:google.protobuf.Api.name) +} + +// repeated .google.protobuf.Method methods = 2; +inline int Api::_internal_methods_size() const { + return methods_.size(); +} +inline int Api::methods_size() const { + return _internal_methods_size(); +} +inline void Api::clear_methods() { + methods_.Clear(); +} +inline PROTOBUF_NAMESPACE_ID::Method* Api::mutable_methods(int index) { + // @@protoc_insertion_point(field_mutable:google.protobuf.Api.methods) + return methods_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::Method >* +Api::mutable_methods() { + // @@protoc_insertion_point(field_mutable_list:google.protobuf.Api.methods) + return &methods_; +} +inline const PROTOBUF_NAMESPACE_ID::Method& Api::_internal_methods(int index) const { + return methods_.Get(index); +} +inline const PROTOBUF_NAMESPACE_ID::Method& Api::methods(int index) const { + // @@protoc_insertion_point(field_get:google.protobuf.Api.methods) + return _internal_methods(index); +} +inline PROTOBUF_NAMESPACE_ID::Method* Api::_internal_add_methods() { + return methods_.Add(); +} +inline PROTOBUF_NAMESPACE_ID::Method* Api::add_methods() { + // @@protoc_insertion_point(field_add:google.protobuf.Api.methods) + return _internal_add_methods(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::Method >& +Api::methods() const { + // @@protoc_insertion_point(field_list:google.protobuf.Api.methods) + return methods_; +} + +// repeated .google.protobuf.Option options = 3; +inline int Api::_internal_options_size() const { + return options_.size(); +} +inline int Api::options_size() const { + return _internal_options_size(); +} +inline PROTOBUF_NAMESPACE_ID::Option* Api::mutable_options(int index) { + // @@protoc_insertion_point(field_mutable:google.protobuf.Api.options) + return options_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::Option >* +Api::mutable_options() { + // @@protoc_insertion_point(field_mutable_list:google.protobuf.Api.options) + return &options_; +} +inline const PROTOBUF_NAMESPACE_ID::Option& Api::_internal_options(int index) const { + return options_.Get(index); +} +inline const PROTOBUF_NAMESPACE_ID::Option& Api::options(int index) const { + // @@protoc_insertion_point(field_get:google.protobuf.Api.options) + return _internal_options(index); +} +inline PROTOBUF_NAMESPACE_ID::Option* Api::_internal_add_options() { + return options_.Add(); +} +inline PROTOBUF_NAMESPACE_ID::Option* Api::add_options() { + // @@protoc_insertion_point(field_add:google.protobuf.Api.options) + return _internal_add_options(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::Option >& +Api::options() const { + // @@protoc_insertion_point(field_list:google.protobuf.Api.options) + return options_; +} + +// string version = 4; +inline void Api::clear_version() { + version_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline const std::string& Api::version() const { + // @@protoc_insertion_point(field_get:google.protobuf.Api.version) + return _internal_version(); +} +inline void Api::set_version(const std::string& value) { + _internal_set_version(value); + // @@protoc_insertion_point(field_set:google.protobuf.Api.version) +} +inline std::string* Api::mutable_version() { + // @@protoc_insertion_point(field_mutable:google.protobuf.Api.version) + return _internal_mutable_version(); +} +inline const std::string& Api::_internal_version() const { + return version_.Get(); +} +inline void Api::_internal_set_version(const std::string& value) { + + version_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void Api::set_version(std::string&& value) { + + version_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:google.protobuf.Api.version) +} +inline void Api::set_version(const char* value) { + GOOGLE_DCHECK(value != nullptr); + + version_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:google.protobuf.Api.version) +} +inline void Api::set_version(const char* value, + size_t size) { + + version_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:google.protobuf.Api.version) +} +inline std::string* Api::_internal_mutable_version() { + + return version_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* Api::release_version() { + // @@protoc_insertion_point(field_release:google.protobuf.Api.version) + return version_.Release(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void Api::set_allocated_version(std::string* version) { + if (version != nullptr) { + + } else { + + } + version_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), version, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:google.protobuf.Api.version) +} + +// .google.protobuf.SourceContext source_context = 5; +inline bool Api::_internal_has_source_context() const { + return this != internal_default_instance() && source_context_ != nullptr; +} +inline bool Api::has_source_context() const { + return _internal_has_source_context(); +} +inline const PROTOBUF_NAMESPACE_ID::SourceContext& Api::_internal_source_context() const { + const PROTOBUF_NAMESPACE_ID::SourceContext* p = source_context_; + return p != nullptr ? *p : *reinterpret_cast( + &PROTOBUF_NAMESPACE_ID::_SourceContext_default_instance_); +} +inline const PROTOBUF_NAMESPACE_ID::SourceContext& Api::source_context() const { + // @@protoc_insertion_point(field_get:google.protobuf.Api.source_context) + return _internal_source_context(); +} +inline void Api::unsafe_arena_set_allocated_source_context( + PROTOBUF_NAMESPACE_ID::SourceContext* source_context) { + if (GetArena() == nullptr) { + delete reinterpret_cast<::PROTOBUF_NAMESPACE_ID::MessageLite*>(source_context_); + } + source_context_ = source_context; + if (source_context) { + + } else { + + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:google.protobuf.Api.source_context) +} +inline PROTOBUF_NAMESPACE_ID::SourceContext* Api::release_source_context() { + + PROTOBUF_NAMESPACE_ID::SourceContext* temp = source_context_; + source_context_ = nullptr; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + return temp; +} +inline PROTOBUF_NAMESPACE_ID::SourceContext* Api::unsafe_arena_release_source_context() { + // @@protoc_insertion_point(field_release:google.protobuf.Api.source_context) + + PROTOBUF_NAMESPACE_ID::SourceContext* temp = source_context_; + source_context_ = nullptr; + return temp; +} +inline PROTOBUF_NAMESPACE_ID::SourceContext* Api::_internal_mutable_source_context() { + + if (source_context_ == nullptr) { + auto* p = CreateMaybeMessage(GetArena()); + source_context_ = p; + } + return source_context_; +} +inline PROTOBUF_NAMESPACE_ID::SourceContext* Api::mutable_source_context() { + // @@protoc_insertion_point(field_mutable:google.protobuf.Api.source_context) + return _internal_mutable_source_context(); +} +inline void Api::set_allocated_source_context(PROTOBUF_NAMESPACE_ID::SourceContext* source_context) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + if (message_arena == nullptr) { + delete reinterpret_cast< ::PROTOBUF_NAMESPACE_ID::MessageLite*>(source_context_); + } + if (source_context) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + reinterpret_cast<::PROTOBUF_NAMESPACE_ID::MessageLite*>(source_context)->GetArena(); + if (message_arena != submessage_arena) { + source_context = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, source_context, submessage_arena); + } + + } else { + + } + source_context_ = source_context; + // @@protoc_insertion_point(field_set_allocated:google.protobuf.Api.source_context) +} + +// repeated .google.protobuf.Mixin mixins = 6; +inline int Api::_internal_mixins_size() const { + return mixins_.size(); +} +inline int Api::mixins_size() const { + return _internal_mixins_size(); +} +inline void Api::clear_mixins() { + mixins_.Clear(); +} +inline PROTOBUF_NAMESPACE_ID::Mixin* Api::mutable_mixins(int index) { + // @@protoc_insertion_point(field_mutable:google.protobuf.Api.mixins) + return mixins_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::Mixin >* +Api::mutable_mixins() { + // @@protoc_insertion_point(field_mutable_list:google.protobuf.Api.mixins) + return &mixins_; +} +inline const PROTOBUF_NAMESPACE_ID::Mixin& Api::_internal_mixins(int index) const { + return mixins_.Get(index); +} +inline const PROTOBUF_NAMESPACE_ID::Mixin& Api::mixins(int index) const { + // @@protoc_insertion_point(field_get:google.protobuf.Api.mixins) + return _internal_mixins(index); +} +inline PROTOBUF_NAMESPACE_ID::Mixin* Api::_internal_add_mixins() { + return mixins_.Add(); +} +inline PROTOBUF_NAMESPACE_ID::Mixin* Api::add_mixins() { + // @@protoc_insertion_point(field_add:google.protobuf.Api.mixins) + return _internal_add_mixins(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::Mixin >& +Api::mixins() const { + // @@protoc_insertion_point(field_list:google.protobuf.Api.mixins) + return mixins_; +} + +// .google.protobuf.Syntax syntax = 7; +inline void Api::clear_syntax() { + syntax_ = 0; +} +inline PROTOBUF_NAMESPACE_ID::Syntax Api::_internal_syntax() const { + return static_cast< PROTOBUF_NAMESPACE_ID::Syntax >(syntax_); +} +inline PROTOBUF_NAMESPACE_ID::Syntax Api::syntax() const { + // @@protoc_insertion_point(field_get:google.protobuf.Api.syntax) + return _internal_syntax(); +} +inline void Api::_internal_set_syntax(PROTOBUF_NAMESPACE_ID::Syntax value) { + + syntax_ = value; +} +inline void Api::set_syntax(PROTOBUF_NAMESPACE_ID::Syntax value) { + _internal_set_syntax(value); + // @@protoc_insertion_point(field_set:google.protobuf.Api.syntax) +} + +// ------------------------------------------------------------------- + +// Method + +// string name = 1; +inline void Method::clear_name() { + name_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline const std::string& Method::name() const { + // @@protoc_insertion_point(field_get:google.protobuf.Method.name) + return _internal_name(); +} +inline void Method::set_name(const std::string& value) { + _internal_set_name(value); + // @@protoc_insertion_point(field_set:google.protobuf.Method.name) +} +inline std::string* Method::mutable_name() { + // @@protoc_insertion_point(field_mutable:google.protobuf.Method.name) + return _internal_mutable_name(); +} +inline const std::string& Method::_internal_name() const { + return name_.Get(); +} +inline void Method::_internal_set_name(const std::string& value) { + + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void Method::set_name(std::string&& value) { + + name_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:google.protobuf.Method.name) +} +inline void Method::set_name(const char* value) { + GOOGLE_DCHECK(value != nullptr); + + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:google.protobuf.Method.name) +} +inline void Method::set_name(const char* value, + size_t size) { + + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:google.protobuf.Method.name) +} +inline std::string* Method::_internal_mutable_name() { + + return name_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* Method::release_name() { + // @@protoc_insertion_point(field_release:google.protobuf.Method.name) + return name_.Release(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void Method::set_allocated_name(std::string* name) { + if (name != nullptr) { + + } else { + + } + name_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), name, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:google.protobuf.Method.name) +} + +// string request_type_url = 2; +inline void Method::clear_request_type_url() { + request_type_url_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline const std::string& Method::request_type_url() const { + // @@protoc_insertion_point(field_get:google.protobuf.Method.request_type_url) + return _internal_request_type_url(); +} +inline void Method::set_request_type_url(const std::string& value) { + _internal_set_request_type_url(value); + // @@protoc_insertion_point(field_set:google.protobuf.Method.request_type_url) +} +inline std::string* Method::mutable_request_type_url() { + // @@protoc_insertion_point(field_mutable:google.protobuf.Method.request_type_url) + return _internal_mutable_request_type_url(); +} +inline const std::string& Method::_internal_request_type_url() const { + return request_type_url_.Get(); +} +inline void Method::_internal_set_request_type_url(const std::string& value) { + + request_type_url_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void Method::set_request_type_url(std::string&& value) { + + request_type_url_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:google.protobuf.Method.request_type_url) +} +inline void Method::set_request_type_url(const char* value) { + GOOGLE_DCHECK(value != nullptr); + + request_type_url_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:google.protobuf.Method.request_type_url) +} +inline void Method::set_request_type_url(const char* value, + size_t size) { + + request_type_url_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:google.protobuf.Method.request_type_url) +} +inline std::string* Method::_internal_mutable_request_type_url() { + + return request_type_url_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* Method::release_request_type_url() { + // @@protoc_insertion_point(field_release:google.protobuf.Method.request_type_url) + return request_type_url_.Release(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void Method::set_allocated_request_type_url(std::string* request_type_url) { + if (request_type_url != nullptr) { + + } else { + + } + request_type_url_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), request_type_url, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:google.protobuf.Method.request_type_url) +} + +// bool request_streaming = 3; +inline void Method::clear_request_streaming() { + request_streaming_ = false; +} +inline bool Method::_internal_request_streaming() const { + return request_streaming_; +} +inline bool Method::request_streaming() const { + // @@protoc_insertion_point(field_get:google.protobuf.Method.request_streaming) + return _internal_request_streaming(); +} +inline void Method::_internal_set_request_streaming(bool value) { + + request_streaming_ = value; +} +inline void Method::set_request_streaming(bool value) { + _internal_set_request_streaming(value); + // @@protoc_insertion_point(field_set:google.protobuf.Method.request_streaming) +} + +// string response_type_url = 4; +inline void Method::clear_response_type_url() { + response_type_url_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline const std::string& Method::response_type_url() const { + // @@protoc_insertion_point(field_get:google.protobuf.Method.response_type_url) + return _internal_response_type_url(); +} +inline void Method::set_response_type_url(const std::string& value) { + _internal_set_response_type_url(value); + // @@protoc_insertion_point(field_set:google.protobuf.Method.response_type_url) +} +inline std::string* Method::mutable_response_type_url() { + // @@protoc_insertion_point(field_mutable:google.protobuf.Method.response_type_url) + return _internal_mutable_response_type_url(); +} +inline const std::string& Method::_internal_response_type_url() const { + return response_type_url_.Get(); +} +inline void Method::_internal_set_response_type_url(const std::string& value) { + + response_type_url_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void Method::set_response_type_url(std::string&& value) { + + response_type_url_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:google.protobuf.Method.response_type_url) +} +inline void Method::set_response_type_url(const char* value) { + GOOGLE_DCHECK(value != nullptr); + + response_type_url_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:google.protobuf.Method.response_type_url) +} +inline void Method::set_response_type_url(const char* value, + size_t size) { + + response_type_url_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:google.protobuf.Method.response_type_url) +} +inline std::string* Method::_internal_mutable_response_type_url() { + + return response_type_url_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* Method::release_response_type_url() { + // @@protoc_insertion_point(field_release:google.protobuf.Method.response_type_url) + return response_type_url_.Release(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void Method::set_allocated_response_type_url(std::string* response_type_url) { + if (response_type_url != nullptr) { + + } else { + + } + response_type_url_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), response_type_url, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:google.protobuf.Method.response_type_url) +} + +// bool response_streaming = 5; +inline void Method::clear_response_streaming() { + response_streaming_ = false; +} +inline bool Method::_internal_response_streaming() const { + return response_streaming_; +} +inline bool Method::response_streaming() const { + // @@protoc_insertion_point(field_get:google.protobuf.Method.response_streaming) + return _internal_response_streaming(); +} +inline void Method::_internal_set_response_streaming(bool value) { + + response_streaming_ = value; +} +inline void Method::set_response_streaming(bool value) { + _internal_set_response_streaming(value); + // @@protoc_insertion_point(field_set:google.protobuf.Method.response_streaming) +} + +// repeated .google.protobuf.Option options = 6; +inline int Method::_internal_options_size() const { + return options_.size(); +} +inline int Method::options_size() const { + return _internal_options_size(); +} +inline PROTOBUF_NAMESPACE_ID::Option* Method::mutable_options(int index) { + // @@protoc_insertion_point(field_mutable:google.protobuf.Method.options) + return options_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::Option >* +Method::mutable_options() { + // @@protoc_insertion_point(field_mutable_list:google.protobuf.Method.options) + return &options_; +} +inline const PROTOBUF_NAMESPACE_ID::Option& Method::_internal_options(int index) const { + return options_.Get(index); +} +inline const PROTOBUF_NAMESPACE_ID::Option& Method::options(int index) const { + // @@protoc_insertion_point(field_get:google.protobuf.Method.options) + return _internal_options(index); +} +inline PROTOBUF_NAMESPACE_ID::Option* Method::_internal_add_options() { + return options_.Add(); +} +inline PROTOBUF_NAMESPACE_ID::Option* Method::add_options() { + // @@protoc_insertion_point(field_add:google.protobuf.Method.options) + return _internal_add_options(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::Option >& +Method::options() const { + // @@protoc_insertion_point(field_list:google.protobuf.Method.options) + return options_; +} + +// .google.protobuf.Syntax syntax = 7; +inline void Method::clear_syntax() { + syntax_ = 0; +} +inline PROTOBUF_NAMESPACE_ID::Syntax Method::_internal_syntax() const { + return static_cast< PROTOBUF_NAMESPACE_ID::Syntax >(syntax_); +} +inline PROTOBUF_NAMESPACE_ID::Syntax Method::syntax() const { + // @@protoc_insertion_point(field_get:google.protobuf.Method.syntax) + return _internal_syntax(); +} +inline void Method::_internal_set_syntax(PROTOBUF_NAMESPACE_ID::Syntax value) { + + syntax_ = value; +} +inline void Method::set_syntax(PROTOBUF_NAMESPACE_ID::Syntax value) { + _internal_set_syntax(value); + // @@protoc_insertion_point(field_set:google.protobuf.Method.syntax) +} + +// ------------------------------------------------------------------- + +// Mixin + +// string name = 1; +inline void Mixin::clear_name() { + name_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline const std::string& Mixin::name() const { + // @@protoc_insertion_point(field_get:google.protobuf.Mixin.name) + return _internal_name(); +} +inline void Mixin::set_name(const std::string& value) { + _internal_set_name(value); + // @@protoc_insertion_point(field_set:google.protobuf.Mixin.name) +} +inline std::string* Mixin::mutable_name() { + // @@protoc_insertion_point(field_mutable:google.protobuf.Mixin.name) + return _internal_mutable_name(); +} +inline const std::string& Mixin::_internal_name() const { + return name_.Get(); +} +inline void Mixin::_internal_set_name(const std::string& value) { + + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void Mixin::set_name(std::string&& value) { + + name_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:google.protobuf.Mixin.name) +} +inline void Mixin::set_name(const char* value) { + GOOGLE_DCHECK(value != nullptr); + + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:google.protobuf.Mixin.name) +} +inline void Mixin::set_name(const char* value, + size_t size) { + + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:google.protobuf.Mixin.name) +} +inline std::string* Mixin::_internal_mutable_name() { + + return name_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* Mixin::release_name() { + // @@protoc_insertion_point(field_release:google.protobuf.Mixin.name) + return name_.Release(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void Mixin::set_allocated_name(std::string* name) { + if (name != nullptr) { + + } else { + + } + name_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), name, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:google.protobuf.Mixin.name) +} + +// string root = 2; +inline void Mixin::clear_root() { + root_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline const std::string& Mixin::root() const { + // @@protoc_insertion_point(field_get:google.protobuf.Mixin.root) + return _internal_root(); +} +inline void Mixin::set_root(const std::string& value) { + _internal_set_root(value); + // @@protoc_insertion_point(field_set:google.protobuf.Mixin.root) +} +inline std::string* Mixin::mutable_root() { + // @@protoc_insertion_point(field_mutable:google.protobuf.Mixin.root) + return _internal_mutable_root(); +} +inline const std::string& Mixin::_internal_root() const { + return root_.Get(); +} +inline void Mixin::_internal_set_root(const std::string& value) { + + root_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void Mixin::set_root(std::string&& value) { + + root_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:google.protobuf.Mixin.root) +} +inline void Mixin::set_root(const char* value) { + GOOGLE_DCHECK(value != nullptr); + + root_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:google.protobuf.Mixin.root) +} +inline void Mixin::set_root(const char* value, + size_t size) { + + root_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:google.protobuf.Mixin.root) +} +inline std::string* Mixin::_internal_mutable_root() { + + return root_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* Mixin::release_root() { + // @@protoc_insertion_point(field_release:google.protobuf.Mixin.root) + return root_.Release(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void Mixin::set_allocated_root(std::string* root) { + if (root != nullptr) { + + } else { + + } + root_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), root, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:google.protobuf.Mixin.root) +} + +#ifdef __GNUC__ + #pragma GCC diagnostic pop +#endif // __GNUC__ +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + + +// @@protoc_insertion_point(namespace_scope) + +PROTOBUF_NAMESPACE_CLOSE + +// @@protoc_insertion_point(global_scope) + +#include +#endif // GOOGLE_PROTOBUF_INCLUDED_GOOGLE_PROTOBUF_INCLUDED_google_2fprotobuf_2fapi_2eproto + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/arena.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/arena.h new file mode 100644 index 0000000000000000000000000000000000000000..33adc15cad401fbeb880476d3965a301232a5777 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/arena.h @@ -0,0 +1,741 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// This file defines an Arena allocator for better allocation performance. + +#ifndef GOOGLE_PROTOBUF_ARENA_H__ +#define GOOGLE_PROTOBUF_ARENA_H__ + + +#include +#include +#include +#ifdef max +#undef max // Visual Studio defines this macro +#endif +#if defined(_MSC_VER) && !defined(_LIBCPP_STD_VER) && !_HAS_EXCEPTIONS +// Work around bugs in MSVC header when _HAS_EXCEPTIONS=0. +#include +#include +namespace std { +using type_info = ::type_info; +} +#else +#include +#endif + +#include +#include +#include + +#include + +#ifdef SWIG +#error "You cannot SWIG proto headers" +#endif + +namespace google { +namespace protobuf { + +struct ArenaOptions; // defined below + +} // namespace protobuf +} // namespace google + +namespace google { +namespace protobuf { + +class Arena; // defined below +class Message; // defined in message.h +class MessageLite; +template +class Map; + +namespace arena_metrics { + +void EnableArenaMetrics(ArenaOptions* options); + +} // namespace arena_metrics + +namespace internal { + +struct ArenaStringPtr; // defined in arenastring.h +class LazyField; // defined in lazy_field.h +class EpsCopyInputStream; // defined in parse_context.h + +template +class GenericTypeHandler; // defined in repeated_field.h + +// Templated cleanup methods. +template +void arena_destruct_object(void* object) { + reinterpret_cast(object)->~T(); +} +template +void arena_delete_object(void* object) { + delete reinterpret_cast(object); +} +inline void arena_free(void* object, size_t size) { +#if defined(__GXX_DELETE_WITH_SIZE__) || defined(__cpp_sized_deallocation) + ::operator delete(object, size); +#else + (void)size; + ::operator delete(object); +#endif +} + +} // namespace internal + +// ArenaOptions provides optional additional parameters to arena construction +// that control its block-allocation behavior. +struct ArenaOptions { + // This defines the size of the first block requested from the system malloc. + // Subsequent block sizes will increase in a geometric series up to a maximum. + size_t start_block_size; + + // This defines the maximum block size requested from system malloc (unless an + // individual arena allocation request occurs with a size larger than this + // maximum). Requested block sizes increase up to this value, then remain + // here. + size_t max_block_size; + + // An initial block of memory for the arena to use, or NULL for none. If + // provided, the block must live at least as long as the arena itself. The + // creator of the Arena retains ownership of the block after the Arena is + // destroyed. + char* initial_block; + + // The size of the initial block, if provided. + size_t initial_block_size; + + // A function pointer to an alloc method that returns memory blocks of size + // requested. By default, it contains a ptr to the malloc function. + // + // NOTE: block_alloc and dealloc functions are expected to behave like + // malloc and free, including Asan poisoning. + void* (*block_alloc)(size_t); + // A function pointer to a dealloc method that takes ownership of the blocks + // from the arena. By default, it contains a ptr to a wrapper function that + // calls free. + void (*block_dealloc)(void*, size_t); + + ArenaOptions() + : start_block_size(kDefaultStartBlockSize), + max_block_size(kDefaultMaxBlockSize), + initial_block(NULL), + initial_block_size(0), + block_alloc(&::operator new), + block_dealloc(&internal::arena_free), + on_arena_init(NULL), + on_arena_reset(NULL), + on_arena_destruction(NULL), + on_arena_allocation(NULL) {} + + private: + // Hooks for adding external functionality such as user-specific metrics + // collection, specific debugging abilities, etc. + // Init hook (if set) will always be called at Arena init time. Init hook may + // return a pointer to a cookie to be stored in the arena. Reset and + // destruction hooks will then be called with the same cookie pointer. This + // allows us to save an external object per arena instance and use it on the + // other hooks (Note: If init hook returns NULL, the other hooks will NOT be + // called on this arena instance). + // on_arena_reset and on_arena_destruction also receive the space used in the + // arena just before the reset. + void* (*on_arena_init)(Arena* arena); + void (*on_arena_reset)(Arena* arena, void* cookie, uint64 space_used); + void (*on_arena_destruction)(Arena* arena, void* cookie, uint64 space_used); + + // type_info is promised to be static - its lifetime extends to + // match program's lifetime (It is given by typeid operator). + // Note: typeid(void) will be passed as allocated_type every time we + // intentionally want to avoid monitoring an allocation. (i.e. internal + // allocations for managing the arena) + void (*on_arena_allocation)(const std::type_info* allocated_type, + uint64 alloc_size, void* cookie); + + // Constants define default starting block size and max block size for + // arena allocator behavior -- see descriptions above. + static const size_t kDefaultStartBlockSize = 256; + static const size_t kDefaultMaxBlockSize = 8192; + + friend void arena_metrics::EnableArenaMetrics(ArenaOptions*); + friend class Arena; + friend class ArenaOptionsTestFriend; +}; + +// Support for non-RTTI environments. (The metrics hooks API uses type +// information.) +#if PROTOBUF_RTTI +#define RTTI_TYPE_ID(type) (&typeid(type)) +#else +#define RTTI_TYPE_ID(type) (NULL) +#endif + +// Arena allocator. Arena allocation replaces ordinary (heap-based) allocation +// with new/delete, and improves performance by aggregating allocations into +// larger blocks and freeing allocations all at once. Protocol messages are +// allocated on an arena by using Arena::CreateMessage(Arena*), below, and +// are automatically freed when the arena is destroyed. +// +// This is a thread-safe implementation: multiple threads may allocate from the +// arena concurrently. Destruction is not thread-safe and the destructing +// thread must synchronize with users of the arena first. +// +// An arena provides two allocation interfaces: CreateMessage, which works +// for arena-enabled proto2 message types as well as other types that satisfy +// the appropriate protocol (described below), and Create, which works for +// any arbitrary type T. CreateMessage is better when the type T supports it, +// because this interface (i) passes the arena pointer to the created object so +// that its sub-objects and internal allocations can use the arena too, and (ii) +// elides the object's destructor call when possible. Create does not place +// any special requirements on the type T, and will invoke the object's +// destructor when the arena is destroyed. +// +// The arena message allocation protocol, required by +// CreateMessage(Arena* arena, Args&&... args), is as follows: +// +// - The type T must have (at least) two constructors: a constructor callable +// with `args` (without `arena`), called when a T is allocated on the heap; +// and a constructor callable with `Arena* arena, Args&&... args`, called when +// a T is allocated on an arena. If the second constructor is called with a +// NULL arena pointer, it must be equivalent to invoking the first +// (`args`-only) constructor. +// +// - The type T must have a particular type trait: a nested type +// |InternalArenaConstructable_|. This is usually a typedef to |void|. If no +// such type trait exists, then the instantiation CreateMessage will fail +// to compile. +// +// - The type T *may* have the type trait |DestructorSkippable_|. If this type +// trait is present in the type, then its destructor will not be called if and +// only if it was passed a non-NULL arena pointer. If this type trait is not +// present on the type, then its destructor is always called when the +// containing arena is destroyed. +// +// This protocol is implemented by all arena-enabled proto2 message classes as +// well as protobuf container types like RepeatedPtrField and Map. The protocol +// is internal to protobuf and is not guaranteed to be stable. Non-proto types +// should not rely on this protocol. +class PROTOBUF_EXPORT PROTOBUF_ALIGNAS(8) Arena final { + public: + // Arena constructor taking custom options. See ArenaOptions below for + // descriptions of the options available. + explicit Arena(const ArenaOptions& options) : impl_(options) { + Init(options); + } + + // Block overhead. Use this as a guide for how much to over-allocate the + // initial block if you want an allocation of size N to fit inside it. + // + // WARNING: if you allocate multiple objects, it is difficult to guarantee + // that a series of allocations will fit in the initial block, especially if + // Arena changes its alignment guarantees in the future! + static const size_t kBlockOverhead = internal::ArenaImpl::kBlockHeaderSize + + internal::ArenaImpl::kSerialArenaSize; + + // Default constructor with sensible default options, tuned for average + // use-cases. + Arena() : impl_(ArenaOptions()) { Init(ArenaOptions()); } + + ~Arena() { + if (hooks_cookie_) { + CallDestructorHooks(); + } + } + + void Init(const ArenaOptions& options) { + on_arena_allocation_ = options.on_arena_allocation; + on_arena_reset_ = options.on_arena_reset; + on_arena_destruction_ = options.on_arena_destruction; + // Call the initialization hook + if (options.on_arena_init != NULL) { + hooks_cookie_ = options.on_arena_init(this); + } else { + hooks_cookie_ = NULL; + } + } + + // API to create proto2 message objects on the arena. If the arena passed in + // is NULL, then a heap allocated object is returned. Type T must be a message + // defined in a .proto file with cc_enable_arenas set to true, otherwise a + // compilation error will occur. + // + // RepeatedField and RepeatedPtrField may also be instantiated directly on an + // arena with this method. + // + // This function also accepts any type T that satisfies the arena message + // allocation protocol, documented above. + template + PROTOBUF_ALWAYS_INLINE static T* CreateMessage(Arena* arena, Args&&... args) { + static_assert( + InternalHelper::is_arena_constructable::value, + "CreateMessage can only construct types that are ArenaConstructable"); + // We must delegate to CreateMaybeMessage() and NOT CreateMessageInternal() + // because protobuf generated classes specialize CreateMaybeMessage() and we + // need to use that specialization for code size reasons. + return Arena::CreateMaybeMessage(arena, std::forward(args)...); + } + + // API to create any objects on the arena. Note that only the object will + // be created on the arena; the underlying ptrs (in case of a proto2 message) + // will be still heap allocated. Proto messages should usually be allocated + // with CreateMessage() instead. + // + // Note that even if T satisfies the arena message construction protocol + // (InternalArenaConstructable_ trait and optional DestructorSkippable_ + // trait), as described above, this function does not follow the protocol; + // instead, it treats T as a black-box type, just as if it did not have these + // traits. Specifically, T's constructor arguments will always be only those + // passed to Create() -- no additional arena pointer is implicitly added. + // Furthermore, the destructor will always be called at arena destruction time + // (unless the destructor is trivial). Hence, from T's point of view, it is as + // if the object were allocated on the heap (except that the underlying memory + // is obtained from the arena). + template + PROTOBUF_ALWAYS_INLINE static T* Create(Arena* arena, Args&&... args) { + return CreateNoMessage(arena, is_arena_constructable(), + std::forward(args)...); + } + + // Create an array of object type T on the arena *without* invoking the + // constructor of T. If `arena` is null, then the return value should be freed + // with `delete[] x;` (or `::operator delete[](x);`). + // To ensure safe uses, this function checks at compile time + // (when compiled as C++11) that T is trivially default-constructible and + // trivially destructible. + template + PROTOBUF_ALWAYS_INLINE static T* CreateArray(Arena* arena, + size_t num_elements) { + static_assert(std::is_pod::value, + "CreateArray requires a trivially constructible type"); + static_assert(std::is_trivially_destructible::value, + "CreateArray requires a trivially destructible type"); + GOOGLE_CHECK_LE(num_elements, std::numeric_limits::max() / sizeof(T)) + << "Requested size is too large to fit into size_t."; + if (arena == NULL) { + return static_cast(::operator new[](num_elements * sizeof(T))); + } else { + return arena->CreateInternalRawArray(num_elements); + } + } + + // Returns the total space allocated by the arena, which is the sum of the + // sizes of the underlying blocks. This method is relatively fast; a counter + // is kept as blocks are allocated. + uint64 SpaceAllocated() const { return impl_.SpaceAllocated(); } + // Returns the total space used by the arena. Similar to SpaceAllocated but + // does not include free space and block overhead. The total space returned + // may not include space used by other threads executing concurrently with + // the call to this method. + uint64 SpaceUsed() const { return impl_.SpaceUsed(); } + + // Frees all storage allocated by this arena after calling destructors + // registered with OwnDestructor() and freeing objects registered with Own(). + // Any objects allocated on this arena are unusable after this call. It also + // returns the total space used by the arena which is the sums of the sizes + // of the allocated blocks. This method is not thread-safe. + PROTOBUF_NOINLINE uint64 Reset() { + // Call the reset hook + if (on_arena_reset_ != NULL) { + on_arena_reset_(this, hooks_cookie_, impl_.SpaceAllocated()); + } + return impl_.Reset(); + } + + // Adds |object| to a list of heap-allocated objects to be freed with |delete| + // when the arena is destroyed or reset. + template + PROTOBUF_NOINLINE void Own(T* object) { + OwnInternal(object, std::is_convertible()); + } + + // Adds |object| to a list of objects whose destructors will be manually + // called when the arena is destroyed or reset. This differs from Own() in + // that it does not free the underlying memory with |delete|; hence, it is + // normally only used for objects that are placement-newed into + // arena-allocated memory. + template + PROTOBUF_NOINLINE void OwnDestructor(T* object) { + if (object != NULL) { + impl_.AddCleanup(object, &internal::arena_destruct_object); + } + } + + // Adds a custom member function on an object to the list of destructors that + // will be manually called when the arena is destroyed or reset. This differs + // from OwnDestructor() in that any member function may be specified, not only + // the class destructor. + PROTOBUF_NOINLINE void OwnCustomDestructor(void* object, + void (*destruct)(void*)) { + impl_.AddCleanup(object, destruct); + } + + // Retrieves the arena associated with |value| if |value| is an arena-capable + // message, or NULL otherwise. If possible, the call resolves at compile time. + // Note that we can often devirtualize calls to `value->GetArena()` so usually + // calling this method is unnecessary. + template + PROTOBUF_ALWAYS_INLINE static Arena* GetArena(const T* value) { + return GetArenaInternal(value); + } + + template + class InternalHelper { + template + static char DestructorSkippable(const typename U::DestructorSkippable_*); + template + static double DestructorSkippable(...); + + typedef std::integral_constant< + bool, sizeof(DestructorSkippable(static_cast(0))) == + sizeof(char) || + std::is_trivially_destructible::value> + is_destructor_skippable; + + template + static char ArenaConstructable( + const typename U::InternalArenaConstructable_*); + template + static double ArenaConstructable(...); + + typedef std::integral_constant( + static_cast(0))) == + sizeof(char)> + is_arena_constructable; + + template () + .GetArena())>::value, + int>::type = 0> + static char HasGetArena(decltype(&U::GetArena)); + template + static double HasGetArena(...); + + typedef std::integral_constant(nullptr)) == + sizeof(char)> + has_get_arena; + + template + static T* Construct(void* ptr, Args&&... args) { + return new (ptr) T(std::forward(args)...); + } + + static Arena* GetArena(const T* p) { return p->GetArena(); } + + friend class Arena; + }; + + // Helper typetraits that indicates support for arenas in a type T at compile + // time. This is public only to allow construction of higher-level templated + // utilities. + // + // is_arena_constructable::value is true if the message type T has arena + // support enabled, and false otherwise. + // + // is_destructor_skippable::value is true if the message type T has told + // the arena that it is safe to skip the destructor, and false otherwise. + // + // This is inside Arena because only Arena has the friend relationships + // necessary to see the underlying generated code traits. + template + struct is_arena_constructable : InternalHelper::is_arena_constructable {}; + template + struct is_destructor_skippable : InternalHelper::is_destructor_skippable { + }; + + private: + template + struct has_get_arena : InternalHelper::has_get_arena {}; + + template + PROTOBUF_ALWAYS_INLINE static T* CreateMessageInternal(Arena* arena, + Args&&... args) { + static_assert( + InternalHelper::is_arena_constructable::value, + "CreateMessage can only construct types that are ArenaConstructable"); + if (arena == NULL) { + return new T(nullptr, std::forward(args)...); + } else { + return arena->DoCreateMessage(std::forward(args)...); + } + } + + // This specialization for no arguments is necessary, because its behavior is + // slightly different. When the arena pointer is nullptr, it calls T() + // instead of T(nullptr). + template + PROTOBUF_ALWAYS_INLINE static T* CreateMessageInternal(Arena* arena) { + static_assert( + InternalHelper::is_arena_constructable::value, + "CreateMessage can only construct types that are ArenaConstructable"); + if (arena == NULL) { + return new T(); + } else { + return arena->DoCreateMessage(); + } + } + + template + PROTOBUF_ALWAYS_INLINE static T* CreateInternal(Arena* arena, + Args&&... args) { + if (arena == NULL) { + return new T(std::forward(args)...); + } else { + return arena->DoCreate(std::is_trivially_destructible::value, + std::forward(args)...); + } + } + + void CallDestructorHooks(); + void OnArenaAllocation(const std::type_info* allocated_type, size_t n) const; + inline void AllocHook(const std::type_info* allocated_type, size_t n) const { + if (PROTOBUF_PREDICT_FALSE(hooks_cookie_ != NULL)) { + OnArenaAllocation(allocated_type, n); + } + } + + // Allocate and also optionally call on_arena_allocation callback with the + // allocated type info when the hooks are in place in ArenaOptions and + // the cookie is not null. + template + PROTOBUF_ALWAYS_INLINE void* AllocateInternal(bool skip_explicit_ownership) { + static_assert(alignof(T) <= 8, "T is overaligned, see b/151247138"); + const size_t n = internal::AlignUpTo8(sizeof(T)); + AllocHook(RTTI_TYPE_ID(T), n); + // Monitor allocation if needed. + if (skip_explicit_ownership) { + return AllocateAlignedNoHook(n); + } else { + return impl_.AllocateAlignedAndAddCleanup( + n, &internal::arena_destruct_object); + } + } + + // CreateMessage requires that T supports arenas, but this private method + // works whether or not T supports arenas. These are not exposed to user code + // as it can cause confusing API usages, and end up having double free in + // user code. These are used only internally from LazyField and Repeated + // fields, since they are designed to work in all mode combinations. + template + PROTOBUF_ALWAYS_INLINE static Msg* DoCreateMaybeMessage(Arena* arena, + std::true_type, + Args&&... args) { + return CreateMessageInternal(arena, std::forward(args)...); + } + + template + PROTOBUF_ALWAYS_INLINE static T* DoCreateMaybeMessage(Arena* arena, + std::false_type, + Args&&... args) { + return CreateInternal(arena, std::forward(args)...); + } + + template + PROTOBUF_ALWAYS_INLINE static T* CreateMaybeMessage(Arena* arena, + Args&&... args) { + return DoCreateMaybeMessage(arena, is_arena_constructable(), + std::forward(args)...); + } + + template + PROTOBUF_ALWAYS_INLINE static T* CreateNoMessage(Arena* arena, std::true_type, + Args&&... args) { + // User is constructing with Create() despite the fact that T supports arena + // construction. In this case we have to delegate to CreateInternal(), and + // we can't use any CreateMaybeMessage() specialization that may be defined. + return CreateInternal(arena, std::forward(args)...); + } + + template + PROTOBUF_ALWAYS_INLINE static T* CreateNoMessage(Arena* arena, + std::false_type, + Args&&... args) { + // User is constructing with Create() and the type does not support arena + // construction. In this case we can delegate to CreateMaybeMessage() and + // use any specialization that may be available for that. + return CreateMaybeMessage(arena, std::forward(args)...); + } + + // Just allocate the required size for the given type assuming the + // type has a trivial constructor. + template + PROTOBUF_ALWAYS_INLINE T* CreateInternalRawArray(size_t num_elements) { + GOOGLE_CHECK_LE(num_elements, std::numeric_limits::max() / sizeof(T)) + << "Requested size is too large to fit into size_t."; + const size_t n = internal::AlignUpTo8(sizeof(T) * num_elements); + // Monitor allocation if needed. + AllocHook(RTTI_TYPE_ID(T), n); + return static_cast(AllocateAlignedNoHook(n)); + } + + template + PROTOBUF_ALWAYS_INLINE T* DoCreate(bool skip_explicit_ownership, + Args&&... args) { + return new (AllocateInternal(skip_explicit_ownership)) + T(std::forward(args)...); + } + template + PROTOBUF_ALWAYS_INLINE T* DoCreateMessage(Args&&... args) { + return InternalHelper::Construct( + AllocateInternal(InternalHelper::is_destructor_skippable::value), + this, std::forward(args)...); + } + + // CreateInArenaStorage is used to implement map field. Without it, + // Map need to call generated message's protected arena constructor, + // which needs to declare Map as friend of generated message. + template + static void CreateInArenaStorage(T* ptr, Arena* arena, Args&&... args) { + CreateInArenaStorageInternal(ptr, arena, + typename is_arena_constructable::type(), + std::forward(args)...); + RegisterDestructorInternal( + ptr, arena, + typename InternalHelper::is_destructor_skippable::type()); + } + + template + static void CreateInArenaStorageInternal(T* ptr, Arena* arena, + std::true_type, Args&&... args) { + InternalHelper::Construct(ptr, arena, std::forward(args)...); + } + template + static void CreateInArenaStorageInternal(T* ptr, Arena* /* arena */, + std::false_type, Args&&... args) { + new (ptr) T(std::forward(args)...); + } + + template + static void RegisterDestructorInternal(T* /* ptr */, Arena* /* arena */, + std::true_type) {} + template + static void RegisterDestructorInternal(T* ptr, Arena* arena, + std::false_type) { + arena->OwnDestructor(ptr); + } + + // These implement Own(), which registers an object for deletion (destructor + // call and operator delete()). The second parameter has type 'true_type' if T + // is a subtype of Message and 'false_type' otherwise. Collapsing + // all template instantiations to one for generic Message reduces code size, + // using the virtual destructor instead. + template + PROTOBUF_ALWAYS_INLINE void OwnInternal(T* object, std::true_type) { + if (object != NULL) { + impl_.AddCleanup(object, &internal::arena_delete_object); + } + } + template + PROTOBUF_ALWAYS_INLINE void OwnInternal(T* object, std::false_type) { + if (object != NULL) { + impl_.AddCleanup(object, &internal::arena_delete_object); + } + } + + // Implementation for GetArena(). Only message objects with + // InternalArenaConstructable_ tags can be associated with an arena, and such + // objects must implement a GetArena() method. + template ::value, int>::type = 0> + PROTOBUF_ALWAYS_INLINE static Arena* GetArenaInternal(const T* value) { + return InternalHelper::GetArena(value); + } + template ::value && + has_get_arena::value, + int>::type = 0> + PROTOBUF_ALWAYS_INLINE static Arena* GetArenaInternal(const T* value) { + return value->GetArena(); + } + template ::value && + !has_get_arena::value, + int>::type = 0> + PROTOBUF_ALWAYS_INLINE static Arena* GetArenaInternal(const T* value) { + (void)value; + return nullptr; + } + + // For friends of arena. + void* AllocateAligned(size_t n) { + AllocHook(NULL, n); + return AllocateAlignedNoHook(internal::AlignUpTo8(n)); + } + template + void* AllocateAlignedTo(size_t n) { + static_assert(Align > 0, "Alignment must be greater than 0"); + static_assert((Align & (Align - 1)) == 0, "Alignment must be power of two"); + if (Align <= 8) return AllocateAligned(n); + // TODO(b/151247138): if the pointer would have been aligned already, + // this is wasting space. We should pass the alignment down. + uintptr_t ptr = reinterpret_cast(AllocateAligned(n + Align - 8)); + ptr = (ptr + Align - 1) & -Align; + return reinterpret_cast(ptr); + } + + void* AllocateAlignedNoHook(size_t n); + + internal::ArenaImpl impl_; + + void (*on_arena_allocation_)(const std::type_info* allocated_type, + uint64 alloc_size, void* cookie); + void (*on_arena_reset_)(Arena* arena, void* cookie, uint64 space_used); + void (*on_arena_destruction_)(Arena* arena, void* cookie, uint64 space_used); + + // The arena may save a cookie it receives from the external on_init hook + // and then use it when calling the on_reset and on_destruction hooks. + void* hooks_cookie_; + + template + friend class internal::GenericTypeHandler; + friend struct internal::ArenaStringPtr; // For AllocateAligned. + friend class internal::LazyField; // For CreateMaybeMessage. + friend class internal::EpsCopyInputStream; // For parser performance + friend class MessageLite; + template + friend class Map; +}; + +// Defined above for supporting environments without RTTI. +#undef RTTI_TYPE_ID + +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_ARENA_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/arena_impl.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/arena_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..701ae5fae6d66aff211e7f8972dd9763735f619e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/arena_impl.h @@ -0,0 +1,394 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// This file defines an Arena allocator for better allocation performance. + +#ifndef GOOGLE_PROTOBUF_ARENA_IMPL_H__ +#define GOOGLE_PROTOBUF_ARENA_IMPL_H__ + +#include +#include + +#include +#include + +#ifdef ADDRESS_SANITIZER +#include +#endif // ADDRESS_SANITIZER + +#include + + +namespace google { +namespace protobuf { +namespace internal { + +inline size_t AlignUpTo8(size_t n) { + // Align n to next multiple of 8 (from Hacker's Delight, Chapter 3.) + return (n + 7) & static_cast(-8); +} + +using LifecycleId = int64_t; + +// This class provides the core Arena memory allocation library. Different +// implementations only need to implement the public interface below. +// Arena is not a template type as that would only be useful if all protos +// in turn would be templates, which will/cannot happen. However separating +// the memory allocation part from the cruft of the API users expect we can +// use #ifdef the select the best implementation based on hardware / OS. +class PROTOBUF_EXPORT ArenaImpl { + public: + struct Options { + size_t start_block_size; + size_t max_block_size; + char* initial_block; + size_t initial_block_size; + void* (*block_alloc)(size_t); + void (*block_dealloc)(void*, size_t); + + template + explicit Options(const O& options) + : start_block_size(options.start_block_size), + max_block_size(options.max_block_size), + initial_block(options.initial_block), + initial_block_size(options.initial_block_size), + block_alloc(options.block_alloc), + block_dealloc(options.block_dealloc) {} + }; + + template + explicit ArenaImpl(const O& options) : options_(options) { + if (options_.initial_block != NULL && options_.initial_block_size > 0) { + GOOGLE_CHECK_GE(options_.initial_block_size, sizeof(Block)) + << ": Initial block size too small for header."; + initial_block_ = reinterpret_cast(options_.initial_block); + } else { + initial_block_ = NULL; + } + + Init(); + } + + // Destructor deletes all owned heap allocated objects, and destructs objects + // that have non-trivial destructors, except for proto2 message objects whose + // destructors can be skipped. Also, frees all blocks except the initial block + // if it was passed in. + ~ArenaImpl(); + + uint64 Reset(); + + uint64 SpaceAllocated() const; + uint64 SpaceUsed() const; + + void* AllocateAligned(size_t n) { + SerialArena* arena; + if (PROTOBUF_PREDICT_TRUE(GetSerialArenaFast(&arena))) { + return arena->AllocateAligned(n); + } else { + return AllocateAlignedFallback(n); + } + } + + // This function allocates n bytes if the common happy case is true and + // returns true. Otherwise does nothing and returns false. This strange + // semantics is necessary to allow callers to program functions that only + // have fallback function calls in tail position. This substantially improves + // code for the happy path. + PROTOBUF_ALWAYS_INLINE bool MaybeAllocateAligned(size_t n, void** out) { + SerialArena* a; + if (PROTOBUF_PREDICT_TRUE(GetSerialArenaFromThreadCache(&a))) { + return a->MaybeAllocateAligned(n, out); + } + return false; + } + + void* AllocateAlignedAndAddCleanup(size_t n, void (*cleanup)(void*)); + + // Add object pointer and cleanup function pointer to the list. + void AddCleanup(void* elem, void (*cleanup)(void*)); + + private: + friend class ArenaBenchmark; + + void* AllocateAlignedFallback(size_t n); + void* AllocateAlignedAndAddCleanupFallback(size_t n, void (*cleanup)(void*)); + void AddCleanupFallback(void* elem, void (*cleanup)(void*)); + + // Node contains the ptr of the object to be cleaned up and the associated + // cleanup function ptr. + struct CleanupNode { + void* elem; // Pointer to the object to be cleaned up. + void (*cleanup)(void*); // Function pointer to the destructor or deleter. + }; + + // Cleanup uses a chunked linked list, to reduce pointer chasing. + struct CleanupChunk { + static size_t SizeOf(size_t i) { + return sizeof(CleanupChunk) + (sizeof(CleanupNode) * (i - 1)); + } + size_t size; // Total elements in the list. + CleanupChunk* next; // Next node in the list. + CleanupNode nodes[1]; // True length is |size|. + }; + + class Block; + + // A thread-unsafe Arena that can only be used within its owning thread. + class PROTOBUF_EXPORT SerialArena { + public: + // The allocate/free methods here are a little strange, since SerialArena is + // allocated inside a Block which it also manages. This is to avoid doing + // an extra allocation for the SerialArena itself. + + // Creates a new SerialArena inside Block* and returns it. + static SerialArena* New(Block* b, void* owner, ArenaImpl* arena); + + // Destroys this SerialArena, freeing all blocks with the given dealloc + // function, except any block equal to |initial_block|. + static uint64 Free(SerialArena* serial, Block* initial_block, + void (*block_dealloc)(void*, size_t)); + + void CleanupList(); + uint64 SpaceUsed() const; + + bool HasSpace(size_t n) { return n <= static_cast(limit_ - ptr_); } + + void* AllocateAligned(size_t n) { + GOOGLE_DCHECK_EQ(internal::AlignUpTo8(n), n); // Must be already aligned. + GOOGLE_DCHECK_GE(limit_, ptr_); + if (PROTOBUF_PREDICT_FALSE(!HasSpace(n))) { + return AllocateAlignedFallback(n); + } + void* ret = ptr_; + ptr_ += n; +#ifdef ADDRESS_SANITIZER + ASAN_UNPOISON_MEMORY_REGION(ret, n); +#endif // ADDRESS_SANITIZER + return ret; + } + + // Allocate space if the current region provides enough space. + bool MaybeAllocateAligned(size_t n, void** out) { + GOOGLE_DCHECK_EQ(internal::AlignUpTo8(n), n); // Must be already aligned. + GOOGLE_DCHECK_GE(limit_, ptr_); + if (PROTOBUF_PREDICT_FALSE(!HasSpace(n))) return false; + void* ret = ptr_; + ptr_ += n; +#ifdef ADDRESS_SANITIZER + ASAN_UNPOISON_MEMORY_REGION(ret, n); +#endif // ADDRESS_SANITIZER + *out = ret; + return true; + } + + void AddCleanup(void* elem, void (*cleanup)(void*)) { + if (PROTOBUF_PREDICT_FALSE(cleanup_ptr_ == cleanup_limit_)) { + AddCleanupFallback(elem, cleanup); + return; + } + cleanup_ptr_->elem = elem; + cleanup_ptr_->cleanup = cleanup; + cleanup_ptr_++; + } + + void* AllocateAlignedAndAddCleanup(size_t n, void (*cleanup)(void*)) { + void* ret = AllocateAligned(n); + AddCleanup(ret, cleanup); + return ret; + } + + void* owner() const { return owner_; } + SerialArena* next() const { return next_; } + void set_next(SerialArena* next) { next_ = next; } + + private: + void* AllocateAlignedFallback(size_t n); + void AddCleanupFallback(void* elem, void (*cleanup)(void*)); + void CleanupListFallback(); + + ArenaImpl* arena_; // Containing arena. + void* owner_; // &ThreadCache of this thread; + Block* head_; // Head of linked list of blocks. + CleanupChunk* cleanup_; // Head of cleanup list. + SerialArena* next_; // Next SerialArena in this linked list. + + // Next pointer to allocate from. Always 8-byte aligned. Points inside + // head_ (and head_->pos will always be non-canonical). We keep these + // here to reduce indirection. + char* ptr_; + char* limit_; + + // Next CleanupList members to append to. These point inside cleanup_. + CleanupNode* cleanup_ptr_; + CleanupNode* cleanup_limit_; + }; + + // Blocks are variable length malloc-ed objects. The following structure + // describes the common header for all blocks. + class PROTOBUF_EXPORT Block { + public: + Block(size_t size, Block* next); + + char* Pointer(size_t n) { + GOOGLE_DCHECK(n <= size_); + return reinterpret_cast(this) + n; + } + + Block* next() const { return next_; } + size_t pos() const { return pos_; } + size_t size() const { return size_; } + void set_pos(size_t pos) { pos_ = pos; } + + private: + Block* next_; // Next block for this thread. + size_t pos_; + size_t size_; + // data follows + }; + + struct ThreadCache { +#if defined(GOOGLE_PROTOBUF_NO_THREADLOCAL) + // If we are using the ThreadLocalStorage class to store the ThreadCache, + // then the ThreadCache's default constructor has to be responsible for + // initializing it. + ThreadCache() : last_lifecycle_id_seen(-1), last_serial_arena(NULL) {} +#endif + + // The ThreadCache is considered valid as long as this matches the + // lifecycle_id of the arena being used. + LifecycleId last_lifecycle_id_seen; + SerialArena* last_serial_arena; + }; + static std::atomic lifecycle_id_generator_; +#if defined(GOOGLE_PROTOBUF_NO_THREADLOCAL) + // Android ndk does not support __thread keyword so we use a custom thread + // local storage class we implemented. + // iOS also does not support the __thread keyword. + static ThreadCache& thread_cache(); +#elif defined(PROTOBUF_USE_DLLS) + // Thread local variables cannot be exposed through DLL interface but we can + // wrap them in static functions. + static ThreadCache& thread_cache(); +#else + static PROTOBUF_THREAD_LOCAL ThreadCache thread_cache_; + static ThreadCache& thread_cache() { return thread_cache_; } +#endif + + void Init(); + + // Free all blocks and return the total space used which is the sums of sizes + // of the all the allocated blocks. + uint64 FreeBlocks(); + // Delete or Destruct all objects owned by the arena. + void CleanupList(); + + inline void CacheSerialArena(SerialArena* serial) { + thread_cache().last_serial_arena = serial; + thread_cache().last_lifecycle_id_seen = lifecycle_id_; + // TODO(haberman): evaluate whether we would gain efficiency by getting rid + // of hint_. It's the only write we do to ArenaImpl in the allocation path, + // which will dirty the cache line. + + hint_.store(serial, std::memory_order_release); + } + + std::atomic + threads_; // Pointer to a linked list of SerialArena. + std::atomic hint_; // Fast thread-local block access + std::atomic space_allocated_; // Total size of all allocated blocks. + + Block* initial_block_; // If non-NULL, points to the block that came from + // user data. + + Block* NewBlock(Block* last_block, size_t min_bytes); + + SerialArena* GetSerialArena(); + PROTOBUF_ALWAYS_INLINE bool GetSerialArenaFast(SerialArena** arena) { + if (GetSerialArenaFromThreadCache(arena)) return true; + + // Check whether we own the last accessed SerialArena on this arena. This + // fast path optimizes the case where a single thread uses multiple arenas. + ThreadCache* tc = &thread_cache(); + SerialArena* serial = hint_.load(std::memory_order_acquire); + if (PROTOBUF_PREDICT_TRUE(serial != NULL && serial->owner() == tc)) { + *arena = serial; + return true; + } + return false; + } + + PROTOBUF_ALWAYS_INLINE bool GetSerialArenaFromThreadCache( + SerialArena** arena) { + // If this thread already owns a block in this arena then try to use that. + // This fast path optimizes the case where multiple threads allocate from + // the same arena. + ThreadCache* tc = &thread_cache(); + if (PROTOBUF_PREDICT_TRUE(tc->last_lifecycle_id_seen == lifecycle_id_)) { + *arena = tc->last_serial_arena; + return true; + } + return false; + } + SerialArena* GetSerialArenaFallback(void* me); + LifecycleId lifecycle_id_; // Unique for each arena. Changes on Reset(). + + Options options_; + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(ArenaImpl); + // All protos have pointers back to the arena hence Arena must have + // pointer stability. + ArenaImpl(ArenaImpl&&) = delete; + ArenaImpl& operator=(ArenaImpl&&) = delete; + + public: + // kBlockHeaderSize is sizeof(Block), aligned up to the nearest multiple of 8 + // to protect the invariant that pos is always at a multiple of 8. + static const size_t kBlockHeaderSize = + (sizeof(Block) + 7) & static_cast(-8); + static const size_t kSerialArenaSize = + (sizeof(SerialArena) + 7) & static_cast(-8); + static_assert(kBlockHeaderSize % 8 == 0, + "kBlockHeaderSize must be a multiple of 8."); + static_assert(kSerialArenaSize % 8 == 0, + "kSerialArenaSize must be a multiple of 8."); +}; + +} // namespace internal +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_ARENA_IMPL_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/arenastring.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/arenastring.h new file mode 100644 index 0000000000000000000000000000000000000000..43955d71ee43de9fa9724b7f263394e29ecad73b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/arenastring.h @@ -0,0 +1,410 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef GOOGLE_PROTOBUF_ARENASTRING_H__ +#define GOOGLE_PROTOBUF_ARENASTRING_H__ + +#include +#include +#include + +#include +#include +#include +#include +#include + +#include + +#ifdef SWIG +#error "You cannot SWIG proto headers" +#endif + + +// This is the implementation of arena string fields written for the open-source +// release. The ArenaStringPtr struct below is an internal implementation class +// and *should not be used* by user code. It is used to collect string +// operations together into one place and abstract away the underlying +// string-field pointer representation, so that (for example) an alternate +// implementation that knew more about ::std::string's internals could integrate +// more closely with the arena allocator. + +namespace google { +namespace protobuf { +namespace internal { + +template +class TaggedPtr { + public: + void Set(T* p) { ptr_ = reinterpret_cast(p); } + T* Get() const { return reinterpret_cast(ptr_); } + + bool IsNull() { return ptr_ == 0; } + + private: + uintptr_t ptr_; +}; + +struct PROTOBUF_EXPORT ArenaStringPtr { + inline void Set(const ::std::string* default_value, + const ::std::string& value, Arena* arena) { + if (ptr_ == default_value) { + CreateInstance(arena, &value); + } else { + *ptr_ = value; + } + } + + inline void SetLite(const ::std::string* default_value, + const ::std::string& value, Arena* arena) { + Set(default_value, value, arena); + } + + // Basic accessors. + inline const ::std::string& Get() const { return *ptr_; } + + inline ::std::string* Mutable(const ::std::string* default_value, + Arena* arena) { + if (ptr_ == default_value) { + CreateInstance(arena, default_value); + } + return ptr_; + } + + // Release returns a ::std::string* instance that is heap-allocated and is not + // Own()'d by any arena. If the field was not set, it returns NULL. The caller + // retains ownership. Clears this field back to NULL state. Used to implement + // release_() methods on generated classes. + inline ::std::string* Release(const ::std::string* default_value, + Arena* arena) { + if (ptr_ == default_value) { + return NULL; + } + return ReleaseNonDefault(default_value, arena); + } + + // Similar to Release, but ptr_ cannot be the default_value. + inline ::std::string* ReleaseNonDefault(const ::std::string* default_value, + Arena* arena) { + GOOGLE_DCHECK(!IsDefault(default_value)); + ::std::string* released = NULL; + if (arena != NULL) { + // ptr_ is owned by the arena. + released = new ::std::string; + released->swap(*ptr_); + } else { + released = ptr_; + } + ptr_ = const_cast< ::std::string*>(default_value); + return released; + } + + // UnsafeArenaRelease returns a ::std::string*, but it may be arena-owned + // (i.e. have its destructor already registered) if arena != NULL. If the + // field was not set, this returns NULL. This method clears this field back to + // NULL state. Used to implement unsafe_arena_release_() methods on + // generated classes. + inline ::std::string* UnsafeArenaRelease(const ::std::string* default_value, + Arena* /* arena */) { + if (ptr_ == default_value) { + return NULL; + } + ::std::string* released = ptr_; + ptr_ = const_cast< ::std::string*>(default_value); + return released; + } + + // Takes a string that is heap-allocated, and takes ownership. The string's + // destructor is registered with the arena. Used to implement + // set_allocated_ in generated classes. + inline void SetAllocated(const ::std::string* default_value, + ::std::string* value, Arena* arena) { + if (arena == NULL && ptr_ != default_value) { + Destroy(default_value, arena); + } + if (value != NULL) { + ptr_ = value; + if (arena != NULL) { + arena->Own(value); + } + } else { + ptr_ = const_cast< ::std::string*>(default_value); + } + } + + // Takes a string that has lifetime equal to the arena's lifetime. The arena + // must be non-null. It is safe only to pass this method a value returned by + // UnsafeArenaRelease() on another field of a message in the same arena. Used + // to implement unsafe_arena_set_allocated_ in generated classes. + inline void UnsafeArenaSetAllocated(const ::std::string* default_value, + ::std::string* value, + Arena* /* arena */) { + if (value != NULL) { + ptr_ = value; + } else { + ptr_ = const_cast< ::std::string*>(default_value); + } + } + + // Swaps internal pointers. Arena-safety semantics: this is guarded by the + // logic in Swap()/UnsafeArenaSwap() at the message level, so this method is + // 'unsafe' if called directly. + PROTOBUF_ALWAYS_INLINE void Swap(ArenaStringPtr* other) { + std::swap(ptr_, other->ptr_); + } + PROTOBUF_ALWAYS_INLINE void Swap(ArenaStringPtr* other, + const ::std::string* default_value, + Arena* arena) { +#ifndef NDEBUG + // For debug builds, we swap the contents of the string, rather than the + // string instances themselves. This invalidates previously taken const + // references that are (per our documentation) invalidated by calling Swap() + // on the message. + // + // If both strings are the default_value, swapping is uninteresting. + // Otherwise, we use ArenaStringPtr::Mutable() to access the string, to + // ensure that we do not try to mutate default_value itself. + if (IsDefault(default_value) && other->IsDefault(default_value)) { + return; + } + + ::std::string* this_ptr = Mutable(default_value, arena); + ::std::string* other_ptr = other->Mutable(default_value, arena); + + this_ptr->swap(*other_ptr); +#else + std::swap(ptr_, other->ptr_); + (void)default_value; + (void)arena; +#endif + } + + // Frees storage (if not on an arena). + inline void Destroy(const ::std::string* default_value, Arena* arena) { + if (arena == NULL && ptr_ != default_value) { + delete ptr_; + } + } + + // Clears content, but keeps allocated string if arena != NULL, to avoid the + // overhead of heap operations. After this returns, the content (as seen by + // the user) will always be the empty string. Assumes that |default_value| + // is an empty string. + inline void ClearToEmpty(const ::std::string* default_value, + Arena* /* arena */) { + if (ptr_ == default_value) { + // Already set to default (which is empty) -- do nothing. + } else { + ptr_->clear(); + } + } + + // Clears content, assuming that the current value is not the empty string + // default. + inline void ClearNonDefaultToEmpty() { ptr_->clear(); } + inline void ClearNonDefaultToEmptyNoArena() { ptr_->clear(); } + + // Clears content, but keeps allocated string if arena != NULL, to avoid the + // overhead of heap operations. After this returns, the content (as seen by + // the user) will always be equal to |default_value|. + inline void ClearToDefault(const ::std::string* default_value, + Arena* /* arena */) { + if (ptr_ == default_value) { + // Already set to default -- do nothing. + } else { + // Have another allocated string -- rather than throwing this away and + // resetting ptr_ to the canonical default string instance, we just reuse + // this instance. + *ptr_ = *default_value; + } + } + + // Called from generated code / reflection runtime only. Resets value to point + // to a default string pointer, with the semantics that this ArenaStringPtr + // does not own the pointed-to memory. Disregards initial value of ptr_ (so + // this is the *ONLY* safe method to call after construction or when + // reinitializing after becoming the active field in a oneof union). + inline void UnsafeSetDefault(const ::std::string* default_value) { + // Casting away 'const' is safe here: accessors ensure that ptr_ is only + // returned as a const if it is equal to default_value. + ptr_ = const_cast< ::std::string*>(default_value); + } + + // The 'NoArena' variants of methods below assume arena == NULL and are + // optimized to provide very little overhead relative to a raw string pointer + // (while still being in-memory compatible with other code that assumes + // ArenaStringPtr). Note the invariant that a class instance that has only + // ever been mutated by NoArena methods must *only* be in the String state + // (i.e., tag bits are not used), *NEVER* ArenaString. This allows all + // tagged-pointer manipulations to be avoided. + inline void SetNoArena(const ::std::string* default_value, + const ::std::string& value) { + if (ptr_ == default_value) { + CreateInstanceNoArena(&value); + } else { + *ptr_ = value; + } + } + + void SetNoArena(const ::std::string* default_value, ::std::string&& value) { + if (IsDefault(default_value)) { + ptr_ = new ::std::string(std::move(value)); + } else { + *ptr_ = std::move(value); + } + } + + void AssignWithDefault(const ::std::string* default_value, + ArenaStringPtr value); + + inline const ::std::string& GetNoArena() const { return *ptr_; } + + inline ::std::string* MutableNoArena(const ::std::string* default_value) { + if (ptr_ == default_value) { + CreateInstanceNoArena(default_value); + } + return ptr_; + } + + inline ::std::string* ReleaseNoArena(const ::std::string* default_value) { + if (ptr_ == default_value) { + return NULL; + } else { + return ReleaseNonDefaultNoArena(default_value); + } + } + + inline ::std::string* ReleaseNonDefaultNoArena( + const ::std::string* default_value) { + GOOGLE_DCHECK(!IsDefault(default_value)); + ::std::string* released = ptr_; + ptr_ = const_cast< ::std::string*>(default_value); + return released; + } + + inline void SetAllocatedNoArena(const ::std::string* default_value, + ::std::string* value) { + if (ptr_ != default_value) { + delete ptr_; + } + if (value != NULL) { + ptr_ = value; + } else { + ptr_ = const_cast< ::std::string*>(default_value); + } + } + + inline void DestroyNoArena(const ::std::string* default_value) { + if (ptr_ != default_value) { + delete ptr_; + } + } + + inline void ClearToEmptyNoArena(const ::std::string* default_value) { + if (ptr_ == default_value) { + // Nothing: already equal to default (which is the empty string). + } else { + ptr_->clear(); + } + } + + inline void ClearToDefaultNoArena(const ::std::string* default_value) { + if (ptr_ == default_value) { + // Nothing: already set to default. + } else { + // Reuse existing allocated instance. + *ptr_ = *default_value; + } + } + + // Internal accessor used only at parse time to provide direct access to the + // raw pointer from the shared parse routine (in the non-arenas case). The + // parse routine does the string allocation in order to save code size in the + // generated parsing code. + inline ::std::string** UnsafeRawStringPointer() { return &ptr_; } + + inline bool IsDefault(const ::std::string* default_value) const { + return ptr_ == default_value; + } + + // Internal accessors!!!! + void UnsafeSetTaggedPointer(TaggedPtr< ::std::string> value) { + ptr_ = value.Get(); + } + // Generated code only! An optimization, in certain cases the generated + // code is certain we can obtain a string with no default checks and + // tag tests. + ::std::string* UnsafeMutablePointer() { return ptr_; } + + private: + ::std::string* ptr_; + + PROTOBUF_NOINLINE + void CreateInstance(Arena* arena, const ::std::string* initial_value) { + GOOGLE_DCHECK(initial_value != NULL); + // uses "new ::std::string" when arena is nullptr + ptr_ = Arena::Create< ::std::string>(arena, *initial_value); + } + PROTOBUF_NOINLINE + void CreateInstanceNoArena(const ::std::string* initial_value) { + GOOGLE_DCHECK(initial_value != NULL); + ptr_ = new ::std::string(*initial_value); + } +}; + +} // namespace internal +} // namespace protobuf + +namespace protobuf { +namespace internal { + +inline void ArenaStringPtr::AssignWithDefault( + const ::std::string* default_value, ArenaStringPtr value) { + const ::std::string* me = *UnsafeRawStringPointer(); + const ::std::string* other = *value.UnsafeRawStringPointer(); + // If the pointers are the same then do nothing. + if (me != other) { + SetNoArena(default_value, value.GetNoArena()); + } +} + +} // namespace internal +} // namespace protobuf +} // namespace google + + +#include + +#endif // GOOGLE_PROTOBUF_ARENASTRING_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/code_generator.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/code_generator.h new file mode 100644 index 0000000000000000000000000000000000000000..c38c35da9761396b4fcb630d97c388422b7dd4c4 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/code_generator.h @@ -0,0 +1,197 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: kenton@google.com (Kenton Varda) +// Based on original Protocol Buffers design by +// Sanjay Ghemawat, Jeff Dean, and others. +// +// Defines the abstract interface implemented by each of the language-specific +// code generators. + +#ifndef GOOGLE_PROTOBUF_COMPILER_CODE_GENERATOR_H__ +#define GOOGLE_PROTOBUF_COMPILER_CODE_GENERATOR_H__ + +#include +#include +#include +#include + +#include + +namespace google { +namespace protobuf { + +namespace io { +class ZeroCopyOutputStream; +} +class FileDescriptor; + +namespace compiler { +class AccessInfoMap; + +class Version; + +// Defined in this file. +class CodeGenerator; +class GeneratorContext; + +// The abstract interface to a class which generates code implementing a +// particular proto file in a particular language. A number of these may +// be registered with CommandLineInterface to support various languages. +class PROTOC_EXPORT CodeGenerator { + public: + inline CodeGenerator() {} + virtual ~CodeGenerator(); + + // Generates code for the given proto file, generating one or more files in + // the given output directory. + // + // A parameter to be passed to the generator can be specified on the command + // line. This is intended to be used to pass generator specific parameters. + // It is empty if no parameter was given. ParseGeneratorParameter (below), + // can be used to accept multiple parameters within the single parameter + // command line flag. + // + // Returns true if successful. Otherwise, sets *error to a description of + // the problem (e.g. "invalid parameter") and returns false. + virtual bool Generate(const FileDescriptor* file, + const std::string& parameter, + GeneratorContext* generator_context, + std::string* error) const = 0; + + // Generates code for all given proto files. + // + // WARNING: The canonical code generator design produces one or two output + // files per input .proto file, and we do not wish to encourage alternate + // designs. + // + // A parameter is given as passed on the command line, as in |Generate()| + // above. + // + // Returns true if successful. Otherwise, sets *error to a description of + // the problem (e.g. "invalid parameter") and returns false. + virtual bool GenerateAll(const std::vector& files, + const std::string& parameter, + GeneratorContext* generator_context, + std::string* error) const; + + // Sync with plugin.proto. + enum Feature { + FEATURE_PROTO3_OPTIONAL = 1, + }; + + // Implement this to indicate what features this code generator supports. + // This should be a bitwise OR of features from the Features enum in + // plugin.proto. + virtual uint64_t GetSupportedFeatures() const { return 0; } + + // This is no longer used, but this class is part of the opensource protobuf + // library, so it has to remain to keep vtables the same for the current + // version of the library. When protobufs does a api breaking change, the + // method can be removed. + virtual bool HasGenerateAll() const { return true; } + + private: + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(CodeGenerator); +}; + +// CodeGenerators generate one or more files in a given directory. This +// abstract interface represents the directory to which the CodeGenerator is +// to write and other information about the context in which the Generator +// runs. +class PROTOC_EXPORT GeneratorContext { + public: + inline GeneratorContext() { + } + virtual ~GeneratorContext(); + + // Opens the given file, truncating it if it exists, and returns a + // ZeroCopyOutputStream that writes to the file. The caller takes ownership + // of the returned object. This method never fails (a dummy stream will be + // returned instead). + // + // The filename given should be relative to the root of the source tree. + // E.g. the C++ generator, when generating code for "foo/bar.proto", will + // generate the files "foo/bar.pb.h" and "foo/bar.pb.cc"; note that + // "foo/" is included in these filenames. The filename is not allowed to + // contain "." or ".." components. + virtual io::ZeroCopyOutputStream* Open(const std::string& filename) = 0; + + // Similar to Open() but the output will be appended to the file if exists + virtual io::ZeroCopyOutputStream* OpenForAppend(const std::string& filename); + + // Creates a ZeroCopyOutputStream which will insert code into the given file + // at the given insertion point. See plugin.proto (plugin.pb.h) for more + // information on insertion points. The default implementation + // assert-fails -- it exists only for backwards-compatibility. + // + // WARNING: This feature is currently EXPERIMENTAL and is subject to change. + virtual io::ZeroCopyOutputStream* OpenForInsert( + const std::string& filename, const std::string& insertion_point); + + // Returns a vector of FileDescriptors for all the files being compiled + // in this run. Useful for languages, such as Go, that treat files + // differently when compiled as a set rather than individually. + virtual void ListParsedFiles(std::vector* output); + + // Retrieves the version number of the protocol compiler associated with + // this GeneratorContext. + virtual void GetCompilerVersion(Version* version) const; + + + private: + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(GeneratorContext); +}; + +// The type GeneratorContext was once called OutputDirectory. This typedef +// provides backward compatibility. +typedef GeneratorContext OutputDirectory; + +// Several code generators treat the parameter argument as holding a +// list of options separated by commas. This helper function parses +// a set of comma-delimited name/value pairs: e.g., +// "foo=bar,baz,qux=corge" +// parses to the pairs: +// ("foo", "bar"), ("baz", ""), ("qux", "corge") +PROTOC_EXPORT void ParseGeneratorParameter( + const std::string&, std::vector >*); + +} // namespace compiler +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_COMPILER_CODE_GENERATOR_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/command_line_interface.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/command_line_interface.h new file mode 100644 index 0000000000000000000000000000000000000000..8f71eb7ed759e197781e1fb47773435eb88faf72 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/command_line_interface.h @@ -0,0 +1,468 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: kenton@google.com (Kenton Varda) +// Based on original Protocol Buffers design by +// Sanjay Ghemawat, Jeff Dean, and others. +// +// Implements the Protocol Compiler front-end such that it may be reused by +// custom compilers written to support other languages. + +#ifndef GOOGLE_PROTOBUF_COMPILER_COMMAND_LINE_INTERFACE_H__ +#define GOOGLE_PROTOBUF_COMPILER_COMMAND_LINE_INTERFACE_H__ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace google { +namespace protobuf { + +class Descriptor; // descriptor.h +class DescriptorDatabase; // descriptor_database.h +class DescriptorPool; // descriptor.h +class FileDescriptor; // descriptor.h +class FileDescriptorSet; // descriptor.h +class FileDescriptorProto; // descriptor.pb.h +template +class RepeatedPtrField; // repeated_field.h +class SimpleDescriptorDatabase; // descriptor_database.h + +namespace compiler { + +class CodeGenerator; // code_generator.h +class GeneratorContext; // code_generator.h +class DiskSourceTree; // importer.h + +// This class implements the command-line interface to the protocol compiler. +// It is designed to make it very easy to create a custom protocol compiler +// supporting the languages of your choice. For example, if you wanted to +// create a custom protocol compiler binary which includes both the regular +// C++ support plus support for your own custom output "Foo", you would +// write a class "FooGenerator" which implements the CodeGenerator interface, +// then write a main() procedure like this: +// +// int main(int argc, char* argv[]) { +// google::protobuf::compiler::CommandLineInterface cli; +// +// // Support generation of C++ source and headers. +// google::protobuf::compiler::cpp::CppGenerator cpp_generator; +// cli.RegisterGenerator("--cpp_out", &cpp_generator, +// "Generate C++ source and header."); +// +// // Support generation of Foo code. +// FooGenerator foo_generator; +// cli.RegisterGenerator("--foo_out", &foo_generator, +// "Generate Foo file."); +// +// return cli.Run(argc, argv); +// } +// +// The compiler is invoked with syntax like: +// protoc --cpp_out=outdir --foo_out=outdir --proto_path=src src/foo.proto +// +// The .proto file to compile can be specified on the command line using either +// its physical file path, or a virtual path relative to a directory specified +// in --proto_path. For example, for src/foo.proto, the following two protoc +// invocations work the same way: +// 1. protoc --proto_path=src src/foo.proto (physical file path) +// 2. protoc --proto_path=src foo.proto (virtual path relative to src) +// +// If a file path can be interpreted both as a physical file path and as a +// relative virtual path, the physical file path takes precendence. +// +// For a full description of the command-line syntax, invoke it with --help. +class PROTOC_EXPORT CommandLineInterface { + public: + static const char* const kPathSeparator; + + CommandLineInterface(); + ~CommandLineInterface(); + + // Register a code generator for a language. + // + // Parameters: + // * flag_name: The command-line flag used to specify an output file of + // this type. The name must start with a '-'. If the name is longer + // than one letter, it must start with two '-'s. + // * generator: The CodeGenerator which will be called to generate files + // of this type. + // * help_text: Text describing this flag in the --help output. + // + // Some generators accept extra parameters. You can specify this parameter + // on the command-line by placing it before the output directory, separated + // by a colon: + // protoc --foo_out=enable_bar:outdir + // The text before the colon is passed to CodeGenerator::Generate() as the + // "parameter". + void RegisterGenerator(const std::string& flag_name, CodeGenerator* generator, + const std::string& help_text); + + // Register a code generator for a language. + // Besides flag_name you can specify another option_flag_name that could be + // used to pass extra parameters to the registered code generator. + // Suppose you have registered a generator by calling: + // command_line_interface.RegisterGenerator("--foo_out", "--foo_opt", ...) + // Then you could invoke the compiler with a command like: + // protoc --foo_out=enable_bar:outdir --foo_opt=enable_baz + // This will pass "enable_bar,enable_baz" as the parameter to the generator. + void RegisterGenerator(const std::string& flag_name, + const std::string& option_flag_name, + CodeGenerator* generator, + const std::string& help_text); + + // Enables "plugins". In this mode, if a command-line flag ends with "_out" + // but does not match any registered generator, the compiler will attempt to + // find a "plugin" to implement the generator. Plugins are just executables. + // They should live somewhere in the PATH. + // + // The compiler determines the executable name to search for by concatenating + // exe_name_prefix with the unrecognized flag name, removing "_out". So, for + // example, if exe_name_prefix is "protoc-" and you pass the flag --foo_out, + // the compiler will try to run the program "protoc-gen-foo". + // + // The plugin program should implement the following usage: + // plugin [--out=OUTDIR] [--parameter=PARAMETER] PROTO_FILES < DESCRIPTORS + // --out indicates the output directory (as passed to the --foo_out + // parameter); if omitted, the current directory should be used. --parameter + // gives the generator parameter, if any was provided (see below). The + // PROTO_FILES list the .proto files which were given on the compiler + // command-line; these are the files for which the plugin is expected to + // generate output code. Finally, DESCRIPTORS is an encoded FileDescriptorSet + // (as defined in descriptor.proto). This is piped to the plugin's stdin. + // The set will include descriptors for all the files listed in PROTO_FILES as + // well as all files that they import. The plugin MUST NOT attempt to read + // the PROTO_FILES directly -- it must use the FileDescriptorSet. + // + // The plugin should generate whatever files are necessary, as code generators + // normally do. It should write the names of all files it generates to + // stdout. The names should be relative to the output directory, NOT absolute + // names or relative to the current directory. If any errors occur, error + // messages should be written to stderr. If an error is fatal, the plugin + // should exit with a non-zero exit code. + // + // Plugins can have generator parameters similar to normal built-in + // generators. Extra generator parameters can be passed in via a matching + // "_opt" parameter. For example: + // protoc --plug_out=enable_bar:outdir --plug_opt=enable_baz + // This will pass "enable_bar,enable_baz" as the parameter to the plugin. + // + void AllowPlugins(const std::string& exe_name_prefix); + + // Run the Protocol Compiler with the given command-line parameters. + // Returns the error code which should be returned by main(). + // + // It may not be safe to call Run() in a multi-threaded environment because + // it calls strerror(). I'm not sure why you'd want to do this anyway. + int Run(int argc, const char* const argv[]); + + // DEPRECATED. Calling this method has no effect. Protocol compiler now + // always try to find the .proto file relative to the current directory + // first and if the file is not found, it will then treat the input path + // as a virtual path. + void SetInputsAreProtoPathRelative(bool /* enable */) {} + + // Provides some text which will be printed when the --version flag is + // used. The version of libprotoc will also be printed on the next line + // after this text. + void SetVersionInfo(const std::string& text) { version_info_ = text; } + + + private: + // ----------------------------------------------------------------- + + class ErrorPrinter; + class GeneratorContextImpl; + class MemoryOutputStream; + typedef std::unordered_map> + GeneratorContextMap; + + // Clear state from previous Run(). + void Clear(); + + // Remaps the proto file so that it is relative to one of the directories + // in proto_path_. Returns false if an error occurred. + bool MakeProtoProtoPathRelative(DiskSourceTree* source_tree, + std::string* proto, + DescriptorDatabase* fallback_database); + + // Remaps each file in input_files_ so that it is relative to one of the + // directories in proto_path_. Returns false if an error occurred. + bool MakeInputsBeProtoPathRelative(DiskSourceTree* source_tree, + DescriptorDatabase* fallback_database); + + // Is this .proto file whitelisted, or do we have a command-line flag allowing + // us to use proto3 optional? This is a temporary control to avoid people from + // using proto3 optional until code generators have implemented it. + bool AllowProto3Optional(const FileDescriptor& file) const; + + // Fails if these files use proto3 optional and the code generator doesn't + // support it. This is a permanent check. + bool EnforceProto3OptionalSupport( + const std::string& codegen_name, uint64 supported_features, + const std::vector& parsed_files) const; + + + // Return status for ParseArguments() and InterpretArgument(). + enum ParseArgumentStatus { + PARSE_ARGUMENT_DONE_AND_CONTINUE, + PARSE_ARGUMENT_DONE_AND_EXIT, + PARSE_ARGUMENT_FAIL + }; + + // Parse all command-line arguments. + ParseArgumentStatus ParseArguments(int argc, const char* const argv[]); + + // Read an argument file and append the file's content to the list of + // arguments. Return false if the file cannot be read. + bool ExpandArgumentFile(const std::string& file, + std::vector* arguments); + + // Parses a command-line argument into a name/value pair. Returns + // true if the next argument in the argv should be used as the value, + // false otherwise. + // + // Examples: + // "-Isrc/protos" -> + // name = "-I", value = "src/protos" + // "--cpp_out=src/foo.pb2.cc" -> + // name = "--cpp_out", value = "src/foo.pb2.cc" + // "foo.proto" -> + // name = "", value = "foo.proto" + bool ParseArgument(const char* arg, std::string* name, std::string* value); + + // Interprets arguments parsed with ParseArgument. + ParseArgumentStatus InterpretArgument(const std::string& name, + const std::string& value); + + // Print the --help text to stderr. + void PrintHelpText(); + + // Loads proto_path_ into the provided source_tree. + bool InitializeDiskSourceTree(DiskSourceTree* source_tree, + DescriptorDatabase* fallback_database); + + // Verify that all the input files exist in the given database. + bool VerifyInputFilesInDescriptors(DescriptorDatabase* fallback_database); + + // Parses input_files_ into parsed_files + bool ParseInputFiles(DescriptorPool* descriptor_pool, + DiskSourceTree* source_tree, + std::vector* parsed_files); + + // Generate the given output file from the given input. + struct OutputDirective; // see below + bool GenerateOutput(const std::vector& parsed_files, + const OutputDirective& output_directive, + GeneratorContext* generator_context); + bool GeneratePluginOutput( + const std::vector& parsed_files, + const std::string& plugin_name, const std::string& parameter, + GeneratorContext* generator_context, std::string* error); + + // Implements --encode and --decode. + bool EncodeOrDecode(const DescriptorPool* pool); + + // Implements the --descriptor_set_out option. + bool WriteDescriptorSet( + const std::vector& parsed_files); + + // Implements the --dependency_out option + bool GenerateDependencyManifestFile( + const std::vector& parsed_files, + const GeneratorContextMap& output_directories, + DiskSourceTree* source_tree); + + // Get all transitive dependencies of the given file (including the file + // itself), adding them to the given list of FileDescriptorProtos. The + // protos will be ordered such that every file is listed before any file that + // depends on it, so that you can call DescriptorPool::BuildFile() on them + // in order. Any files in *already_seen will not be added, and each file + // added will be inserted into *already_seen. If include_source_code_info is + // true then include the source code information in the FileDescriptorProtos. + // If include_json_name is true, populate the json_name field of + // FieldDescriptorProto for all fields. + static void GetTransitiveDependencies( + const FileDescriptor* file, bool include_json_name, + bool include_source_code_info, + std::set* already_seen, + RepeatedPtrField* output); + + // Implements the --print_free_field_numbers. This function prints free field + // numbers into stdout for the message and it's nested message types in + // post-order, i.e. nested types first. Printed range are left-right + // inclusive, i.e. [a, b]. + // + // Groups: + // For historical reasons, groups are considered to share the same + // field number space with the parent message, thus it will not print free + // field numbers for groups. The field numbers used in the groups are + // excluded in the free field numbers of the parent message. + // + // Extension Ranges: + // Extension ranges are considered ocuppied field numbers and they will not be + // listed as free numbers in the output. + void PrintFreeFieldNumbers(const Descriptor* descriptor); + + // ----------------------------------------------------------------- + + // The name of the executable as invoked (i.e. argv[0]). + std::string executable_name_; + + // Version info set with SetVersionInfo(). + std::string version_info_; + + // Registered generators. + struct GeneratorInfo { + std::string flag_name; + std::string option_flag_name; + CodeGenerator* generator; + std::string help_text; + }; + typedef std::map GeneratorMap; + GeneratorMap generators_by_flag_name_; + GeneratorMap generators_by_option_name_; + // A map from generator names to the parameters specified using the option + // flag. For example, if the user invokes the compiler with: + // protoc --foo_out=outputdir --foo_opt=enable_bar ... + // Then there will be an entry ("--foo_out", "enable_bar") in this map. + std::map generator_parameters_; + // Similar to generator_parameters_, but stores the parameters for plugins. + std::map plugin_parameters_; + + // See AllowPlugins(). If this is empty, plugins aren't allowed. + std::string plugin_prefix_; + + // Maps specific plugin names to files. When executing a plugin, this map + // is searched first to find the plugin executable. If not found here, the + // PATH (or other OS-specific search strategy) is searched. + std::map plugins_; + + // Stuff parsed from command line. + enum Mode { + MODE_COMPILE, // Normal mode: parse .proto files and compile them. + MODE_ENCODE, // --encode: read text from stdin, write binary to stdout. + MODE_DECODE, // --decode: read binary from stdin, write text to stdout. + MODE_PRINT, // Print mode: print info of the given .proto files and exit. + }; + + Mode mode_ = MODE_COMPILE; + + enum PrintMode { + PRINT_NONE, // Not in MODE_PRINT + PRINT_FREE_FIELDS, // --print_free_fields + }; + + PrintMode print_mode_ = PRINT_NONE; + + enum ErrorFormat { + ERROR_FORMAT_GCC, // GCC error output format (default). + ERROR_FORMAT_MSVS // Visual Studio output (--error_format=msvs). + }; + + ErrorFormat error_format_ = ERROR_FORMAT_GCC; + + std::vector > + proto_path_; // Search path for proto files. + std::vector input_files_; // Names of the input proto files. + + // Names of proto files which are allowed to be imported. Used by build + // systems to enforce depend-on-what-you-import. + std::set direct_dependencies_; + bool direct_dependencies_explicitly_set_ = false; + + // If there's a violation of depend-on-what-you-import, this string will be + // presented to the user. "%s" will be replaced with the violating import. + std::string direct_dependencies_violation_msg_; + + // output_directives_ lists all the files we are supposed to output and what + // generator to use for each. + struct OutputDirective { + std::string name; // E.g. "--foo_out" + CodeGenerator* generator; // NULL for plugins + std::string parameter; + std::string output_location; + }; + std::vector output_directives_; + + // When using --encode or --decode, this names the type we are encoding or + // decoding. (Empty string indicates --decode_raw.) + std::string codec_type_; + + // If --descriptor_set_in was given, these are filenames containing + // parsed FileDescriptorSets to be used for loading protos. Otherwise, empty. + std::vector descriptor_set_in_names_; + + // If --descriptor_set_out was given, this is the filename to which the + // FileDescriptorSet should be written. Otherwise, empty. + std::string descriptor_set_out_name_; + + // If --dependency_out was given, this is the path to the file where the + // dependency file will be written. Otherwise, empty. + std::string dependency_out_name_; + + // True if --include_imports was given, meaning that we should + // write all transitive dependencies to the DescriptorSet. Otherwise, only + // the .proto files listed on the command-line are added. + bool imports_in_descriptor_set_; + + // True if --include_source_info was given, meaning that we should not strip + // SourceCodeInfo from the DescriptorSet. + bool source_info_in_descriptor_set_ = false; + + // Was the --disallow_services flag used? + bool disallow_services_ = false; + + // Was the --experimental_allow_proto3_optional flag used? + bool allow_proto3_optional_ = false; + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(CommandLineInterface); +}; + +} // namespace compiler +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_COMPILER_COMMAND_LINE_INTERFACE_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/cpp/cpp_generator.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/cpp/cpp_generator.h new file mode 100644 index 0000000000000000000000000000000000000000..365b8e12f8f7897b0bb881b891ae525a2ca72ca9 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/cpp/cpp_generator.h @@ -0,0 +1,111 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: kenton@google.com (Kenton Varda) +// Based on original Protocol Buffers design by +// Sanjay Ghemawat, Jeff Dean, and others. +// +// Generates C++ code for a given .proto file. + +#ifndef GOOGLE_PROTOBUF_COMPILER_CPP_GENERATOR_H__ +#define GOOGLE_PROTOBUF_COMPILER_CPP_GENERATOR_H__ + +#include +#include + +#include + +namespace google { +namespace protobuf { +namespace compiler { +namespace cpp { + +// CodeGenerator implementation which generates a C++ source file and +// header. If you create your own protocol compiler binary and you want +// it to support C++ output, you can do so by registering an instance of this +// CodeGenerator with the CommandLineInterface in your main() function. +class PROTOC_EXPORT CppGenerator : public CodeGenerator { + public: + CppGenerator(); + ~CppGenerator(); + + enum class Runtime { + kGoogle3, // Use the internal google3 runtime. + kOpensource, // Use the open-source runtime. + + // Use the open-source runtime with google3 #include paths. We make these + // absolute to avoid ambiguity, so the runtime will be #included like: + // #include "third_party/protobuf/.../google/protobuf/message.h" + kOpensourceGoogle3 + }; + + void set_opensource_runtime(bool opensource) { + opensource_runtime_ = opensource; + } + + // If set to a non-empty string, generated code will do: + // #include "/google/protobuf/message.h" + // instead of: + // #include + // This has no effect if opensource_runtime = false. + void set_runtime_include_base(const std::string& base) { + runtime_include_base_ = base; + } + + // implements CodeGenerator ---------------------------------------- + bool Generate(const FileDescriptor* file, const std::string& parameter, + GeneratorContext* generator_context, + std::string* error) const override; + + uint64_t GetSupportedFeatures() const override { + // We don't fully support this yet, but this is needed to unblock the tests, + // and we will have full support before the experimental flag is removed. + return FEATURE_PROTO3_OPTIONAL; + } + + private: + bool opensource_runtime_ = true; + std::string runtime_include_base_; + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(CppGenerator); +}; + +} // namespace cpp +} // namespace compiler +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_COMPILER_CPP_GENERATOR_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/csharp/csharp_generator.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/csharp/csharp_generator.h new file mode 100644 index 0000000000000000000000000000000000000000..b85832eb7688472c8547cadd4c4e83ffba6d1fec --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/csharp/csharp_generator.h @@ -0,0 +1,75 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Generates C# code for a given .proto file. + +#ifndef GOOGLE_PROTOBUF_COMPILER_CSHARP_GENERATOR_H__ +#define GOOGLE_PROTOBUF_COMPILER_CSHARP_GENERATOR_H__ + +#include + +#include + +#include + +namespace google { +namespace protobuf { +namespace compiler { +namespace csharp { + +// CodeGenerator implementation which generates a C# source file and +// header. If you create your own protocol compiler binary and you want +// it to support C# output, you can do so by registering an instance of this +// CodeGenerator with the CommandLineInterface in your main() function. +class PROTOC_EXPORT Generator : public CodeGenerator { + public: + Generator(); + ~Generator(); + bool Generate( + const FileDescriptor* file, + const string& parameter, + GeneratorContext* generator_context, + string* error) const override; + uint64_t GetSupportedFeatures() const override; +}; + +} // namespace csharp +} // namespace compiler +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_COMPILER_CSHARP_GENERATOR_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/csharp/csharp_names.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/csharp/csharp_names.h new file mode 100644 index 0000000000000000000000000000000000000000..972b097817a14ef490829ad2fd725a58122afbcf --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/csharp/csharp_names.h @@ -0,0 +1,112 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: kenton@google.com (Kenton Varda) +// Based on original Protocol Buffers design by +// Sanjay Ghemawat, Jeff Dean, and others. +// +// Provides a mechanism for mapping a descriptor to the +// fully-qualified name of the corresponding C# class. + +#ifndef GOOGLE_PROTOBUF_COMPILER_CSHARP_NAMES_H__ +#define GOOGLE_PROTOBUF_COMPILER_CSHARP_NAMES_H__ + +#include +#include +#include + +#include + +namespace google { +namespace protobuf { + +class Descriptor; +class EnumDescriptor; +class FileDescriptor; +class ServiceDescriptor; + +namespace compiler { +namespace csharp { + +// Requires: +// descriptor != NULL +// +// Returns: +// The namespace to use for given file descriptor. +string PROTOC_EXPORT GetFileNamespace(const FileDescriptor* descriptor); + +// Requires: +// descriptor != NULL +// +// Returns: +// The fully-qualified C# class name. +string PROTOC_EXPORT GetClassName(const Descriptor* descriptor); + +// Requires: +// descriptor != NULL +// +// Returns: +// The fully-qualified name of the C# class that provides +// access to the file descriptor. Proto compiler generates +// such class for each .proto file processed. +string PROTOC_EXPORT GetReflectionClassName(const FileDescriptor* descriptor); + +// Generates output file name for given file descriptor. If generate_directories +// is true, the output file will be put under directory corresponding to file's +// namespace. base_namespace can be used to strip some of the top level +// directories. E.g. for file with namespace "Bar.Foo" and base_namespace="Bar", +// the resulting file will be put under directory "Foo" (and not "Bar/Foo"). +// +// Requires: +// descriptor != NULL +// error != NULL +// +// Returns: +// The file name to use as output file for given file descriptor. In case +// of failure, this function will return empty string and error parameter +// will contain the error message. +string PROTOC_EXPORT GetOutputFile(const FileDescriptor* descriptor, + const string file_extension, + const bool generate_directories, + const string base_namespace, string* error); + +} // namespace csharp +} // namespace compiler +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_COMPILER_CSHARP_NAMES_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/importer.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/importer.h new file mode 100644 index 0000000000000000000000000000000000000000..1f2c6df0962a9141c3b567524cc6062e00f4ea11 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/importer.h @@ -0,0 +1,341 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: kenton@google.com (Kenton Varda) +// Based on original Protocol Buffers design by +// Sanjay Ghemawat, Jeff Dean, and others. +// +// This file is the public interface to the .proto file parser. + +#ifndef GOOGLE_PROTOBUF_COMPILER_IMPORTER_H__ +#define GOOGLE_PROTOBUF_COMPILER_IMPORTER_H__ + +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace google { +namespace protobuf { + +namespace io { +class ZeroCopyInputStream; +} + +namespace compiler { + +// Defined in this file. +class Importer; +class MultiFileErrorCollector; +class SourceTree; +class DiskSourceTree; + +// TODO(kenton): Move all SourceTree stuff to a separate file? + +// An implementation of DescriptorDatabase which loads files from a SourceTree +// and parses them. +// +// Note: This class is not thread-safe since it maintains a table of source +// code locations for error reporting. However, when a DescriptorPool wraps +// a DescriptorDatabase, it uses mutex locking to make sure only one method +// of the database is called at a time, even if the DescriptorPool is used +// from multiple threads. Therefore, there is only a problem if you create +// multiple DescriptorPools wrapping the same SourceTreeDescriptorDatabase +// and use them from multiple threads. +// +// Note: This class does not implement FindFileContainingSymbol() or +// FindFileContainingExtension(); these will always return false. +class PROTOBUF_EXPORT SourceTreeDescriptorDatabase : public DescriptorDatabase { + public: + SourceTreeDescriptorDatabase(SourceTree* source_tree); + + // If non-NULL, fallback_database will be checked if a file doesn't exist in + // the specified source_tree. + SourceTreeDescriptorDatabase(SourceTree* source_tree, + DescriptorDatabase* fallback_database); + ~SourceTreeDescriptorDatabase(); + + // Instructs the SourceTreeDescriptorDatabase to report any parse errors + // to the given MultiFileErrorCollector. This should be called before + // parsing. error_collector must remain valid until either this method + // is called again or the SourceTreeDescriptorDatabase is destroyed. + void RecordErrorsTo(MultiFileErrorCollector* error_collector) { + error_collector_ = error_collector; + } + + // Gets a DescriptorPool::ErrorCollector which records errors to the + // MultiFileErrorCollector specified with RecordErrorsTo(). This collector + // has the ability to determine exact line and column numbers of errors + // from the information given to it by the DescriptorPool. + DescriptorPool::ErrorCollector* GetValidationErrorCollector() { + using_validation_error_collector_ = true; + return &validation_error_collector_; + } + + // implements DescriptorDatabase ----------------------------------- + bool FindFileByName(const std::string& filename, + FileDescriptorProto* output) override; + bool FindFileContainingSymbol(const std::string& symbol_name, + FileDescriptorProto* output) override; + bool FindFileContainingExtension(const std::string& containing_type, + int field_number, + FileDescriptorProto* output) override; + + private: + class SingleFileErrorCollector; + + SourceTree* source_tree_; + DescriptorDatabase* fallback_database_; + MultiFileErrorCollector* error_collector_; + + class PROTOBUF_EXPORT ValidationErrorCollector + : public DescriptorPool::ErrorCollector { + public: + ValidationErrorCollector(SourceTreeDescriptorDatabase* owner); + ~ValidationErrorCollector(); + + // implements ErrorCollector --------------------------------------- + void AddError(const std::string& filename, const std::string& element_name, + const Message* descriptor, ErrorLocation location, + const std::string& message) override; + + void AddWarning(const std::string& filename, + const std::string& element_name, const Message* descriptor, + ErrorLocation location, + const std::string& message) override; + + private: + SourceTreeDescriptorDatabase* owner_; + }; + friend class ValidationErrorCollector; + + bool using_validation_error_collector_; + SourceLocationTable source_locations_; + ValidationErrorCollector validation_error_collector_; +}; + +// Simple interface for parsing .proto files. This wraps the process +// of opening the file, parsing it with a Parser, recursively parsing all its +// imports, and then cross-linking the results to produce a FileDescriptor. +// +// This is really just a thin wrapper around SourceTreeDescriptorDatabase. +// You may find that SourceTreeDescriptorDatabase is more flexible. +// +// TODO(kenton): I feel like this class is not well-named. +class PROTOBUF_EXPORT Importer { + public: + Importer(SourceTree* source_tree, MultiFileErrorCollector* error_collector); + ~Importer(); + + // Import the given file and build a FileDescriptor representing it. If + // the file is already in the DescriptorPool, the existing FileDescriptor + // will be returned. The FileDescriptor is property of the DescriptorPool, + // and will remain valid until it is destroyed. If any errors occur, they + // will be reported using the error collector and Import() will return NULL. + // + // A particular Importer object will only report errors for a particular + // file once. All future attempts to import the same file will return NULL + // without reporting any errors. The idea is that you might want to import + // a lot of files without seeing the same errors over and over again. If + // you want to see errors for the same files repeatedly, you can use a + // separate Importer object to import each one (but use the same + // DescriptorPool so that they can be cross-linked). + const FileDescriptor* Import(const std::string& filename); + + // The DescriptorPool in which all imported FileDescriptors and their + // contents are stored. + inline const DescriptorPool* pool() const { return &pool_; } + + void AddUnusedImportTrackFile(const std::string& file_name, + bool is_error = false); + void ClearUnusedImportTrackFiles(); + + + private: + SourceTreeDescriptorDatabase database_; + DescriptorPool pool_; + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(Importer); +}; + +// If the importer encounters problems while trying to import the proto files, +// it reports them to a MultiFileErrorCollector. +class PROTOBUF_EXPORT MultiFileErrorCollector { + public: + inline MultiFileErrorCollector() {} + virtual ~MultiFileErrorCollector(); + + // Line and column numbers are zero-based. A line number of -1 indicates + // an error with the entire file (e.g. "not found"). + virtual void AddError(const std::string& filename, int line, int column, + const std::string& message) = 0; + + virtual void AddWarning(const std::string& filename, int line, int column, + const std::string& message) {} + + private: + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(MultiFileErrorCollector); +}; + +// Abstract interface which represents a directory tree containing proto files. +// Used by the default implementation of Importer to resolve import statements +// Most users will probably want to use the DiskSourceTree implementation, +// below. +class PROTOBUF_EXPORT SourceTree { + public: + inline SourceTree() {} + virtual ~SourceTree(); + + // Open the given file and return a stream that reads it, or NULL if not + // found. The caller takes ownership of the returned object. The filename + // must be a path relative to the root of the source tree and must not + // contain "." or ".." components. + virtual io::ZeroCopyInputStream* Open(const std::string& filename) = 0; + + // If Open() returns NULL, calling this method immediately will return an + // description of the error. + // Subclasses should implement this method and return a meaningful value for + // better error reporting. + // TODO(xiaofeng): change this to a pure virtual function. + virtual std::string GetLastErrorMessage(); + + private: + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(SourceTree); +}; + +// An implementation of SourceTree which loads files from locations on disk. +// Multiple mappings can be set up to map locations in the DiskSourceTree to +// locations in the physical filesystem. +class PROTOBUF_EXPORT DiskSourceTree : public SourceTree { + public: + DiskSourceTree(); + ~DiskSourceTree(); + + // Map a path on disk to a location in the SourceTree. The path may be + // either a file or a directory. If it is a directory, the entire tree + // under it will be mapped to the given virtual location. To map a directory + // to the root of the source tree, pass an empty string for virtual_path. + // + // If multiple mapped paths apply when opening a file, they will be searched + // in order. For example, if you do: + // MapPath("bar", "foo/bar"); + // MapPath("", "baz"); + // and then you do: + // Open("bar/qux"); + // the DiskSourceTree will first try to open foo/bar/qux, then baz/bar/qux, + // returning the first one that opens successfully. + // + // disk_path may be an absolute path or relative to the current directory, + // just like a path you'd pass to open(). + void MapPath(const std::string& virtual_path, const std::string& disk_path); + + // Return type for DiskFileToVirtualFile(). + enum DiskFileToVirtualFileResult { + SUCCESS, + SHADOWED, + CANNOT_OPEN, + NO_MAPPING + }; + + // Given a path to a file on disk, find a virtual path mapping to that + // file. The first mapping created with MapPath() whose disk_path contains + // the filename is used. However, that virtual path may not actually be + // usable to open the given file. Possible return values are: + // * SUCCESS: The mapping was found. *virtual_file is filled in so that + // calling Open(*virtual_file) will open the file named by disk_file. + // * SHADOWED: A mapping was found, but using Open() to open this virtual + // path will end up returning some different file. This is because some + // other mapping with a higher precedence also matches this virtual path + // and maps it to a different file that exists on disk. *virtual_file + // is filled in as it would be in the SUCCESS case. *shadowing_disk_file + // is filled in with the disk path of the file which would be opened if + // you were to call Open(*virtual_file). + // * CANNOT_OPEN: The mapping was found and was not shadowed, but the + // file specified cannot be opened. When this value is returned, + // errno will indicate the reason the file cannot be opened. *virtual_file + // will be set to the virtual path as in the SUCCESS case, even though + // it is not useful. + // * NO_MAPPING: Indicates that no mapping was found which contains this + // file. + DiskFileToVirtualFileResult DiskFileToVirtualFile( + const std::string& disk_file, std::string* virtual_file, + std::string* shadowing_disk_file); + + // Given a virtual path, find the path to the file on disk. + // Return true and update disk_file with the on-disk path if the file exists. + // Return false and leave disk_file untouched if the file doesn't exist. + bool VirtualFileToDiskFile(const std::string& virtual_file, + std::string* disk_file); + + // implements SourceTree ------------------------------------------- + io::ZeroCopyInputStream* Open(const std::string& filename) override; + + std::string GetLastErrorMessage() override; + + private: + struct Mapping { + std::string virtual_path; + std::string disk_path; + + inline Mapping(const std::string& virtual_path_param, + const std::string& disk_path_param) + : virtual_path(virtual_path_param), disk_path(disk_path_param) {} + }; + std::vector mappings_; + std::string last_error_message_; + + // Like Open(), but returns the on-disk path in disk_file if disk_file is + // non-NULL and the file could be successfully opened. + io::ZeroCopyInputStream* OpenVirtualFile(const std::string& virtual_file, + std::string* disk_file); + + // Like Open() but given the actual on-disk path. + io::ZeroCopyInputStream* OpenDiskFile(const std::string& filename); + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(DiskSourceTree); +}; + +} // namespace compiler +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_COMPILER_IMPORTER_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/java/java_generator.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/java/java_generator.h new file mode 100644 index 0000000000000000000000000000000000000000..99014924c68d1b481a64ef1a65cf3d787356c511 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/java/java_generator.h @@ -0,0 +1,81 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: kenton@google.com (Kenton Varda) +// Based on original Protocol Buffers design by +// Sanjay Ghemawat, Jeff Dean, and others. +// +// Generates Java code for a given .proto file. + +#ifndef GOOGLE_PROTOBUF_COMPILER_JAVA_GENERATOR_H__ +#define GOOGLE_PROTOBUF_COMPILER_JAVA_GENERATOR_H__ + +#include +#include + +#include + +namespace google { +namespace protobuf { +namespace compiler { +namespace java { + +// CodeGenerator implementation which generates Java code. If you create your +// own protocol compiler binary and you want it to support Java output, you +// can do so by registering an instance of this CodeGenerator with the +// CommandLineInterface in your main() function. +class PROTOC_EXPORT JavaGenerator : public CodeGenerator { + public: + JavaGenerator(); + ~JavaGenerator(); + + // implements CodeGenerator ---------------------------------------- + bool Generate(const FileDescriptor* file, const std::string& parameter, + GeneratorContext* context, std::string* error) const override; + + uint64_t GetSupportedFeatures() const override; + + private: + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(JavaGenerator); +}; + +} // namespace java +} // namespace compiler +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_COMPILER_JAVA_GENERATOR_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/java/java_names.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/java/java_names.h new file mode 100644 index 0000000000000000000000000000000000000000..1e82f60fb15b283a8db5e33f3c0ec2b02288ab8b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/java/java_names.h @@ -0,0 +1,117 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: kenton@google.com (Kenton Varda) +// Based on original Protocol Buffers design by +// Sanjay Ghemawat, Jeff Dean, and others. +// +// Provides a mechanism for mapping a descriptor to the +// fully-qualified name of the corresponding Java class. + +#ifndef GOOGLE_PROTOBUF_COMPILER_JAVA_NAMES_H__ +#define GOOGLE_PROTOBUF_COMPILER_JAVA_NAMES_H__ + +#include + +namespace google { +namespace protobuf { + +class Descriptor; +class EnumDescriptor; +class FileDescriptor; +class FieldDescriptor; +class ServiceDescriptor; + +namespace compiler { +namespace java { + +// Requires: +// descriptor != NULL +// +// Returns: +// The fully-qualified Java class name. +std::string ClassName(const Descriptor* descriptor); + +// Requires: +// descriptor != NULL +// +// Returns: +// The fully-qualified Java class name. +std::string ClassName(const EnumDescriptor* descriptor); + +// Requires: +// descriptor != NULL +// +// Returns: +// The fully-qualified Java class name. +std::string ClassName(const FileDescriptor* descriptor); + +// Requires: +// descriptor != NULL +// +// Returns: +// The fully-qualified Java class name. +std::string ClassName(const ServiceDescriptor* descriptor); + +// Requires: +// descriptor != NULL +// +// Returns: +// Java package name. +std::string FileJavaPackage(const FileDescriptor* descriptor); + +// Requires: +// descriptor != NULL +// Returns: +// Capitalized camel case name field name. +std::string CapitalizedFieldName(const FieldDescriptor* descriptor); + +// Requires: +// descriptor != NULL +// Returns: +// Primitive Java type name for the field. +const char* PrimitiveTypeName(const FieldDescriptor* descriptor); + +// Requires: +// descriptor != NULL +// Returns: +// Boes primitive Java type name for the field. +const char* BoxedPrimitiveTypeName(const FieldDescriptor* descriptor); + +} // namespace java +} // namespace compiler +} // namespace protobuf +} // namespace google +#endif // GOOGLE_PROTOBUF_COMPILER_JAVA_NAMES_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/js/js_generator.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/js/js_generator.h new file mode 100644 index 0000000000000000000000000000000000000000..87f69bd39e91d3fd5857ca9513e1469a25dc6409 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/js/js_generator.h @@ -0,0 +1,344 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Generates JavaScript code for a given .proto file. +// +#ifndef GOOGLE_PROTOBUF_COMPILER_JS_GENERATOR_H__ +#define GOOGLE_PROTOBUF_COMPILER_JS_GENERATOR_H__ + +#include +#include + +#include +#include +#include +#include + +#include + +namespace google { +namespace protobuf { + +class Descriptor; +class EnumDescriptor; +class FieldDescriptor; +class OneofDescriptor; +class FileDescriptor; + +namespace io { +class Printer; +} + +namespace compiler { +namespace js { + +struct GeneratorOptions { + // Output path. + std::string output_dir; + // Namespace prefix. + std::string namespace_prefix; + // Enable binary-format support? + bool binary; + // What style of imports should be used. + enum ImportStyle { + kImportClosure, // goog.require() + kImportCommonJs, // require() + kImportCommonJsStrict, // require() with no global export + kImportBrowser, // no import statements + kImportEs6, // import { member } from '' + } import_style; + + GeneratorOptions() + : output_dir("."), + namespace_prefix(""), + binary(false), + import_style(kImportClosure), + add_require_for_enums(false), + testonly(false), + library(""), + error_on_name_conflict(false), + extension(".js"), + one_output_file_per_input_file(false), + annotate_code(false) {} + + bool ParseFromOptions( + const std::vector >& options, + std::string* error); + + // Returns the file name extension to use for generated code. + std::string GetFileNameExtension() const { + return import_style == kImportClosure ? extension : "_pb.js"; + } + + enum OutputMode { + // Create an output file for each input .proto file. + kOneOutputFilePerInputFile, + // Create an output file for each type. + kOneOutputFilePerSCC, + // Put everything in a single file named by the library option. + kEverythingInOneFile, + }; + + // Indicates how to output the generated code based on the provided options. + OutputMode output_mode() const; + + // The remaining options are only relevant when we are using kImportClosure. + + // Add a `goog.requires()` call for each enum type used. If not set, a + // forward declaration with `goog.forwardDeclare` is produced instead. + bool add_require_for_enums; + // Set this as a test-only module via `goog.setTestOnly();`. + bool testonly; + // Create a library with name _lib.js rather than a separate .js file + // per type? + std::string library; + // Error if there are two types that would generate the same output file? + bool error_on_name_conflict; + // The extension to use for output file names. + std::string extension; + // Create a separate output file for each input file? + bool one_output_file_per_input_file; + // If true, we should append annotations as commen on the last line for + // generated .js file. Annotations used by tools like https://kythe.io + // to provide cross-references between .js and .proto files. Annotations + // are enced as base64 proto of GeneratedCodeInfo message (see + // descriptor.proto). + bool annotate_code; +}; + +// CodeGenerator implementation which generates a JavaScript source file and +// header. If you create your own protocol compiler binary and you want it to +// support JavaScript output, you can do so by registering an instance of this +// CodeGenerator with the CommandLineInterface in your main() function. +class PROTOC_EXPORT Generator : public CodeGenerator { + public: + Generator() {} + virtual ~Generator() {} + + bool Generate(const FileDescriptor* file, const std::string& parameter, + GeneratorContext* context, std::string* error) const override { + *error = "Unimplemented Generate() method. Call GenerateAll() instead."; + return false; + } + + bool HasGenerateAll() const override { return true; } + + bool GenerateAll(const std::vector& files, + const std::string& parameter, GeneratorContext* context, + std::string* error) const override; + + uint64 GetSupportedFeatures() const override { + return FEATURE_PROTO3_OPTIONAL; + } + + private: + void GenerateHeader(const GeneratorOptions& options, + const FileDescriptor* file, io::Printer* printer) const; + + // Generate goog.provides() calls. + void FindProvides(const GeneratorOptions& options, io::Printer* printer, + const std::vector& file, + std::set* provided) const; + void FindProvidesForFile(const GeneratorOptions& options, + io::Printer* printer, const FileDescriptor* file, + std::set* provided) const; + void FindProvidesForMessage(const GeneratorOptions& options, + io::Printer* printer, const Descriptor* desc, + std::set* provided) const; + void FindProvidesForEnum(const GeneratorOptions& options, + io::Printer* printer, const EnumDescriptor* enumdesc, + std::set* provided) const; + // For extension fields at file scope. + void FindProvidesForFields(const GeneratorOptions& options, + io::Printer* printer, + const std::vector& fields, + std::set* provided) const; + // Print the goog.provides() found by the methods above. + void GenerateProvides(const GeneratorOptions& options, io::Printer* printer, + std::set* provided) const; + + // Generate goog.setTestOnly() if indicated. + void GenerateTestOnly(const GeneratorOptions& options, + io::Printer* printer) const; + + // Generate goog.requires() calls. + void GenerateRequiresForLibrary( + const GeneratorOptions& options, io::Printer* printer, + const std::vector& files, + std::set* provided) const; + void GenerateRequiresForSCC(const GeneratorOptions& options, + io::Printer* printer, const SCC* scc, + std::set* provided) const; + // For extension fields at file scope. + void GenerateRequiresForExtensions( + const GeneratorOptions& options, io::Printer* printer, + const std::vector& fields, + std::set* provided) const; + void GenerateRequiresImpl(const GeneratorOptions& options, + io::Printer* printer, + std::set* required, + std::set* forwards, + std::set* provided, bool require_jspb, + bool require_extension, bool require_map) const; + void FindRequiresForMessage(const GeneratorOptions& options, + const Descriptor* desc, + std::set* required, + std::set* forwards, + bool* have_message) const; + void FindRequiresForField(const GeneratorOptions& options, + const FieldDescriptor* field, + std::set* required, + std::set* forwards) const; + void FindRequiresForExtension(const GeneratorOptions& options, + const FieldDescriptor* field, + std::set* required, + std::set* forwards) const; + // Generate all things in a proto file into one file. + // If use_short_name is true, the generated file's name will only be short + // name that without directory, otherwise filename equals file->name() + bool GenerateFile(const FileDescriptor* file, const GeneratorOptions& options, + GeneratorContext* context, bool use_short_name) const; + void GenerateFile(const GeneratorOptions& options, io::Printer* printer, + const FileDescriptor* file) const; + + // Generate definitions for all message classes and enums in all files, + // processing the files in dependence order. + void GenerateFilesInDepOrder( + const GeneratorOptions& options, io::Printer* printer, + const std::vector& file) const; + // Helper for above. + void GenerateFileAndDeps(const GeneratorOptions& options, + io::Printer* printer, const FileDescriptor* root, + std::set* all_files, + std::set* generated) const; + + // Generate definitions for all message classes and enums. + void GenerateClassesAndEnums(const GeneratorOptions& options, + io::Printer* printer, + const FileDescriptor* file) const; + + void GenerateFieldValueExpression(io::Printer* printer, + const char* obj_reference, + const FieldDescriptor* field, + bool use_default) const; + + // Generate definition for one class. + void GenerateClass(const GeneratorOptions& options, io::Printer* printer, + const Descriptor* desc) const; + void GenerateClassConstructor(const GeneratorOptions& options, + io::Printer* printer, + const Descriptor* desc) const; + void GenerateClassFieldInfo(const GeneratorOptions& options, + io::Printer* printer, + const Descriptor* desc) const; + void GenerateClassConstructorAndDeclareExtensionFieldInfo( + const GeneratorOptions& options, io::Printer* printer, + const Descriptor* desc) const; + void GenerateClassXid(const GeneratorOptions& options, io::Printer* printer, + const Descriptor* desc) const; + void GenerateOneofCaseDefinition(const GeneratorOptions& options, + io::Printer* printer, + const OneofDescriptor* oneof) const; + void GenerateObjectTypedef(const GeneratorOptions& options, + io::Printer* printer, + const Descriptor* desc) const; + void GenerateClassToObject(const GeneratorOptions& options, + io::Printer* printer, + const Descriptor* desc) const; + void GenerateClassFieldToObject(const GeneratorOptions& options, + io::Printer* printer, + const FieldDescriptor* field) const; + void GenerateClassFromObject(const GeneratorOptions& options, + io::Printer* printer, + const Descriptor* desc) const; + void GenerateClassFieldFromObject(const GeneratorOptions& options, + io::Printer* printer, + const FieldDescriptor* field) const; + void GenerateClassRegistration(const GeneratorOptions& options, + io::Printer* printer, + const Descriptor* desc) const; + void GenerateClassFields(const GeneratorOptions& options, + io::Printer* printer, const Descriptor* desc) const; + void GenerateClassField(const GeneratorOptions& options, io::Printer* printer, + const FieldDescriptor* desc) const; + void GenerateClassExtensionFieldInfo(const GeneratorOptions& options, + io::Printer* printer, + const Descriptor* desc) const; + void GenerateClassDeserialize(const GeneratorOptions& options, + io::Printer* printer, + const Descriptor* desc) const; + void GenerateClassDeserializeBinary(const GeneratorOptions& options, + io::Printer* printer, + const Descriptor* desc) const; + void GenerateClassDeserializeBinaryField(const GeneratorOptions& options, + io::Printer* printer, + const FieldDescriptor* field) const; + void GenerateClassSerializeBinary(const GeneratorOptions& options, + io::Printer* printer, + const Descriptor* desc) const; + void GenerateClassSerializeBinaryField(const GeneratorOptions& options, + io::Printer* printer, + const FieldDescriptor* field) const; + + // Generate definition for one enum. + void GenerateEnum(const GeneratorOptions& options, io::Printer* printer, + const EnumDescriptor* enumdesc) const; + + // Generate an extension definition. + void GenerateExtension(const GeneratorOptions& options, io::Printer* printer, + const FieldDescriptor* field) const; + + // Generate addFoo() method for repeated primitive fields. + void GenerateRepeatedPrimitiveHelperMethods(const GeneratorOptions& options, + io::Printer* printer, + const FieldDescriptor* field, + bool untyped) const; + + // Generate addFoo() method for repeated message fields. + void GenerateRepeatedMessageHelperMethods(const GeneratorOptions& options, + io::Printer* printer, + const FieldDescriptor* field) const; + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(Generator); +}; + +} // namespace js +} // namespace compiler +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_COMPILER_JS_GENERATOR_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/js/well_known_types_embed.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/js/well_known_types_embed.h new file mode 100644 index 0000000000000000000000000000000000000000..5e3d8361ab42ed19536e2d1e284fc9c0a0526122 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/js/well_known_types_embed.h @@ -0,0 +1,48 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef GOOGLE_PROTOBUF_COMPILER_JS_WELL_KNOWN_TYPES_EMBED_H__ +#define GOOGLE_PROTOBUF_COMPILER_JS_WELL_KNOWN_TYPES_EMBED_H__ + +#include + +struct FileToc { + const char* name; + const char* data; +}; + +extern struct FileToc well_known_types_js[]; + +#endif // GOOGLE_PROTOBUF_COMPILER_JS_WELL_KNOWN_TYPES_EMBED_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/objectivec/objectivec_generator.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/objectivec/objectivec_generator.h new file mode 100644 index 0000000000000000000000000000000000000000..e5673a7bdaa2950a9e9b3a14eebbf5313c9696ff --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/objectivec/objectivec_generator.h @@ -0,0 +1,87 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Generates ObjectiveC code for a given .proto file. + +#ifndef GOOGLE_PROTOBUF_COMPILER_OBJECTIVEC_GENERATOR_H__ +#define GOOGLE_PROTOBUF_COMPILER_OBJECTIVEC_GENERATOR_H__ + +#include +#include +#include + +#include + +namespace google { +namespace protobuf { +namespace compiler { +namespace objectivec { + +// CodeGenerator implementation which generates a ObjectiveC source file and +// header. If you create your own protocol compiler binary and you want it to +// support ObjectiveC output, you can do so by registering an instance of this +// CodeGenerator with the CommandLineInterface in your main() function. +class PROTOC_EXPORT ObjectiveCGenerator : public CodeGenerator { + public: + ObjectiveCGenerator(); + ~ObjectiveCGenerator(); + + ObjectiveCGenerator(const ObjectiveCGenerator&) = delete; + ObjectiveCGenerator& operator=(const ObjectiveCGenerator&) = delete; + + // implements CodeGenerator ---------------------------------------- + bool HasGenerateAll() const override; + bool Generate(const FileDescriptor* file, + const string& parameter, + GeneratorContext* context, + string* error) const override; + bool GenerateAll(const std::vector& files, + const string& parameter, + GeneratorContext* context, + string* error) const override; + + uint64_t GetSupportedFeatures() const override { + return FEATURE_PROTO3_OPTIONAL; + } +}; + +} // namespace objectivec +} // namespace compiler +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_COMPILER_OBJECTIVEC_GENERATOR_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/objectivec/objectivec_helpers.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/objectivec/objectivec_helpers.h new file mode 100644 index 0000000000000000000000000000000000000000..f170d077ab39a25c3ca56d9337c249f16cb57445 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/objectivec/objectivec_helpers.h @@ -0,0 +1,333 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Helper functions for generating ObjectiveC code. + +#ifndef GOOGLE_PROTOBUF_COMPILER_OBJECTIVEC_HELPERS_H__ +#define GOOGLE_PROTOBUF_COMPILER_OBJECTIVEC_HELPERS_H__ + +#include +#include + +#include +#include + +#include + +namespace google { +namespace protobuf { +namespace compiler { +namespace objectivec { + +// Generator options (see objectivec_generator.cc for a description of each): +struct Options { + Options(); + string expected_prefixes_path; + std::vector expected_prefixes_suppressions; + string generate_for_named_framework; + string named_framework_to_proto_path_mappings_path; + string runtime_import_prefix; +}; + +// Escape C++ trigraphs by escaping question marks to "\?". +string PROTOC_EXPORT EscapeTrigraphs(const string& to_escape); + +// Strips ".proto" or ".protodevel" from the end of a filename. +string PROTOC_EXPORT StripProto(const string& filename); + +// Remove white space from either end of a StringPiece. +void PROTOC_EXPORT TrimWhitespace(StringPiece* input); + +// Returns true if the name requires a ns_returns_not_retained attribute applied +// to it. +bool PROTOC_EXPORT IsRetainedName(const string& name); + +// Returns true if the name starts with "init" and will need to have special +// handling under ARC. +bool PROTOC_EXPORT IsInitName(const string& name); + +// Gets the objc_class_prefix. +string PROTOC_EXPORT FileClassPrefix(const FileDescriptor* file); + +// Gets the path of the file we're going to generate (sans the .pb.h +// extension). The path will be dependent on the objectivec package +// declared in the proto package. +string PROTOC_EXPORT FilePath(const FileDescriptor* file); + +// Just like FilePath(), but without the directory part. +string PROTOC_EXPORT FilePathBasename(const FileDescriptor* file); + +// Gets the name of the root class we'll generate in the file. This class +// is not meant for external consumption, but instead contains helpers that +// the rest of the classes need +string PROTOC_EXPORT FileClassName(const FileDescriptor* file); + +// These return the fully-qualified class name corresponding to the given +// descriptor. +string PROTOC_EXPORT ClassName(const Descriptor* descriptor); +string PROTOC_EXPORT ClassName(const Descriptor* descriptor, + string* out_suffix_added); +string PROTOC_EXPORT EnumName(const EnumDescriptor* descriptor); + +// Returns the fully-qualified name of the enum value corresponding to the +// the descriptor. +string PROTOC_EXPORT EnumValueName(const EnumValueDescriptor* descriptor); + +// Returns the name of the enum value corresponding to the descriptor. +string PROTOC_EXPORT EnumValueShortName(const EnumValueDescriptor* descriptor); + +// Reverse what an enum does. +string PROTOC_EXPORT UnCamelCaseEnumShortName(const string& name); + +// Returns the name to use for the extension (used as the method off the file's +// Root class). +string PROTOC_EXPORT ExtensionMethodName(const FieldDescriptor* descriptor); + +// Returns the transformed field name. +string PROTOC_EXPORT FieldName(const FieldDescriptor* field); +string PROTOC_EXPORT FieldNameCapitalized(const FieldDescriptor* field); + +// Returns the transformed oneof name. +string PROTOC_EXPORT OneofEnumName(const OneofDescriptor* descriptor); +string PROTOC_EXPORT OneofName(const OneofDescriptor* descriptor); +string PROTOC_EXPORT OneofNameCapitalized(const OneofDescriptor* descriptor); + +// Returns a symbol that can be used in C code to refer to an Objective C +// class without initializing the class. +string PROTOC_EXPORT ObjCClass(const string& class_name); + +// Declares an Objective C class without initializing the class so that it can +// be refrerred to by ObjCClass. +string PROTOC_EXPORT ObjCClassDeclaration(const string& class_name); + +inline bool HasPreservingUnknownEnumSemantics(const FileDescriptor* file) { + return file->syntax() == FileDescriptor::SYNTAX_PROTO3; +} + +inline bool IsMapEntryMessage(const Descriptor* descriptor) { + return descriptor->options().map_entry(); +} + +// Reverse of the above. +string PROTOC_EXPORT UnCamelCaseFieldName(const string& name, + const FieldDescriptor* field); + +enum ObjectiveCType { + OBJECTIVECTYPE_INT32, + OBJECTIVECTYPE_UINT32, + OBJECTIVECTYPE_INT64, + OBJECTIVECTYPE_UINT64, + OBJECTIVECTYPE_FLOAT, + OBJECTIVECTYPE_DOUBLE, + OBJECTIVECTYPE_BOOLEAN, + OBJECTIVECTYPE_STRING, + OBJECTIVECTYPE_DATA, + OBJECTIVECTYPE_ENUM, + OBJECTIVECTYPE_MESSAGE +}; + +enum FlagType { + FLAGTYPE_DESCRIPTOR_INITIALIZATION, + FLAGTYPE_EXTENSION, + FLAGTYPE_FIELD +}; + +template +string GetOptionalDeprecatedAttribute( + const TDescriptor* descriptor, + const FileDescriptor* file = NULL, + bool preSpace = true, bool postNewline = false) { + bool isDeprecated = descriptor->options().deprecated(); + // The file is only passed when checking Messages & Enums, so those types + // get tagged. At the moment, it doesn't seem to make sense to tag every + // field or enum value with when the file is deprecated. + bool isFileLevelDeprecation = false; + if (!isDeprecated && file) { + isFileLevelDeprecation = file->options().deprecated(); + isDeprecated = isFileLevelDeprecation; + } + if (isDeprecated) { + string message; + const FileDescriptor* sourceFile = descriptor->file(); + if (isFileLevelDeprecation) { + message = sourceFile->name() + " is deprecated."; + } else { + message = descriptor->full_name() + " is deprecated (see " + + sourceFile->name() + ")."; + } + + string result = string("GPB_DEPRECATED_MSG(\"") + message + "\")"; + if (preSpace) { + result.insert(0, " "); + } + if (postNewline) { + result.append("\n"); + } + return result; + } else { + return ""; + } +} + +string PROTOC_EXPORT GetCapitalizedType(const FieldDescriptor* field); + +ObjectiveCType PROTOC_EXPORT +GetObjectiveCType(FieldDescriptor::Type field_type); + +inline ObjectiveCType GetObjectiveCType(const FieldDescriptor* field) { + return GetObjectiveCType(field->type()); +} + +bool PROTOC_EXPORT IsPrimitiveType(const FieldDescriptor* field); +bool PROTOC_EXPORT IsReferenceType(const FieldDescriptor* field); + +string PROTOC_EXPORT GPBGenericValueFieldName(const FieldDescriptor* field); +string PROTOC_EXPORT DefaultValue(const FieldDescriptor* field); +bool PROTOC_EXPORT HasNonZeroDefaultValue(const FieldDescriptor* field); + +string PROTOC_EXPORT BuildFlagsString(const FlagType type, + const std::vector& strings); + +// Builds HeaderDoc/appledoc style comments out of the comments in the .proto +// file. +string PROTOC_EXPORT BuildCommentsString(const SourceLocation& location, + bool prefer_single_line); + +// The name the commonly used by the library when built as a framework. +// This lines up to the name used in the CocoaPod. +extern PROTOC_EXPORT const char* const ProtobufLibraryFrameworkName; +// Returns the CPP symbol name to use as the gate for framework style imports +// for the given framework name to use. +string PROTOC_EXPORT +ProtobufFrameworkImportSymbol(const string& framework_name); + +// Checks if the file is one of the proto's bundled with the library. +bool PROTOC_EXPORT +IsProtobufLibraryBundledProtoFile(const FileDescriptor* file); + +// Checks the prefix for the given files and outputs any warnings as needed. If +// there are flat out errors, then out_error is filled in with the first error +// and the result is false. +bool PROTOC_EXPORT +ValidateObjCClassPrefixes(const std::vector& files, + const Options& generation_options, string* out_error); + +// Generate decode data needed for ObjC's GPBDecodeTextFormatName() to transform +// the input into the expected output. +class PROTOC_EXPORT TextFormatDecodeData { + public: + TextFormatDecodeData(); + ~TextFormatDecodeData(); + + TextFormatDecodeData(const TextFormatDecodeData&) = delete; + TextFormatDecodeData& operator=(const TextFormatDecodeData&) = delete; + + void AddString(int32 key, const string& input_for_decode, + const string& desired_output); + size_t num_entries() const { return entries_.size(); } + string Data() const; + + static string DecodeDataForString(const string& input_for_decode, + const string& desired_output); + + private: + typedef std::pair DataEntry; + std::vector entries_; +}; + +// Helper for parsing simple files. +class PROTOC_EXPORT LineConsumer { + public: + LineConsumer(); + virtual ~LineConsumer(); + virtual bool ConsumeLine(const StringPiece& line, string* out_error) = 0; +}; + +bool PROTOC_EXPORT ParseSimpleFile(const string& path, + LineConsumer* line_consumer, + string* out_error); + +// Helper class for parsing framework import mappings and generating +// import statements. +class PROTOC_EXPORT ImportWriter { + public: + ImportWriter(const string& generate_for_named_framework, + const string& named_framework_to_proto_path_mappings_path, + const string& runtime_import_prefix, + bool include_wkt_imports); + ~ImportWriter(); + + void AddFile(const FileDescriptor* file, const string& header_extension); + void Print(io::Printer *printer) const; + + static void PrintRuntimeImports(io::Printer *printer, + const std::vector& header_to_import, + const string& runtime_import_prefix, + bool default_cpp_symbol = false); + + private: + class ProtoFrameworkCollector : public LineConsumer { + public: + ProtoFrameworkCollector(std::map* inout_proto_file_to_framework_name) + : map_(inout_proto_file_to_framework_name) {} + + virtual bool ConsumeLine(const StringPiece& line, string* out_error); + + private: + std::map* map_; + }; + + void ParseFrameworkMappings(); + + const string generate_for_named_framework_; + const string named_framework_to_proto_path_mappings_path_; + const string runtime_import_prefix_; + const bool include_wkt_imports_; + std::map proto_file_to_framework_name_; + bool need_to_parse_mapping_file_; + + std::vector protobuf_imports_; + std::vector other_framework_imports_; + std::vector other_imports_; +}; + +} // namespace objectivec +} // namespace compiler +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_COMPILER_OBJECTIVEC_HELPERS_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/parser.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/parser.h new file mode 100644 index 0000000000000000000000000000000000000000..ea3b64dc72f5316e51261584c153274455c033f5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/parser.h @@ -0,0 +1,605 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: kenton@google.com (Kenton Varda) +// Based on original Protocol Buffers design by +// Sanjay Ghemawat, Jeff Dean, and others. +// +// Implements parsing of .proto files to FileDescriptorProtos. + +#ifndef GOOGLE_PROTOBUF_COMPILER_PARSER_H__ +#define GOOGLE_PROTOBUF_COMPILER_PARSER_H__ + +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace google { +namespace protobuf { + +class Message; + +namespace compiler { + +// Defined in this file. +class Parser; +class SourceLocationTable; + +// Implements parsing of protocol definitions (such as .proto files). +// +// Note that most users will be more interested in the Importer class. +// Parser is a lower-level class which simply converts a single .proto file +// to a FileDescriptorProto. It does not resolve import directives or perform +// many other kinds of validation needed to construct a complete +// FileDescriptor. +class PROTOBUF_EXPORT Parser { + public: + Parser(); + ~Parser(); + + // Parse the entire input and construct a FileDescriptorProto representing + // it. Returns true if no errors occurred, false otherwise. + bool Parse(io::Tokenizer* input, FileDescriptorProto* file); + + // Optional features: + + // DEPRECATED: New code should use the SourceCodeInfo embedded in the + // FileDescriptorProto. + // + // Requests that locations of certain definitions be recorded to the given + // SourceLocationTable while parsing. This can be used to look up exact line + // and column numbers for errors reported by DescriptorPool during validation. + // Set to NULL (the default) to discard source location information. + void RecordSourceLocationsTo(SourceLocationTable* location_table) { + source_location_table_ = location_table; + } + + // Requests that errors be recorded to the given ErrorCollector while + // parsing. Set to NULL (the default) to discard error messages. + void RecordErrorsTo(io::ErrorCollector* error_collector) { + error_collector_ = error_collector; + } + + // Returns the identifier used in the "syntax = " declaration, if one was + // seen during the last call to Parse(), or the empty string otherwise. + const std::string& GetSyntaxIdentifier() { return syntax_identifier_; } + + // If set true, input files will be required to begin with a syntax + // identifier. Otherwise, files may omit this. If a syntax identifier + // is provided, it must be 'syntax = "proto2";' and must appear at the + // top of this file regardless of whether or not it was required. + void SetRequireSyntaxIdentifier(bool value) { + require_syntax_identifier_ = value; + } + + // Call SetStopAfterSyntaxIdentifier(true) to tell the parser to stop + // parsing as soon as it has seen the syntax identifier, or lack thereof. + // This is useful for quickly identifying the syntax of the file without + // parsing the whole thing. If this is enabled, no error will be recorded + // if the syntax identifier is something other than "proto2" (since + // presumably the caller intends to deal with that), but other kinds of + // errors (e.g. parse errors) will still be reported. When this is enabled, + // you may pass a NULL FileDescriptorProto to Parse(). + void SetStopAfterSyntaxIdentifier(bool value) { + stop_after_syntax_identifier_ = value; + } + + private: + class LocationRecorder; + + // ================================================================= + // Error recovery helpers + + // Consume the rest of the current statement. This consumes tokens + // until it sees one of: + // ';' Consumes the token and returns. + // '{' Consumes the brace then calls SkipRestOfBlock(). + // '}' Returns without consuming. + // EOF Returns (can't consume). + // The Parser often calls SkipStatement() after encountering a syntax + // error. This allows it to go on parsing the following lines, allowing + // it to report more than just one error in the file. + void SkipStatement(); + + // Consume the rest of the current block, including nested blocks, + // ending after the closing '}' is encountered and consumed, or at EOF. + void SkipRestOfBlock(); + + // ----------------------------------------------------------------- + // Single-token consuming helpers + // + // These make parsing code more readable. + + // True if the current token is TYPE_END. + inline bool AtEnd(); + + // True if the next token matches the given text. + inline bool LookingAt(const char* text); + // True if the next token is of the given type. + inline bool LookingAtType(io::Tokenizer::TokenType token_type); + + // If the next token exactly matches the text given, consume it and return + // true. Otherwise, return false without logging an error. + bool TryConsume(const char* text); + + // These attempt to read some kind of token from the input. If successful, + // they return true. Otherwise they return false and add the given error + // to the error list. + + // Consume a token with the exact text given. + bool Consume(const char* text, const char* error); + // Same as above, but automatically generates the error "Expected \"text\".", + // where "text" is the expected token text. + bool Consume(const char* text); + // Consume a token of type IDENTIFIER and store its text in "output". + bool ConsumeIdentifier(std::string* output, const char* error); + // Consume an integer and store its value in "output". + bool ConsumeInteger(int* output, const char* error); + // Consume a signed integer and store its value in "output". + bool ConsumeSignedInteger(int* output, const char* error); + // Consume a 64-bit integer and store its value in "output". If the value + // is greater than max_value, an error will be reported. + bool ConsumeInteger64(uint64 max_value, uint64* output, const char* error); + // Consume a number and store its value in "output". This will accept + // tokens of either INTEGER or FLOAT type. + bool ConsumeNumber(double* output, const char* error); + // Consume a string literal and store its (unescaped) value in "output". + bool ConsumeString(std::string* output, const char* error); + + // Consume a token representing the end of the statement. Comments between + // this token and the next will be harvested for documentation. The given + // LocationRecorder should refer to the declaration that was just parsed; + // it will be populated with these comments. + // + // TODO(kenton): The LocationRecorder is const because historically locations + // have been passed around by const reference, for no particularly good + // reason. We should probably go through and change them all to mutable + // pointer to make this more intuitive. + bool TryConsumeEndOfDeclaration(const char* text, + const LocationRecorder* location); + bool TryConsumeEndOfDeclarationFinishScope(const char* text, + const LocationRecorder* location); + + bool ConsumeEndOfDeclaration(const char* text, + const LocationRecorder* location); + + // ----------------------------------------------------------------- + // Error logging helpers + + // Invokes error_collector_->AddError(), if error_collector_ is not NULL. + void AddError(int line, int column, const std::string& error); + + // Invokes error_collector_->AddError() with the line and column number + // of the current token. + void AddError(const std::string& error); + + // Invokes error_collector_->AddWarning() with the line and column number + // of the current token. + void AddWarning(const std::string& warning); + + // Records a location in the SourceCodeInfo.location table (see + // descriptor.proto). We use RAII to ensure that the start and end locations + // are recorded -- the constructor records the start location and the + // destructor records the end location. Since the parser is + // recursive-descent, this works out beautifully. + class PROTOBUF_EXPORT LocationRecorder { + public: + // Construct the file's "root" location. + LocationRecorder(Parser* parser); + + // Construct a location that represents a declaration nested within the + // given parent. E.g. a field's location is nested within the location + // for a message type. The parent's path will be copied, so you should + // call AddPath() only to add the path components leading from the parent + // to the child (as opposed to leading from the root to the child). + LocationRecorder(const LocationRecorder& parent); + + // Convenience constructors that call AddPath() one or two times. + LocationRecorder(const LocationRecorder& parent, int path1); + LocationRecorder(const LocationRecorder& parent, int path1, int path2); + + // Creates a recorder that generates locations into given source code info. + LocationRecorder(const LocationRecorder& parent, int path1, + SourceCodeInfo* source_code_info); + + ~LocationRecorder(); + + // Add a path component. See SourceCodeInfo.Location.path in + // descriptor.proto. + void AddPath(int path_component); + + // By default the location is considered to start at the current token at + // the time the LocationRecorder is created. StartAt() sets the start + // location to the given token instead. + void StartAt(const io::Tokenizer::Token& token); + + // Start at the same location as some other LocationRecorder. + void StartAt(const LocationRecorder& other); + + // By default the location is considered to end at the previous token at + // the time the LocationRecorder is destroyed. EndAt() sets the end + // location to the given token instead. + void EndAt(const io::Tokenizer::Token& token); + + // Records the start point of this location to the SourceLocationTable that + // was passed to RecordSourceLocationsTo(), if any. SourceLocationTable + // is an older way of keeping track of source locations which is still + // used in some places. + void RecordLegacyLocation( + const Message* descriptor, + DescriptorPool::ErrorCollector::ErrorLocation location); + void RecordLegacyImportLocation(const Message* descriptor, + const std::string& name); + + // Returns the number of path components in the recorder's current location. + int CurrentPathSize() const; + + // Attaches leading and trailing comments to the location. The two strings + // will be swapped into place, so after this is called *leading and + // *trailing will be empty. + // + // TODO(kenton): See comment on TryConsumeEndOfDeclaration(), above, for + // why this is const. + void AttachComments(std::string* leading, std::string* trailing, + std::vector* detached_comments) const; + + private: + // Indexes of parent and current location in the parent + // SourceCodeInfo.location repeated field. For top-level elements, + // parent_index_ is -1. + Parser* parser_; + SourceCodeInfo* source_code_info_; + SourceCodeInfo::Location* location_; + + void Init(const LocationRecorder& parent, SourceCodeInfo* source_code_info); + }; + + // ================================================================= + // Parsers for various language constructs + + // Parses the "syntax = \"proto2\";" line at the top of the file. Returns + // false if it failed to parse or if the syntax identifier was not + // recognized. + bool ParseSyntaxIdentifier(const LocationRecorder& parent); + + // These methods parse various individual bits of code. They return + // false if they completely fail to parse the construct. In this case, + // it is probably necessary to skip the rest of the statement to recover. + // However, if these methods return true, it does NOT mean that there + // were no errors; only that there were no *syntax* errors. For instance, + // if a service method is defined using proper syntax but uses a primitive + // type as its input or output, ParseMethodField() still returns true + // and only reports the error by calling AddError(). In practice, this + // makes logic much simpler for the caller. + + // Parse a top-level message, enum, service, etc. + bool ParseTopLevelStatement(FileDescriptorProto* file, + const LocationRecorder& root_location); + + // Parse various language high-level language construrcts. + bool ParseMessageDefinition(DescriptorProto* message, + const LocationRecorder& message_location, + const FileDescriptorProto* containing_file); + bool ParseEnumDefinition(EnumDescriptorProto* enum_type, + const LocationRecorder& enum_location, + const FileDescriptorProto* containing_file); + bool ParseServiceDefinition(ServiceDescriptorProto* service, + const LocationRecorder& service_location, + const FileDescriptorProto* containing_file); + bool ParsePackage(FileDescriptorProto* file, + const LocationRecorder& root_location, + const FileDescriptorProto* containing_file); + bool ParseImport(RepeatedPtrField* dependency, + RepeatedField* public_dependency, + RepeatedField* weak_dependency, + const LocationRecorder& root_location, + const FileDescriptorProto* containing_file); + + // These methods parse the contents of a message, enum, or service type and + // add them to the given object. They consume the entire block including + // the beginning and ending brace. + bool ParseMessageBlock(DescriptorProto* message, + const LocationRecorder& message_location, + const FileDescriptorProto* containing_file); + bool ParseEnumBlock(EnumDescriptorProto* enum_type, + const LocationRecorder& enum_location, + const FileDescriptorProto* containing_file); + bool ParseServiceBlock(ServiceDescriptorProto* service, + const LocationRecorder& service_location, + const FileDescriptorProto* containing_file); + + // Parse one statement within a message, enum, or service block, including + // final semicolon. + bool ParseMessageStatement(DescriptorProto* message, + const LocationRecorder& message_location, + const FileDescriptorProto* containing_file); + bool ParseEnumStatement(EnumDescriptorProto* message, + const LocationRecorder& enum_location, + const FileDescriptorProto* containing_file); + bool ParseServiceStatement(ServiceDescriptorProto* message, + const LocationRecorder& service_location, + const FileDescriptorProto* containing_file); + + // Parse a field of a message. If the field is a group, its type will be + // added to "messages". + // + // parent_location and location_field_number_for_nested_type are needed when + // parsing groups -- we need to generate a nested message type within the + // parent and record its location accordingly. Since the parent could be + // either a FileDescriptorProto or a DescriptorProto, we must pass in the + // correct field number to use. + bool ParseMessageField(FieldDescriptorProto* field, + RepeatedPtrField* messages, + const LocationRecorder& parent_location, + int location_field_number_for_nested_type, + const LocationRecorder& field_location, + const FileDescriptorProto* containing_file); + + // Like ParseMessageField() but expects the label has already been filled in + // by the caller. + bool ParseMessageFieldNoLabel(FieldDescriptorProto* field, + RepeatedPtrField* messages, + const LocationRecorder& parent_location, + int location_field_number_for_nested_type, + const LocationRecorder& field_location, + const FileDescriptorProto* containing_file); + + // Parse an "extensions" declaration. + bool ParseExtensions(DescriptorProto* message, + const LocationRecorder& extensions_location, + const FileDescriptorProto* containing_file); + + // Parse a "reserved" declaration. + bool ParseReserved(DescriptorProto* message, + const LocationRecorder& message_location); + bool ParseReservedNames(DescriptorProto* message, + const LocationRecorder& parent_location); + bool ParseReservedNumbers(DescriptorProto* message, + const LocationRecorder& parent_location); + bool ParseReserved(EnumDescriptorProto* message, + const LocationRecorder& message_location); + bool ParseReservedNames(EnumDescriptorProto* message, + const LocationRecorder& parent_location); + bool ParseReservedNumbers(EnumDescriptorProto* message, + const LocationRecorder& parent_location); + + // Parse an "extend" declaration. (See also comments for + // ParseMessageField().) + bool ParseExtend(RepeatedPtrField* extensions, + RepeatedPtrField* messages, + const LocationRecorder& parent_location, + int location_field_number_for_nested_type, + const LocationRecorder& extend_location, + const FileDescriptorProto* containing_file); + + // Parse a "oneof" declaration. The caller is responsible for setting + // oneof_decl->label() since it will have had to parse the label before it + // knew it was parsing a oneof. + bool ParseOneof(OneofDescriptorProto* oneof_decl, + DescriptorProto* containing_type, int oneof_index, + const LocationRecorder& oneof_location, + const LocationRecorder& containing_type_location, + const FileDescriptorProto* containing_file); + + // Parse a single enum value within an enum block. + bool ParseEnumConstant(EnumValueDescriptorProto* enum_value, + const LocationRecorder& enum_value_location, + const FileDescriptorProto* containing_file); + + // Parse enum constant options, i.e. the list in square brackets at the end + // of the enum constant value definition. + bool ParseEnumConstantOptions(EnumValueDescriptorProto* value, + const LocationRecorder& enum_value_location, + const FileDescriptorProto* containing_file); + + // Parse a single method within a service definition. + bool ParseServiceMethod(MethodDescriptorProto* method, + const LocationRecorder& method_location, + const FileDescriptorProto* containing_file); + + + // Parse options of a single method or stream. + bool ParseMethodOptions(const LocationRecorder& parent_location, + const FileDescriptorProto* containing_file, + const int optionsFieldNumber, + Message* mutable_options); + + // Parse "required", "optional", or "repeated" and fill in "label" + // with the value. Returns true if such a label is consumed. + bool ParseLabel(FieldDescriptorProto::Label* label, + const LocationRecorder& field_location, + const FileDescriptorProto* containing_file); + + // Parse a type name and fill in "type" (if it is a primitive) or + // "type_name" (if it is not) with the type parsed. + bool ParseType(FieldDescriptorProto::Type* type, std::string* type_name); + // Parse a user-defined type and fill in "type_name" with the name. + // If a primitive type is named, it is treated as an error. + bool ParseUserDefinedType(std::string* type_name); + + // Parses field options, i.e. the stuff in square brackets at the end + // of a field definition. Also parses default value. + bool ParseFieldOptions(FieldDescriptorProto* field, + const LocationRecorder& field_location, + const FileDescriptorProto* containing_file); + + // Parse the "default" option. This needs special handling because its + // type is the field's type. + bool ParseDefaultAssignment(FieldDescriptorProto* field, + const LocationRecorder& field_location, + const FileDescriptorProto* containing_file); + + bool ParseJsonName(FieldDescriptorProto* field, + const LocationRecorder& field_location, + const FileDescriptorProto* containing_file); + + enum OptionStyle { + OPTION_ASSIGNMENT, // just "name = value" + OPTION_STATEMENT // "option name = value;" + }; + + // Parse a single option name/value pair, e.g. "ctype = CORD". The name + // identifies a field of the given Message, and the value of that field + // is set to the parsed value. + bool ParseOption(Message* options, const LocationRecorder& options_location, + const FileDescriptorProto* containing_file, + OptionStyle style); + + // Parses a single part of a multipart option name. A multipart name consists + // of names separated by dots. Each name is either an identifier or a series + // of identifiers separated by dots and enclosed in parentheses. E.g., + // "foo.(bar.baz).qux". + bool ParseOptionNamePart(UninterpretedOption* uninterpreted_option, + const LocationRecorder& part_location, + const FileDescriptorProto* containing_file); + + // Parses a string surrounded by balanced braces. Strips off the outer + // braces and stores the enclosed string in *value. + // E.g., + // { foo } *value gets 'foo' + // { foo { bar: box } } *value gets 'foo { bar: box }' + // {} *value gets '' + // + // REQUIRES: LookingAt("{") + // When finished successfully, we are looking at the first token past + // the ending brace. + bool ParseUninterpretedBlock(std::string* value); + + struct MapField { + // Whether the field is a map field. + bool is_map_field; + // The types of the key and value if they are primitive types. + FieldDescriptorProto::Type key_type; + FieldDescriptorProto::Type value_type; + // Or the type names string if the types are customized types. + std::string key_type_name; + std::string value_type_name; + + MapField() : is_map_field(false) {} + }; + // Desugar the map syntax to generate a nested map entry message. + void GenerateMapEntry(const MapField& map_field, FieldDescriptorProto* field, + RepeatedPtrField* messages); + + // Whether fields without label default to optional fields. + bool DefaultToOptionalFields() const { + return syntax_identifier_ == "proto3"; + } + + + bool ValidateEnum(const EnumDescriptorProto* proto); + + // ================================================================= + + io::Tokenizer* input_; + io::ErrorCollector* error_collector_; + SourceCodeInfo* source_code_info_; + SourceLocationTable* source_location_table_; // legacy + bool had_errors_; + bool require_syntax_identifier_; + bool stop_after_syntax_identifier_; + std::string syntax_identifier_; + + // Leading doc comments for the next declaration. These are not complete + // yet; use ConsumeEndOfDeclaration() to get the complete comments. + std::string upcoming_doc_comments_; + + // Detached comments are not connected to any syntax entities. Elements in + // this vector are paragraphs of comments separated by empty lines. The + // detached comments will be put into the leading_detached_comments field for + // the next element (See SourceCodeInfo.Location in descriptor.proto), when + // ConsumeEndOfDeclaration() is called. + std::vector upcoming_detached_comments_; + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(Parser); +}; + +// A table mapping (descriptor, ErrorLocation) pairs -- as reported by +// DescriptorPool when validating descriptors -- to line and column numbers +// within the original source code. +// +// This is semi-obsolete: FileDescriptorProto.source_code_info now contains +// far more complete information about source locations. However, as of this +// writing you still need to use SourceLocationTable when integrating with +// DescriptorPool. +class PROTOBUF_EXPORT SourceLocationTable { + public: + SourceLocationTable(); + ~SourceLocationTable(); + + // Finds the precise location of the given error and fills in *line and + // *column with the line and column numbers. If not found, sets *line to + // -1 and *column to 0 (since line = -1 is used to mean "error has no exact + // location" in the ErrorCollector interface). Returns true if found, false + // otherwise. + bool Find(const Message* descriptor, + DescriptorPool::ErrorCollector::ErrorLocation location, int* line, + int* column) const; + bool FindImport(const Message* descriptor, const std::string& name, int* line, + int* column) const; + + // Adds a location to the table. + void Add(const Message* descriptor, + DescriptorPool::ErrorCollector::ErrorLocation location, int line, + int column); + void AddImport(const Message* descriptor, const std::string& name, int line, + int column); + + // Clears the contents of the table. + void Clear(); + + private: + typedef std::map< + std::pair, + std::pair > + LocationMap; + LocationMap location_map_; + std::map, std::pair > + import_location_map_; +}; + +} // namespace compiler +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_COMPILER_PARSER_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/php/php_generator.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/php/php_generator.h new file mode 100644 index 0000000000000000000000000000000000000000..e2610ec4dde2b539e6e021d55dd30e252bdc94ea --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/php/php_generator.h @@ -0,0 +1,97 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef GOOGLE_PROTOBUF_COMPILER_PHP_GENERATOR_H__ +#define GOOGLE_PROTOBUF_COMPILER_PHP_GENERATOR_H__ + +#include +#include + +#include + +#include + +namespace google { +namespace protobuf { +namespace compiler { +namespace php { + +class PROTOC_EXPORT Generator : public CodeGenerator { + public: + virtual bool Generate( + const FileDescriptor* file, + const string& parameter, + GeneratorContext* generator_context, + string* error) const override; + + bool GenerateAll(const std::vector& files, + const std::string& parameter, + GeneratorContext* generator_context, + std::string* error) const override; + + uint64_t GetSupportedFeatures() const override { + return FEATURE_PROTO3_OPTIONAL; + } + + private: + bool Generate( + const FileDescriptor* file, + bool is_descriptor, + bool aggregate_metadata, + const std::set& aggregate_metadata_prefixes, + GeneratorContext* generator_context, + string* error) const; +}; + +// To skip reserved keywords in php, some generated classname are prefixed. +// Other code generators may need following API to figure out the actual +// classname. +PROTOC_EXPORT std::string GeneratedClassName(const Descriptor* desc); +PROTOC_EXPORT std::string GeneratedClassName(const EnumDescriptor* desc); +PROTOC_EXPORT std::string GeneratedClassName(const ServiceDescriptor* desc); + +inline bool IsWrapperType(const FieldDescriptor* descriptor) { + return descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE && + descriptor->message_type()->file()->name() == "google/protobuf/wrappers.proto"; +} + +} // namespace php +} // namespace compiler +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_COMPILER_PHP_GENERATOR_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/plugin.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/plugin.h new file mode 100644 index 0000000000000000000000000000000000000000..a25079235a36c11ad7ca8ef8eeea78c172178eeb --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/plugin.h @@ -0,0 +1,99 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: kenton@google.com (Kenton Varda) +// +// Front-end for protoc code generator plugins written in C++. +// +// To implement a protoc plugin in C++, simply write an implementation of +// CodeGenerator, then create a main() function like: +// int main(int argc, char* argv[]) { +// MyCodeGenerator generator; +// return google::protobuf::compiler::PluginMain(argc, argv, &generator); +// } +// You must link your plugin against libprotobuf and libprotoc. +// +// The core part of PluginMain is to invoke the given CodeGenerator on a +// CodeGeneratorRequest to generate a CodeGeneratorResponse. This part is +// abstracted out and made into function GenerateCode so that it can be reused, +// for example, to implement a variant of PluginMain that does some +// preprocessing on the input CodeGeneratorRequest before feeding the request +// to the given code generator. +// +// To get protoc to use the plugin, do one of the following: +// * Place the plugin binary somewhere in the PATH and give it the name +// "protoc-gen-NAME" (replacing "NAME" with the name of your plugin). If you +// then invoke protoc with the parameter --NAME_out=OUT_DIR (again, replace +// "NAME" with your plugin's name), protoc will invoke your plugin to generate +// the output, which will be placed in OUT_DIR. +// * Place the plugin binary anywhere, with any name, and pass the --plugin +// parameter to protoc to direct it to your plugin like so: +// protoc --plugin=protoc-gen-NAME=path/to/mybinary --NAME_out=OUT_DIR +// On Windows, make sure to include the .exe suffix: +// protoc --plugin=protoc-gen-NAME=path/to/mybinary.exe --NAME_out=OUT_DIR + +#ifndef GOOGLE_PROTOBUF_COMPILER_PLUGIN_H__ +#define GOOGLE_PROTOBUF_COMPILER_PLUGIN_H__ + +#include + +#include + +namespace google { +namespace protobuf { +namespace compiler { + +class CodeGenerator; // code_generator.h +class CodeGeneratorRequest; +class CodeGeneratorResponse; + +// Implements main() for a protoc plugin exposing the given code generator. +PROTOC_EXPORT int PluginMain(int argc, char* argv[], + const CodeGenerator* generator); + +// Generates code using the given code generator. Returns true if the code +// generation is successful. If the code generation fails, error_msg may be +// populated to describe the failure cause. +bool GenerateCode(const CodeGeneratorRequest& request, + const CodeGenerator& generator, + CodeGeneratorResponse* response, std::string* error_msg); + +} // namespace compiler +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_COMPILER_PLUGIN_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/plugin.pb.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/plugin.pb.h new file mode 100644 index 0000000000000000000000000000000000000000..81ba11cec535baed50d5b8f892ed5ba82e45a58f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/plugin.pb.h @@ -0,0 +1,1803 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: google/protobuf/compiler/plugin.proto + +#ifndef GOOGLE_PROTOBUF_INCLUDED_google_2fprotobuf_2fcompiler_2fplugin_2eproto +#define GOOGLE_PROTOBUF_INCLUDED_google_2fprotobuf_2fcompiler_2fplugin_2eproto + +#include +#include + +#include +#if PROTOBUF_VERSION < 3013000 +#error This file was generated by a newer version of protoc which is +#error incompatible with your Protocol Buffer headers. Please update +#error your headers. +#endif +#if 3013000 < PROTOBUF_MIN_PROTOC_VERSION +#error This file was generated by an older version of protoc which is +#error incompatible with your Protocol Buffer headers. Please +#error regenerate this file with a newer version of protoc. +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include // IWYU pragma: export +#include // IWYU pragma: export +#include +#include +#include +// @@protoc_insertion_point(includes) +#include +#define PROTOBUF_INTERNAL_EXPORT_google_2fprotobuf_2fcompiler_2fplugin_2eproto PROTOC_EXPORT +#ifdef major +#undef major +#endif +#ifdef minor +#undef minor +#endif +PROTOBUF_NAMESPACE_OPEN +namespace internal { +class AnyMetadata; +} // namespace internal +PROTOBUF_NAMESPACE_CLOSE + +// Internal implementation detail -- do not use these members. +struct PROTOC_EXPORT TableStruct_google_2fprotobuf_2fcompiler_2fplugin_2eproto { + static const ::PROTOBUF_NAMESPACE_ID::internal::ParseTableField entries[] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::AuxiliaryParseTableField aux[] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::ParseTable schema[4] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::FieldMetadata field_metadata[]; + static const ::PROTOBUF_NAMESPACE_ID::internal::SerializationTable serialization_table[]; + static const ::PROTOBUF_NAMESPACE_ID::uint32 offsets[]; +}; +extern PROTOC_EXPORT const ::PROTOBUF_NAMESPACE_ID::internal::DescriptorTable descriptor_table_google_2fprotobuf_2fcompiler_2fplugin_2eproto; +PROTOBUF_NAMESPACE_OPEN +namespace compiler { +class CodeGeneratorRequest; +class CodeGeneratorRequestDefaultTypeInternal; +PROTOC_EXPORT extern CodeGeneratorRequestDefaultTypeInternal _CodeGeneratorRequest_default_instance_; +class CodeGeneratorResponse; +class CodeGeneratorResponseDefaultTypeInternal; +PROTOC_EXPORT extern CodeGeneratorResponseDefaultTypeInternal _CodeGeneratorResponse_default_instance_; +class CodeGeneratorResponse_File; +class CodeGeneratorResponse_FileDefaultTypeInternal; +PROTOC_EXPORT extern CodeGeneratorResponse_FileDefaultTypeInternal _CodeGeneratorResponse_File_default_instance_; +class Version; +class VersionDefaultTypeInternal; +PROTOC_EXPORT extern VersionDefaultTypeInternal _Version_default_instance_; +} // namespace compiler +PROTOBUF_NAMESPACE_CLOSE +PROTOBUF_NAMESPACE_OPEN +template<> PROTOC_EXPORT PROTOBUF_NAMESPACE_ID::compiler::CodeGeneratorRequest* Arena::CreateMaybeMessage(Arena*); +template<> PROTOC_EXPORT PROTOBUF_NAMESPACE_ID::compiler::CodeGeneratorResponse* Arena::CreateMaybeMessage(Arena*); +template<> PROTOC_EXPORT PROTOBUF_NAMESPACE_ID::compiler::CodeGeneratorResponse_File* Arena::CreateMaybeMessage(Arena*); +template<> PROTOC_EXPORT PROTOBUF_NAMESPACE_ID::compiler::Version* Arena::CreateMaybeMessage(Arena*); +PROTOBUF_NAMESPACE_CLOSE +PROTOBUF_NAMESPACE_OPEN +namespace compiler { + +enum CodeGeneratorResponse_Feature : int { + CodeGeneratorResponse_Feature_FEATURE_NONE = 0, + CodeGeneratorResponse_Feature_FEATURE_PROTO3_OPTIONAL = 1 +}; +PROTOC_EXPORT bool CodeGeneratorResponse_Feature_IsValid(int value); +constexpr CodeGeneratorResponse_Feature CodeGeneratorResponse_Feature_Feature_MIN = CodeGeneratorResponse_Feature_FEATURE_NONE; +constexpr CodeGeneratorResponse_Feature CodeGeneratorResponse_Feature_Feature_MAX = CodeGeneratorResponse_Feature_FEATURE_PROTO3_OPTIONAL; +constexpr int CodeGeneratorResponse_Feature_Feature_ARRAYSIZE = CodeGeneratorResponse_Feature_Feature_MAX + 1; + +PROTOC_EXPORT const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* CodeGeneratorResponse_Feature_descriptor(); +template +inline const std::string& CodeGeneratorResponse_Feature_Name(T enum_t_value) { + static_assert(::std::is_same::value || + ::std::is_integral::value, + "Incorrect type passed to function CodeGeneratorResponse_Feature_Name."); + return ::PROTOBUF_NAMESPACE_ID::internal::NameOfEnum( + CodeGeneratorResponse_Feature_descriptor(), enum_t_value); +} +inline bool CodeGeneratorResponse_Feature_Parse( + ::PROTOBUF_NAMESPACE_ID::ConstStringParam name, CodeGeneratorResponse_Feature* value) { + return ::PROTOBUF_NAMESPACE_ID::internal::ParseNamedEnum( + CodeGeneratorResponse_Feature_descriptor(), name, value); +} +// =================================================================== + +class PROTOC_EXPORT Version PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:google.protobuf.compiler.Version) */ { + public: + inline Version() : Version(nullptr) {} + virtual ~Version(); + + Version(const Version& from); + Version(Version&& from) noexcept + : Version() { + *this = ::std::move(from); + } + + inline Version& operator=(const Version& from) { + CopyFrom(from); + return *this; + } + inline Version& operator=(Version&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const Version& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const Version* internal_default_instance() { + return reinterpret_cast( + &_Version_default_instance_); + } + static constexpr int kIndexInFileMessages = + 0; + + friend void swap(Version& a, Version& b) { + a.Swap(&b); + } + inline void Swap(Version* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(Version* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline Version* New() const final { + return CreateMaybeMessage(nullptr); + } + + Version* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const Version& from); + void MergeFrom(const Version& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(Version* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "google.protobuf.compiler.Version"; + } + protected: + explicit Version(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_google_2fprotobuf_2fcompiler_2fplugin_2eproto); + return ::descriptor_table_google_2fprotobuf_2fcompiler_2fplugin_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kSuffixFieldNumber = 4, + kMajorFieldNumber = 1, + kMinorFieldNumber = 2, + kPatchFieldNumber = 3, + }; + // optional string suffix = 4; + bool has_suffix() const; + private: + bool _internal_has_suffix() const; + public: + void clear_suffix(); + const std::string& suffix() const; + void set_suffix(const std::string& value); + void set_suffix(std::string&& value); + void set_suffix(const char* value); + void set_suffix(const char* value, size_t size); + std::string* mutable_suffix(); + std::string* release_suffix(); + void set_allocated_suffix(std::string* suffix); + private: + const std::string& _internal_suffix() const; + void _internal_set_suffix(const std::string& value); + std::string* _internal_mutable_suffix(); + public: + + // optional int32 major = 1; + bool has_major() const; + private: + bool _internal_has_major() const; + public: + void clear_major(); + ::PROTOBUF_NAMESPACE_ID::int32 major() const; + void set_major(::PROTOBUF_NAMESPACE_ID::int32 value); + private: + ::PROTOBUF_NAMESPACE_ID::int32 _internal_major() const; + void _internal_set_major(::PROTOBUF_NAMESPACE_ID::int32 value); + public: + + // optional int32 minor = 2; + bool has_minor() const; + private: + bool _internal_has_minor() const; + public: + void clear_minor(); + ::PROTOBUF_NAMESPACE_ID::int32 minor() const; + void set_minor(::PROTOBUF_NAMESPACE_ID::int32 value); + private: + ::PROTOBUF_NAMESPACE_ID::int32 _internal_minor() const; + void _internal_set_minor(::PROTOBUF_NAMESPACE_ID::int32 value); + public: + + // optional int32 patch = 3; + bool has_patch() const; + private: + bool _internal_has_patch() const; + public: + void clear_patch(); + ::PROTOBUF_NAMESPACE_ID::int32 patch() const; + void set_patch(::PROTOBUF_NAMESPACE_ID::int32 value); + private: + ::PROTOBUF_NAMESPACE_ID::int32 _internal_patch() const; + void _internal_set_patch(::PROTOBUF_NAMESPACE_ID::int32 value); + public: + + // @@protoc_insertion_point(class_scope:google.protobuf.compiler.Version) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr suffix_; + ::PROTOBUF_NAMESPACE_ID::int32 major_; + ::PROTOBUF_NAMESPACE_ID::int32 minor_; + ::PROTOBUF_NAMESPACE_ID::int32 patch_; + friend struct ::TableStruct_google_2fprotobuf_2fcompiler_2fplugin_2eproto; +}; +// ------------------------------------------------------------------- + +class PROTOC_EXPORT CodeGeneratorRequest PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:google.protobuf.compiler.CodeGeneratorRequest) */ { + public: + inline CodeGeneratorRequest() : CodeGeneratorRequest(nullptr) {} + virtual ~CodeGeneratorRequest(); + + CodeGeneratorRequest(const CodeGeneratorRequest& from); + CodeGeneratorRequest(CodeGeneratorRequest&& from) noexcept + : CodeGeneratorRequest() { + *this = ::std::move(from); + } + + inline CodeGeneratorRequest& operator=(const CodeGeneratorRequest& from) { + CopyFrom(from); + return *this; + } + inline CodeGeneratorRequest& operator=(CodeGeneratorRequest&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const CodeGeneratorRequest& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const CodeGeneratorRequest* internal_default_instance() { + return reinterpret_cast( + &_CodeGeneratorRequest_default_instance_); + } + static constexpr int kIndexInFileMessages = + 1; + + friend void swap(CodeGeneratorRequest& a, CodeGeneratorRequest& b) { + a.Swap(&b); + } + inline void Swap(CodeGeneratorRequest* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(CodeGeneratorRequest* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline CodeGeneratorRequest* New() const final { + return CreateMaybeMessage(nullptr); + } + + CodeGeneratorRequest* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const CodeGeneratorRequest& from); + void MergeFrom(const CodeGeneratorRequest& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(CodeGeneratorRequest* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "google.protobuf.compiler.CodeGeneratorRequest"; + } + protected: + explicit CodeGeneratorRequest(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_google_2fprotobuf_2fcompiler_2fplugin_2eproto); + return ::descriptor_table_google_2fprotobuf_2fcompiler_2fplugin_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kFileToGenerateFieldNumber = 1, + kProtoFileFieldNumber = 15, + kParameterFieldNumber = 2, + kCompilerVersionFieldNumber = 3, + }; + // repeated string file_to_generate = 1; + int file_to_generate_size() const; + private: + int _internal_file_to_generate_size() const; + public: + void clear_file_to_generate(); + const std::string& file_to_generate(int index) const; + std::string* mutable_file_to_generate(int index); + void set_file_to_generate(int index, const std::string& value); + void set_file_to_generate(int index, std::string&& value); + void set_file_to_generate(int index, const char* value); + void set_file_to_generate(int index, const char* value, size_t size); + std::string* add_file_to_generate(); + void add_file_to_generate(const std::string& value); + void add_file_to_generate(std::string&& value); + void add_file_to_generate(const char* value); + void add_file_to_generate(const char* value, size_t size); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField& file_to_generate() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField* mutable_file_to_generate(); + private: + const std::string& _internal_file_to_generate(int index) const; + std::string* _internal_add_file_to_generate(); + public: + + // repeated .google.protobuf.FileDescriptorProto proto_file = 15; + int proto_file_size() const; + private: + int _internal_proto_file_size() const; + public: + void clear_proto_file(); + PROTOBUF_NAMESPACE_ID::FileDescriptorProto* mutable_proto_file(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::FileDescriptorProto >* + mutable_proto_file(); + private: + const PROTOBUF_NAMESPACE_ID::FileDescriptorProto& _internal_proto_file(int index) const; + PROTOBUF_NAMESPACE_ID::FileDescriptorProto* _internal_add_proto_file(); + public: + const PROTOBUF_NAMESPACE_ID::FileDescriptorProto& proto_file(int index) const; + PROTOBUF_NAMESPACE_ID::FileDescriptorProto* add_proto_file(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::FileDescriptorProto >& + proto_file() const; + + // optional string parameter = 2; + bool has_parameter() const; + private: + bool _internal_has_parameter() const; + public: + void clear_parameter(); + const std::string& parameter() const; + void set_parameter(const std::string& value); + void set_parameter(std::string&& value); + void set_parameter(const char* value); + void set_parameter(const char* value, size_t size); + std::string* mutable_parameter(); + std::string* release_parameter(); + void set_allocated_parameter(std::string* parameter); + private: + const std::string& _internal_parameter() const; + void _internal_set_parameter(const std::string& value); + std::string* _internal_mutable_parameter(); + public: + + // optional .google.protobuf.compiler.Version compiler_version = 3; + bool has_compiler_version() const; + private: + bool _internal_has_compiler_version() const; + public: + void clear_compiler_version(); + const PROTOBUF_NAMESPACE_ID::compiler::Version& compiler_version() const; + PROTOBUF_NAMESPACE_ID::compiler::Version* release_compiler_version(); + PROTOBUF_NAMESPACE_ID::compiler::Version* mutable_compiler_version(); + void set_allocated_compiler_version(PROTOBUF_NAMESPACE_ID::compiler::Version* compiler_version); + private: + const PROTOBUF_NAMESPACE_ID::compiler::Version& _internal_compiler_version() const; + PROTOBUF_NAMESPACE_ID::compiler::Version* _internal_mutable_compiler_version(); + public: + void unsafe_arena_set_allocated_compiler_version( + PROTOBUF_NAMESPACE_ID::compiler::Version* compiler_version); + PROTOBUF_NAMESPACE_ID::compiler::Version* unsafe_arena_release_compiler_version(); + + // @@protoc_insertion_point(class_scope:google.protobuf.compiler.CodeGeneratorRequest) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField file_to_generate_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::FileDescriptorProto > proto_file_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr parameter_; + PROTOBUF_NAMESPACE_ID::compiler::Version* compiler_version_; + friend struct ::TableStruct_google_2fprotobuf_2fcompiler_2fplugin_2eproto; +}; +// ------------------------------------------------------------------- + +class PROTOC_EXPORT CodeGeneratorResponse_File PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:google.protobuf.compiler.CodeGeneratorResponse.File) */ { + public: + inline CodeGeneratorResponse_File() : CodeGeneratorResponse_File(nullptr) {} + virtual ~CodeGeneratorResponse_File(); + + CodeGeneratorResponse_File(const CodeGeneratorResponse_File& from); + CodeGeneratorResponse_File(CodeGeneratorResponse_File&& from) noexcept + : CodeGeneratorResponse_File() { + *this = ::std::move(from); + } + + inline CodeGeneratorResponse_File& operator=(const CodeGeneratorResponse_File& from) { + CopyFrom(from); + return *this; + } + inline CodeGeneratorResponse_File& operator=(CodeGeneratorResponse_File&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const CodeGeneratorResponse_File& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const CodeGeneratorResponse_File* internal_default_instance() { + return reinterpret_cast( + &_CodeGeneratorResponse_File_default_instance_); + } + static constexpr int kIndexInFileMessages = + 2; + + friend void swap(CodeGeneratorResponse_File& a, CodeGeneratorResponse_File& b) { + a.Swap(&b); + } + inline void Swap(CodeGeneratorResponse_File* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(CodeGeneratorResponse_File* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline CodeGeneratorResponse_File* New() const final { + return CreateMaybeMessage(nullptr); + } + + CodeGeneratorResponse_File* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const CodeGeneratorResponse_File& from); + void MergeFrom(const CodeGeneratorResponse_File& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(CodeGeneratorResponse_File* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "google.protobuf.compiler.CodeGeneratorResponse.File"; + } + protected: + explicit CodeGeneratorResponse_File(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_google_2fprotobuf_2fcompiler_2fplugin_2eproto); + return ::descriptor_table_google_2fprotobuf_2fcompiler_2fplugin_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kNameFieldNumber = 1, + kInsertionPointFieldNumber = 2, + kContentFieldNumber = 15, + }; + // optional string name = 1; + bool has_name() const; + private: + bool _internal_has_name() const; + public: + void clear_name(); + const std::string& name() const; + void set_name(const std::string& value); + void set_name(std::string&& value); + void set_name(const char* value); + void set_name(const char* value, size_t size); + std::string* mutable_name(); + std::string* release_name(); + void set_allocated_name(std::string* name); + private: + const std::string& _internal_name() const; + void _internal_set_name(const std::string& value); + std::string* _internal_mutable_name(); + public: + + // optional string insertion_point = 2; + bool has_insertion_point() const; + private: + bool _internal_has_insertion_point() const; + public: + void clear_insertion_point(); + const std::string& insertion_point() const; + void set_insertion_point(const std::string& value); + void set_insertion_point(std::string&& value); + void set_insertion_point(const char* value); + void set_insertion_point(const char* value, size_t size); + std::string* mutable_insertion_point(); + std::string* release_insertion_point(); + void set_allocated_insertion_point(std::string* insertion_point); + private: + const std::string& _internal_insertion_point() const; + void _internal_set_insertion_point(const std::string& value); + std::string* _internal_mutable_insertion_point(); + public: + + // optional string content = 15; + bool has_content() const; + private: + bool _internal_has_content() const; + public: + void clear_content(); + const std::string& content() const; + void set_content(const std::string& value); + void set_content(std::string&& value); + void set_content(const char* value); + void set_content(const char* value, size_t size); + std::string* mutable_content(); + std::string* release_content(); + void set_allocated_content(std::string* content); + private: + const std::string& _internal_content() const; + void _internal_set_content(const std::string& value); + std::string* _internal_mutable_content(); + public: + + // @@protoc_insertion_point(class_scope:google.protobuf.compiler.CodeGeneratorResponse.File) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr name_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr insertion_point_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr content_; + friend struct ::TableStruct_google_2fprotobuf_2fcompiler_2fplugin_2eproto; +}; +// ------------------------------------------------------------------- + +class PROTOC_EXPORT CodeGeneratorResponse PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:google.protobuf.compiler.CodeGeneratorResponse) */ { + public: + inline CodeGeneratorResponse() : CodeGeneratorResponse(nullptr) {} + virtual ~CodeGeneratorResponse(); + + CodeGeneratorResponse(const CodeGeneratorResponse& from); + CodeGeneratorResponse(CodeGeneratorResponse&& from) noexcept + : CodeGeneratorResponse() { + *this = ::std::move(from); + } + + inline CodeGeneratorResponse& operator=(const CodeGeneratorResponse& from) { + CopyFrom(from); + return *this; + } + inline CodeGeneratorResponse& operator=(CodeGeneratorResponse&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const CodeGeneratorResponse& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const CodeGeneratorResponse* internal_default_instance() { + return reinterpret_cast( + &_CodeGeneratorResponse_default_instance_); + } + static constexpr int kIndexInFileMessages = + 3; + + friend void swap(CodeGeneratorResponse& a, CodeGeneratorResponse& b) { + a.Swap(&b); + } + inline void Swap(CodeGeneratorResponse* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(CodeGeneratorResponse* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline CodeGeneratorResponse* New() const final { + return CreateMaybeMessage(nullptr); + } + + CodeGeneratorResponse* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const CodeGeneratorResponse& from); + void MergeFrom(const CodeGeneratorResponse& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(CodeGeneratorResponse* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "google.protobuf.compiler.CodeGeneratorResponse"; + } + protected: + explicit CodeGeneratorResponse(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_google_2fprotobuf_2fcompiler_2fplugin_2eproto); + return ::descriptor_table_google_2fprotobuf_2fcompiler_2fplugin_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + typedef CodeGeneratorResponse_File File; + + typedef CodeGeneratorResponse_Feature Feature; + static constexpr Feature FEATURE_NONE = + CodeGeneratorResponse_Feature_FEATURE_NONE; + static constexpr Feature FEATURE_PROTO3_OPTIONAL = + CodeGeneratorResponse_Feature_FEATURE_PROTO3_OPTIONAL; + static inline bool Feature_IsValid(int value) { + return CodeGeneratorResponse_Feature_IsValid(value); + } + static constexpr Feature Feature_MIN = + CodeGeneratorResponse_Feature_Feature_MIN; + static constexpr Feature Feature_MAX = + CodeGeneratorResponse_Feature_Feature_MAX; + static constexpr int Feature_ARRAYSIZE = + CodeGeneratorResponse_Feature_Feature_ARRAYSIZE; + static inline const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* + Feature_descriptor() { + return CodeGeneratorResponse_Feature_descriptor(); + } + template + static inline const std::string& Feature_Name(T enum_t_value) { + static_assert(::std::is_same::value || + ::std::is_integral::value, + "Incorrect type passed to function Feature_Name."); + return CodeGeneratorResponse_Feature_Name(enum_t_value); + } + static inline bool Feature_Parse(::PROTOBUF_NAMESPACE_ID::ConstStringParam name, + Feature* value) { + return CodeGeneratorResponse_Feature_Parse(name, value); + } + + // accessors ------------------------------------------------------- + + enum : int { + kFileFieldNumber = 15, + kErrorFieldNumber = 1, + kSupportedFeaturesFieldNumber = 2, + }; + // repeated .google.protobuf.compiler.CodeGeneratorResponse.File file = 15; + int file_size() const; + private: + int _internal_file_size() const; + public: + void clear_file(); + PROTOBUF_NAMESPACE_ID::compiler::CodeGeneratorResponse_File* mutable_file(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::compiler::CodeGeneratorResponse_File >* + mutable_file(); + private: + const PROTOBUF_NAMESPACE_ID::compiler::CodeGeneratorResponse_File& _internal_file(int index) const; + PROTOBUF_NAMESPACE_ID::compiler::CodeGeneratorResponse_File* _internal_add_file(); + public: + const PROTOBUF_NAMESPACE_ID::compiler::CodeGeneratorResponse_File& file(int index) const; + PROTOBUF_NAMESPACE_ID::compiler::CodeGeneratorResponse_File* add_file(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::compiler::CodeGeneratorResponse_File >& + file() const; + + // optional string error = 1; + bool has_error() const; + private: + bool _internal_has_error() const; + public: + void clear_error(); + const std::string& error() const; + void set_error(const std::string& value); + void set_error(std::string&& value); + void set_error(const char* value); + void set_error(const char* value, size_t size); + std::string* mutable_error(); + std::string* release_error(); + void set_allocated_error(std::string* error); + private: + const std::string& _internal_error() const; + void _internal_set_error(const std::string& value); + std::string* _internal_mutable_error(); + public: + + // optional uint64 supported_features = 2; + bool has_supported_features() const; + private: + bool _internal_has_supported_features() const; + public: + void clear_supported_features(); + ::PROTOBUF_NAMESPACE_ID::uint64 supported_features() const; + void set_supported_features(::PROTOBUF_NAMESPACE_ID::uint64 value); + private: + ::PROTOBUF_NAMESPACE_ID::uint64 _internal_supported_features() const; + void _internal_set_supported_features(::PROTOBUF_NAMESPACE_ID::uint64 value); + public: + + // @@protoc_insertion_point(class_scope:google.protobuf.compiler.CodeGeneratorResponse) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::compiler::CodeGeneratorResponse_File > file_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr error_; + ::PROTOBUF_NAMESPACE_ID::uint64 supported_features_; + friend struct ::TableStruct_google_2fprotobuf_2fcompiler_2fplugin_2eproto; +}; +// =================================================================== + + +// =================================================================== + +#ifdef __GNUC__ + #pragma GCC diagnostic push + #pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif // __GNUC__ +// Version + +// optional int32 major = 1; +inline bool Version::_internal_has_major() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + return value; +} +inline bool Version::has_major() const { + return _internal_has_major(); +} +inline void Version::clear_major() { + major_ = 0; + _has_bits_[0] &= ~0x00000002u; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 Version::_internal_major() const { + return major_; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 Version::major() const { + // @@protoc_insertion_point(field_get:google.protobuf.compiler.Version.major) + return _internal_major(); +} +inline void Version::_internal_set_major(::PROTOBUF_NAMESPACE_ID::int32 value) { + _has_bits_[0] |= 0x00000002u; + major_ = value; +} +inline void Version::set_major(::PROTOBUF_NAMESPACE_ID::int32 value) { + _internal_set_major(value); + // @@protoc_insertion_point(field_set:google.protobuf.compiler.Version.major) +} + +// optional int32 minor = 2; +inline bool Version::_internal_has_minor() const { + bool value = (_has_bits_[0] & 0x00000004u) != 0; + return value; +} +inline bool Version::has_minor() const { + return _internal_has_minor(); +} +inline void Version::clear_minor() { + minor_ = 0; + _has_bits_[0] &= ~0x00000004u; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 Version::_internal_minor() const { + return minor_; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 Version::minor() const { + // @@protoc_insertion_point(field_get:google.protobuf.compiler.Version.minor) + return _internal_minor(); +} +inline void Version::_internal_set_minor(::PROTOBUF_NAMESPACE_ID::int32 value) { + _has_bits_[0] |= 0x00000004u; + minor_ = value; +} +inline void Version::set_minor(::PROTOBUF_NAMESPACE_ID::int32 value) { + _internal_set_minor(value); + // @@protoc_insertion_point(field_set:google.protobuf.compiler.Version.minor) +} + +// optional int32 patch = 3; +inline bool Version::_internal_has_patch() const { + bool value = (_has_bits_[0] & 0x00000008u) != 0; + return value; +} +inline bool Version::has_patch() const { + return _internal_has_patch(); +} +inline void Version::clear_patch() { + patch_ = 0; + _has_bits_[0] &= ~0x00000008u; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 Version::_internal_patch() const { + return patch_; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 Version::patch() const { + // @@protoc_insertion_point(field_get:google.protobuf.compiler.Version.patch) + return _internal_patch(); +} +inline void Version::_internal_set_patch(::PROTOBUF_NAMESPACE_ID::int32 value) { + _has_bits_[0] |= 0x00000008u; + patch_ = value; +} +inline void Version::set_patch(::PROTOBUF_NAMESPACE_ID::int32 value) { + _internal_set_patch(value); + // @@protoc_insertion_point(field_set:google.protobuf.compiler.Version.patch) +} + +// optional string suffix = 4; +inline bool Version::_internal_has_suffix() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool Version::has_suffix() const { + return _internal_has_suffix(); +} +inline void Version::clear_suffix() { + suffix_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000001u; +} +inline const std::string& Version::suffix() const { + // @@protoc_insertion_point(field_get:google.protobuf.compiler.Version.suffix) + return _internal_suffix(); +} +inline void Version::set_suffix(const std::string& value) { + _internal_set_suffix(value); + // @@protoc_insertion_point(field_set:google.protobuf.compiler.Version.suffix) +} +inline std::string* Version::mutable_suffix() { + // @@protoc_insertion_point(field_mutable:google.protobuf.compiler.Version.suffix) + return _internal_mutable_suffix(); +} +inline const std::string& Version::_internal_suffix() const { + return suffix_.Get(); +} +inline void Version::_internal_set_suffix(const std::string& value) { + _has_bits_[0] |= 0x00000001u; + suffix_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void Version::set_suffix(std::string&& value) { + _has_bits_[0] |= 0x00000001u; + suffix_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:google.protobuf.compiler.Version.suffix) +} +inline void Version::set_suffix(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000001u; + suffix_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:google.protobuf.compiler.Version.suffix) +} +inline void Version::set_suffix(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000001u; + suffix_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:google.protobuf.compiler.Version.suffix) +} +inline std::string* Version::_internal_mutable_suffix() { + _has_bits_[0] |= 0x00000001u; + return suffix_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* Version::release_suffix() { + // @@protoc_insertion_point(field_release:google.protobuf.compiler.Version.suffix) + if (!_internal_has_suffix()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000001u; + return suffix_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void Version::set_allocated_suffix(std::string* suffix) { + if (suffix != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + suffix_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), suffix, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:google.protobuf.compiler.Version.suffix) +} + +// ------------------------------------------------------------------- + +// CodeGeneratorRequest + +// repeated string file_to_generate = 1; +inline int CodeGeneratorRequest::_internal_file_to_generate_size() const { + return file_to_generate_.size(); +} +inline int CodeGeneratorRequest::file_to_generate_size() const { + return _internal_file_to_generate_size(); +} +inline void CodeGeneratorRequest::clear_file_to_generate() { + file_to_generate_.Clear(); +} +inline std::string* CodeGeneratorRequest::add_file_to_generate() { + // @@protoc_insertion_point(field_add_mutable:google.protobuf.compiler.CodeGeneratorRequest.file_to_generate) + return _internal_add_file_to_generate(); +} +inline const std::string& CodeGeneratorRequest::_internal_file_to_generate(int index) const { + return file_to_generate_.Get(index); +} +inline const std::string& CodeGeneratorRequest::file_to_generate(int index) const { + // @@protoc_insertion_point(field_get:google.protobuf.compiler.CodeGeneratorRequest.file_to_generate) + return _internal_file_to_generate(index); +} +inline std::string* CodeGeneratorRequest::mutable_file_to_generate(int index) { + // @@protoc_insertion_point(field_mutable:google.protobuf.compiler.CodeGeneratorRequest.file_to_generate) + return file_to_generate_.Mutable(index); +} +inline void CodeGeneratorRequest::set_file_to_generate(int index, const std::string& value) { + // @@protoc_insertion_point(field_set:google.protobuf.compiler.CodeGeneratorRequest.file_to_generate) + file_to_generate_.Mutable(index)->assign(value); +} +inline void CodeGeneratorRequest::set_file_to_generate(int index, std::string&& value) { + // @@protoc_insertion_point(field_set:google.protobuf.compiler.CodeGeneratorRequest.file_to_generate) + file_to_generate_.Mutable(index)->assign(std::move(value)); +} +inline void CodeGeneratorRequest::set_file_to_generate(int index, const char* value) { + GOOGLE_DCHECK(value != nullptr); + file_to_generate_.Mutable(index)->assign(value); + // @@protoc_insertion_point(field_set_char:google.protobuf.compiler.CodeGeneratorRequest.file_to_generate) +} +inline void CodeGeneratorRequest::set_file_to_generate(int index, const char* value, size_t size) { + file_to_generate_.Mutable(index)->assign( + reinterpret_cast(value), size); + // @@protoc_insertion_point(field_set_pointer:google.protobuf.compiler.CodeGeneratorRequest.file_to_generate) +} +inline std::string* CodeGeneratorRequest::_internal_add_file_to_generate() { + return file_to_generate_.Add(); +} +inline void CodeGeneratorRequest::add_file_to_generate(const std::string& value) { + file_to_generate_.Add()->assign(value); + // @@protoc_insertion_point(field_add:google.protobuf.compiler.CodeGeneratorRequest.file_to_generate) +} +inline void CodeGeneratorRequest::add_file_to_generate(std::string&& value) { + file_to_generate_.Add(std::move(value)); + // @@protoc_insertion_point(field_add:google.protobuf.compiler.CodeGeneratorRequest.file_to_generate) +} +inline void CodeGeneratorRequest::add_file_to_generate(const char* value) { + GOOGLE_DCHECK(value != nullptr); + file_to_generate_.Add()->assign(value); + // @@protoc_insertion_point(field_add_char:google.protobuf.compiler.CodeGeneratorRequest.file_to_generate) +} +inline void CodeGeneratorRequest::add_file_to_generate(const char* value, size_t size) { + file_to_generate_.Add()->assign(reinterpret_cast(value), size); + // @@protoc_insertion_point(field_add_pointer:google.protobuf.compiler.CodeGeneratorRequest.file_to_generate) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField& +CodeGeneratorRequest::file_to_generate() const { + // @@protoc_insertion_point(field_list:google.protobuf.compiler.CodeGeneratorRequest.file_to_generate) + return file_to_generate_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField* +CodeGeneratorRequest::mutable_file_to_generate() { + // @@protoc_insertion_point(field_mutable_list:google.protobuf.compiler.CodeGeneratorRequest.file_to_generate) + return &file_to_generate_; +} + +// optional string parameter = 2; +inline bool CodeGeneratorRequest::_internal_has_parameter() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool CodeGeneratorRequest::has_parameter() const { + return _internal_has_parameter(); +} +inline void CodeGeneratorRequest::clear_parameter() { + parameter_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000001u; +} +inline const std::string& CodeGeneratorRequest::parameter() const { + // @@protoc_insertion_point(field_get:google.protobuf.compiler.CodeGeneratorRequest.parameter) + return _internal_parameter(); +} +inline void CodeGeneratorRequest::set_parameter(const std::string& value) { + _internal_set_parameter(value); + // @@protoc_insertion_point(field_set:google.protobuf.compiler.CodeGeneratorRequest.parameter) +} +inline std::string* CodeGeneratorRequest::mutable_parameter() { + // @@protoc_insertion_point(field_mutable:google.protobuf.compiler.CodeGeneratorRequest.parameter) + return _internal_mutable_parameter(); +} +inline const std::string& CodeGeneratorRequest::_internal_parameter() const { + return parameter_.Get(); +} +inline void CodeGeneratorRequest::_internal_set_parameter(const std::string& value) { + _has_bits_[0] |= 0x00000001u; + parameter_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void CodeGeneratorRequest::set_parameter(std::string&& value) { + _has_bits_[0] |= 0x00000001u; + parameter_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:google.protobuf.compiler.CodeGeneratorRequest.parameter) +} +inline void CodeGeneratorRequest::set_parameter(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000001u; + parameter_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:google.protobuf.compiler.CodeGeneratorRequest.parameter) +} +inline void CodeGeneratorRequest::set_parameter(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000001u; + parameter_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:google.protobuf.compiler.CodeGeneratorRequest.parameter) +} +inline std::string* CodeGeneratorRequest::_internal_mutable_parameter() { + _has_bits_[0] |= 0x00000001u; + return parameter_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* CodeGeneratorRequest::release_parameter() { + // @@protoc_insertion_point(field_release:google.protobuf.compiler.CodeGeneratorRequest.parameter) + if (!_internal_has_parameter()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000001u; + return parameter_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void CodeGeneratorRequest::set_allocated_parameter(std::string* parameter) { + if (parameter != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + parameter_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), parameter, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:google.protobuf.compiler.CodeGeneratorRequest.parameter) +} + +// repeated .google.protobuf.FileDescriptorProto proto_file = 15; +inline int CodeGeneratorRequest::_internal_proto_file_size() const { + return proto_file_.size(); +} +inline int CodeGeneratorRequest::proto_file_size() const { + return _internal_proto_file_size(); +} +inline PROTOBUF_NAMESPACE_ID::FileDescriptorProto* CodeGeneratorRequest::mutable_proto_file(int index) { + // @@protoc_insertion_point(field_mutable:google.protobuf.compiler.CodeGeneratorRequest.proto_file) + return proto_file_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::FileDescriptorProto >* +CodeGeneratorRequest::mutable_proto_file() { + // @@protoc_insertion_point(field_mutable_list:google.protobuf.compiler.CodeGeneratorRequest.proto_file) + return &proto_file_; +} +inline const PROTOBUF_NAMESPACE_ID::FileDescriptorProto& CodeGeneratorRequest::_internal_proto_file(int index) const { + return proto_file_.Get(index); +} +inline const PROTOBUF_NAMESPACE_ID::FileDescriptorProto& CodeGeneratorRequest::proto_file(int index) const { + // @@protoc_insertion_point(field_get:google.protobuf.compiler.CodeGeneratorRequest.proto_file) + return _internal_proto_file(index); +} +inline PROTOBUF_NAMESPACE_ID::FileDescriptorProto* CodeGeneratorRequest::_internal_add_proto_file() { + return proto_file_.Add(); +} +inline PROTOBUF_NAMESPACE_ID::FileDescriptorProto* CodeGeneratorRequest::add_proto_file() { + // @@protoc_insertion_point(field_add:google.protobuf.compiler.CodeGeneratorRequest.proto_file) + return _internal_add_proto_file(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::FileDescriptorProto >& +CodeGeneratorRequest::proto_file() const { + // @@protoc_insertion_point(field_list:google.protobuf.compiler.CodeGeneratorRequest.proto_file) + return proto_file_; +} + +// optional .google.protobuf.compiler.Version compiler_version = 3; +inline bool CodeGeneratorRequest::_internal_has_compiler_version() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + PROTOBUF_ASSUME(!value || compiler_version_ != nullptr); + return value; +} +inline bool CodeGeneratorRequest::has_compiler_version() const { + return _internal_has_compiler_version(); +} +inline void CodeGeneratorRequest::clear_compiler_version() { + if (compiler_version_ != nullptr) compiler_version_->Clear(); + _has_bits_[0] &= ~0x00000002u; +} +inline const PROTOBUF_NAMESPACE_ID::compiler::Version& CodeGeneratorRequest::_internal_compiler_version() const { + const PROTOBUF_NAMESPACE_ID::compiler::Version* p = compiler_version_; + return p != nullptr ? *p : *reinterpret_cast( + &PROTOBUF_NAMESPACE_ID::compiler::_Version_default_instance_); +} +inline const PROTOBUF_NAMESPACE_ID::compiler::Version& CodeGeneratorRequest::compiler_version() const { + // @@protoc_insertion_point(field_get:google.protobuf.compiler.CodeGeneratorRequest.compiler_version) + return _internal_compiler_version(); +} +inline void CodeGeneratorRequest::unsafe_arena_set_allocated_compiler_version( + PROTOBUF_NAMESPACE_ID::compiler::Version* compiler_version) { + if (GetArena() == nullptr) { + delete reinterpret_cast<::PROTOBUF_NAMESPACE_ID::MessageLite*>(compiler_version_); + } + compiler_version_ = compiler_version; + if (compiler_version) { + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:google.protobuf.compiler.CodeGeneratorRequest.compiler_version) +} +inline PROTOBUF_NAMESPACE_ID::compiler::Version* CodeGeneratorRequest::release_compiler_version() { + _has_bits_[0] &= ~0x00000002u; + PROTOBUF_NAMESPACE_ID::compiler::Version* temp = compiler_version_; + compiler_version_ = nullptr; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + return temp; +} +inline PROTOBUF_NAMESPACE_ID::compiler::Version* CodeGeneratorRequest::unsafe_arena_release_compiler_version() { + // @@protoc_insertion_point(field_release:google.protobuf.compiler.CodeGeneratorRequest.compiler_version) + _has_bits_[0] &= ~0x00000002u; + PROTOBUF_NAMESPACE_ID::compiler::Version* temp = compiler_version_; + compiler_version_ = nullptr; + return temp; +} +inline PROTOBUF_NAMESPACE_ID::compiler::Version* CodeGeneratorRequest::_internal_mutable_compiler_version() { + _has_bits_[0] |= 0x00000002u; + if (compiler_version_ == nullptr) { + auto* p = CreateMaybeMessage(GetArena()); + compiler_version_ = p; + } + return compiler_version_; +} +inline PROTOBUF_NAMESPACE_ID::compiler::Version* CodeGeneratorRequest::mutable_compiler_version() { + // @@protoc_insertion_point(field_mutable:google.protobuf.compiler.CodeGeneratorRequest.compiler_version) + return _internal_mutable_compiler_version(); +} +inline void CodeGeneratorRequest::set_allocated_compiler_version(PROTOBUF_NAMESPACE_ID::compiler::Version* compiler_version) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + if (message_arena == nullptr) { + delete compiler_version_; + } + if (compiler_version) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + ::PROTOBUF_NAMESPACE_ID::Arena::GetArena(compiler_version); + if (message_arena != submessage_arena) { + compiler_version = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, compiler_version, submessage_arena); + } + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + compiler_version_ = compiler_version; + // @@protoc_insertion_point(field_set_allocated:google.protobuf.compiler.CodeGeneratorRequest.compiler_version) +} + +// ------------------------------------------------------------------- + +// CodeGeneratorResponse_File + +// optional string name = 1; +inline bool CodeGeneratorResponse_File::_internal_has_name() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool CodeGeneratorResponse_File::has_name() const { + return _internal_has_name(); +} +inline void CodeGeneratorResponse_File::clear_name() { + name_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000001u; +} +inline const std::string& CodeGeneratorResponse_File::name() const { + // @@protoc_insertion_point(field_get:google.protobuf.compiler.CodeGeneratorResponse.File.name) + return _internal_name(); +} +inline void CodeGeneratorResponse_File::set_name(const std::string& value) { + _internal_set_name(value); + // @@protoc_insertion_point(field_set:google.protobuf.compiler.CodeGeneratorResponse.File.name) +} +inline std::string* CodeGeneratorResponse_File::mutable_name() { + // @@protoc_insertion_point(field_mutable:google.protobuf.compiler.CodeGeneratorResponse.File.name) + return _internal_mutable_name(); +} +inline const std::string& CodeGeneratorResponse_File::_internal_name() const { + return name_.Get(); +} +inline void CodeGeneratorResponse_File::_internal_set_name(const std::string& value) { + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void CodeGeneratorResponse_File::set_name(std::string&& value) { + _has_bits_[0] |= 0x00000001u; + name_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:google.protobuf.compiler.CodeGeneratorResponse.File.name) +} +inline void CodeGeneratorResponse_File::set_name(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:google.protobuf.compiler.CodeGeneratorResponse.File.name) +} +inline void CodeGeneratorResponse_File::set_name(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:google.protobuf.compiler.CodeGeneratorResponse.File.name) +} +inline std::string* CodeGeneratorResponse_File::_internal_mutable_name() { + _has_bits_[0] |= 0x00000001u; + return name_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* CodeGeneratorResponse_File::release_name() { + // @@protoc_insertion_point(field_release:google.protobuf.compiler.CodeGeneratorResponse.File.name) + if (!_internal_has_name()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000001u; + return name_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void CodeGeneratorResponse_File::set_allocated_name(std::string* name) { + if (name != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + name_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), name, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:google.protobuf.compiler.CodeGeneratorResponse.File.name) +} + +// optional string insertion_point = 2; +inline bool CodeGeneratorResponse_File::_internal_has_insertion_point() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + return value; +} +inline bool CodeGeneratorResponse_File::has_insertion_point() const { + return _internal_has_insertion_point(); +} +inline void CodeGeneratorResponse_File::clear_insertion_point() { + insertion_point_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000002u; +} +inline const std::string& CodeGeneratorResponse_File::insertion_point() const { + // @@protoc_insertion_point(field_get:google.protobuf.compiler.CodeGeneratorResponse.File.insertion_point) + return _internal_insertion_point(); +} +inline void CodeGeneratorResponse_File::set_insertion_point(const std::string& value) { + _internal_set_insertion_point(value); + // @@protoc_insertion_point(field_set:google.protobuf.compiler.CodeGeneratorResponse.File.insertion_point) +} +inline std::string* CodeGeneratorResponse_File::mutable_insertion_point() { + // @@protoc_insertion_point(field_mutable:google.protobuf.compiler.CodeGeneratorResponse.File.insertion_point) + return _internal_mutable_insertion_point(); +} +inline const std::string& CodeGeneratorResponse_File::_internal_insertion_point() const { + return insertion_point_.Get(); +} +inline void CodeGeneratorResponse_File::_internal_set_insertion_point(const std::string& value) { + _has_bits_[0] |= 0x00000002u; + insertion_point_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void CodeGeneratorResponse_File::set_insertion_point(std::string&& value) { + _has_bits_[0] |= 0x00000002u; + insertion_point_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:google.protobuf.compiler.CodeGeneratorResponse.File.insertion_point) +} +inline void CodeGeneratorResponse_File::set_insertion_point(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000002u; + insertion_point_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:google.protobuf.compiler.CodeGeneratorResponse.File.insertion_point) +} +inline void CodeGeneratorResponse_File::set_insertion_point(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000002u; + insertion_point_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:google.protobuf.compiler.CodeGeneratorResponse.File.insertion_point) +} +inline std::string* CodeGeneratorResponse_File::_internal_mutable_insertion_point() { + _has_bits_[0] |= 0x00000002u; + return insertion_point_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* CodeGeneratorResponse_File::release_insertion_point() { + // @@protoc_insertion_point(field_release:google.protobuf.compiler.CodeGeneratorResponse.File.insertion_point) + if (!_internal_has_insertion_point()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000002u; + return insertion_point_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void CodeGeneratorResponse_File::set_allocated_insertion_point(std::string* insertion_point) { + if (insertion_point != nullptr) { + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + insertion_point_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), insertion_point, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:google.protobuf.compiler.CodeGeneratorResponse.File.insertion_point) +} + +// optional string content = 15; +inline bool CodeGeneratorResponse_File::_internal_has_content() const { + bool value = (_has_bits_[0] & 0x00000004u) != 0; + return value; +} +inline bool CodeGeneratorResponse_File::has_content() const { + return _internal_has_content(); +} +inline void CodeGeneratorResponse_File::clear_content() { + content_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000004u; +} +inline const std::string& CodeGeneratorResponse_File::content() const { + // @@protoc_insertion_point(field_get:google.protobuf.compiler.CodeGeneratorResponse.File.content) + return _internal_content(); +} +inline void CodeGeneratorResponse_File::set_content(const std::string& value) { + _internal_set_content(value); + // @@protoc_insertion_point(field_set:google.protobuf.compiler.CodeGeneratorResponse.File.content) +} +inline std::string* CodeGeneratorResponse_File::mutable_content() { + // @@protoc_insertion_point(field_mutable:google.protobuf.compiler.CodeGeneratorResponse.File.content) + return _internal_mutable_content(); +} +inline const std::string& CodeGeneratorResponse_File::_internal_content() const { + return content_.Get(); +} +inline void CodeGeneratorResponse_File::_internal_set_content(const std::string& value) { + _has_bits_[0] |= 0x00000004u; + content_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void CodeGeneratorResponse_File::set_content(std::string&& value) { + _has_bits_[0] |= 0x00000004u; + content_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:google.protobuf.compiler.CodeGeneratorResponse.File.content) +} +inline void CodeGeneratorResponse_File::set_content(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000004u; + content_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:google.protobuf.compiler.CodeGeneratorResponse.File.content) +} +inline void CodeGeneratorResponse_File::set_content(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000004u; + content_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:google.protobuf.compiler.CodeGeneratorResponse.File.content) +} +inline std::string* CodeGeneratorResponse_File::_internal_mutable_content() { + _has_bits_[0] |= 0x00000004u; + return content_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* CodeGeneratorResponse_File::release_content() { + // @@protoc_insertion_point(field_release:google.protobuf.compiler.CodeGeneratorResponse.File.content) + if (!_internal_has_content()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000004u; + return content_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void CodeGeneratorResponse_File::set_allocated_content(std::string* content) { + if (content != nullptr) { + _has_bits_[0] |= 0x00000004u; + } else { + _has_bits_[0] &= ~0x00000004u; + } + content_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), content, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:google.protobuf.compiler.CodeGeneratorResponse.File.content) +} + +// ------------------------------------------------------------------- + +// CodeGeneratorResponse + +// optional string error = 1; +inline bool CodeGeneratorResponse::_internal_has_error() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool CodeGeneratorResponse::has_error() const { + return _internal_has_error(); +} +inline void CodeGeneratorResponse::clear_error() { + error_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000001u; +} +inline const std::string& CodeGeneratorResponse::error() const { + // @@protoc_insertion_point(field_get:google.protobuf.compiler.CodeGeneratorResponse.error) + return _internal_error(); +} +inline void CodeGeneratorResponse::set_error(const std::string& value) { + _internal_set_error(value); + // @@protoc_insertion_point(field_set:google.protobuf.compiler.CodeGeneratorResponse.error) +} +inline std::string* CodeGeneratorResponse::mutable_error() { + // @@protoc_insertion_point(field_mutable:google.protobuf.compiler.CodeGeneratorResponse.error) + return _internal_mutable_error(); +} +inline const std::string& CodeGeneratorResponse::_internal_error() const { + return error_.Get(); +} +inline void CodeGeneratorResponse::_internal_set_error(const std::string& value) { + _has_bits_[0] |= 0x00000001u; + error_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void CodeGeneratorResponse::set_error(std::string&& value) { + _has_bits_[0] |= 0x00000001u; + error_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:google.protobuf.compiler.CodeGeneratorResponse.error) +} +inline void CodeGeneratorResponse::set_error(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000001u; + error_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:google.protobuf.compiler.CodeGeneratorResponse.error) +} +inline void CodeGeneratorResponse::set_error(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000001u; + error_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:google.protobuf.compiler.CodeGeneratorResponse.error) +} +inline std::string* CodeGeneratorResponse::_internal_mutable_error() { + _has_bits_[0] |= 0x00000001u; + return error_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* CodeGeneratorResponse::release_error() { + // @@protoc_insertion_point(field_release:google.protobuf.compiler.CodeGeneratorResponse.error) + if (!_internal_has_error()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000001u; + return error_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void CodeGeneratorResponse::set_allocated_error(std::string* error) { + if (error != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + error_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), error, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:google.protobuf.compiler.CodeGeneratorResponse.error) +} + +// optional uint64 supported_features = 2; +inline bool CodeGeneratorResponse::_internal_has_supported_features() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + return value; +} +inline bool CodeGeneratorResponse::has_supported_features() const { + return _internal_has_supported_features(); +} +inline void CodeGeneratorResponse::clear_supported_features() { + supported_features_ = PROTOBUF_ULONGLONG(0); + _has_bits_[0] &= ~0x00000002u; +} +inline ::PROTOBUF_NAMESPACE_ID::uint64 CodeGeneratorResponse::_internal_supported_features() const { + return supported_features_; +} +inline ::PROTOBUF_NAMESPACE_ID::uint64 CodeGeneratorResponse::supported_features() const { + // @@protoc_insertion_point(field_get:google.protobuf.compiler.CodeGeneratorResponse.supported_features) + return _internal_supported_features(); +} +inline void CodeGeneratorResponse::_internal_set_supported_features(::PROTOBUF_NAMESPACE_ID::uint64 value) { + _has_bits_[0] |= 0x00000002u; + supported_features_ = value; +} +inline void CodeGeneratorResponse::set_supported_features(::PROTOBUF_NAMESPACE_ID::uint64 value) { + _internal_set_supported_features(value); + // @@protoc_insertion_point(field_set:google.protobuf.compiler.CodeGeneratorResponse.supported_features) +} + +// repeated .google.protobuf.compiler.CodeGeneratorResponse.File file = 15; +inline int CodeGeneratorResponse::_internal_file_size() const { + return file_.size(); +} +inline int CodeGeneratorResponse::file_size() const { + return _internal_file_size(); +} +inline void CodeGeneratorResponse::clear_file() { + file_.Clear(); +} +inline PROTOBUF_NAMESPACE_ID::compiler::CodeGeneratorResponse_File* CodeGeneratorResponse::mutable_file(int index) { + // @@protoc_insertion_point(field_mutable:google.protobuf.compiler.CodeGeneratorResponse.file) + return file_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::compiler::CodeGeneratorResponse_File >* +CodeGeneratorResponse::mutable_file() { + // @@protoc_insertion_point(field_mutable_list:google.protobuf.compiler.CodeGeneratorResponse.file) + return &file_; +} +inline const PROTOBUF_NAMESPACE_ID::compiler::CodeGeneratorResponse_File& CodeGeneratorResponse::_internal_file(int index) const { + return file_.Get(index); +} +inline const PROTOBUF_NAMESPACE_ID::compiler::CodeGeneratorResponse_File& CodeGeneratorResponse::file(int index) const { + // @@protoc_insertion_point(field_get:google.protobuf.compiler.CodeGeneratorResponse.file) + return _internal_file(index); +} +inline PROTOBUF_NAMESPACE_ID::compiler::CodeGeneratorResponse_File* CodeGeneratorResponse::_internal_add_file() { + return file_.Add(); +} +inline PROTOBUF_NAMESPACE_ID::compiler::CodeGeneratorResponse_File* CodeGeneratorResponse::add_file() { + // @@protoc_insertion_point(field_add:google.protobuf.compiler.CodeGeneratorResponse.file) + return _internal_add_file(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::compiler::CodeGeneratorResponse_File >& +CodeGeneratorResponse::file() const { + // @@protoc_insertion_point(field_list:google.protobuf.compiler.CodeGeneratorResponse.file) + return file_; +} + +#ifdef __GNUC__ + #pragma GCC diagnostic pop +#endif // __GNUC__ +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + + +// @@protoc_insertion_point(namespace_scope) + +} // namespace compiler +PROTOBUF_NAMESPACE_CLOSE + +PROTOBUF_NAMESPACE_OPEN + +template <> struct is_proto_enum< PROTOBUF_NAMESPACE_ID::compiler::CodeGeneratorResponse_Feature> : ::std::true_type {}; +template <> +inline const EnumDescriptor* GetEnumDescriptor< PROTOBUF_NAMESPACE_ID::compiler::CodeGeneratorResponse_Feature>() { + return PROTOBUF_NAMESPACE_ID::compiler::CodeGeneratorResponse_Feature_descriptor(); +} + +PROTOBUF_NAMESPACE_CLOSE + +// @@protoc_insertion_point(global_scope) + +#include +#endif // GOOGLE_PROTOBUF_INCLUDED_GOOGLE_PROTOBUF_INCLUDED_google_2fprotobuf_2fcompiler_2fplugin_2eproto + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/python/python_generator.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/python/python_generator.h new file mode 100644 index 0000000000000000000000000000000000000000..b3d3e7fd6c1dff0a191a3298b6a74d4c54cb3734 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/python/python_generator.h @@ -0,0 +1,187 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: robinson@google.com (Will Robinson) +// +// Generates Python code for a given .proto file. + +#ifndef GOOGLE_PROTOBUF_COMPILER_PYTHON_GENERATOR_H__ +#define GOOGLE_PROTOBUF_COMPILER_PYTHON_GENERATOR_H__ + +#include + +#include +#include + +#include + +namespace google { +namespace protobuf { + +class Descriptor; +class EnumDescriptor; +class EnumValueDescriptor; +class FieldDescriptor; +class OneofDescriptor; +class ServiceDescriptor; + +namespace io { +class Printer; +} + +namespace compiler { +namespace python { + +// CodeGenerator implementation for generated Python protocol buffer classes. +// If you create your own protocol compiler binary and you want it to support +// Python output, you can do so by registering an instance of this +// CodeGenerator with the CommandLineInterface in your main() function. +class PROTOC_EXPORT Generator : public CodeGenerator { + public: + Generator(); + virtual ~Generator(); + + // CodeGenerator methods. + bool Generate(const FileDescriptor* file, const std::string& parameter, + GeneratorContext* generator_context, + std::string* error) const override; + + uint64_t GetSupportedFeatures() const override; + + private: + void PrintImports() const; + void PrintFileDescriptor() const; + void PrintTopLevelEnums() const; + void PrintAllNestedEnumsInFile() const; + void PrintNestedEnums(const Descriptor& descriptor) const; + void PrintEnum(const EnumDescriptor& enum_descriptor) const; + + void PrintTopLevelExtensions() const; + + void PrintFieldDescriptor(const FieldDescriptor& field, + bool is_extension) const; + void PrintFieldDescriptorsInDescriptor( + const Descriptor& message_descriptor, bool is_extension, + const std::string& list_variable_name, int (Descriptor::*CountFn)() const, + const FieldDescriptor* (Descriptor::*GetterFn)(int)const) const; + void PrintFieldsInDescriptor(const Descriptor& message_descriptor) const; + void PrintExtensionsInDescriptor(const Descriptor& message_descriptor) const; + void PrintMessageDescriptors() const; + void PrintDescriptor(const Descriptor& message_descriptor) const; + void PrintNestedDescriptors(const Descriptor& containing_descriptor) const; + + void PrintMessages() const; + void PrintMessage(const Descriptor& message_descriptor, + const std::string& prefix, + std::vector* to_register, + bool is_nested) const; + void PrintNestedMessages(const Descriptor& containing_descriptor, + const std::string& prefix, + std::vector* to_register) const; + + void FixForeignFieldsInDescriptors() const; + void FixForeignFieldsInDescriptor( + const Descriptor& descriptor, + const Descriptor* containing_descriptor) const; + void FixForeignFieldsInField(const Descriptor* containing_type, + const FieldDescriptor& field, + const std::string& python_dict_name) const; + void AddMessageToFileDescriptor(const Descriptor& descriptor) const; + void AddEnumToFileDescriptor(const EnumDescriptor& descriptor) const; + void AddExtensionToFileDescriptor(const FieldDescriptor& descriptor) const; + void AddServiceToFileDescriptor(const ServiceDescriptor& descriptor) const; + std::string FieldReferencingExpression( + const Descriptor* containing_type, const FieldDescriptor& field, + const std::string& python_dict_name) const; + template + void FixContainingTypeInDescriptor( + const DescriptorT& descriptor, + const Descriptor* containing_descriptor) const; + + void FixForeignFieldsInExtensions() const; + void FixForeignFieldsInExtension( + const FieldDescriptor& extension_field) const; + void FixForeignFieldsInNestedExtensions(const Descriptor& descriptor) const; + + void PrintServices() const; + void PrintServiceDescriptors() const; + void PrintServiceDescriptor(const ServiceDescriptor& descriptor) const; + void PrintServiceClass(const ServiceDescriptor& descriptor) const; + void PrintServiceStub(const ServiceDescriptor& descriptor) const; + void PrintDescriptorKeyAndModuleName( + const ServiceDescriptor& descriptor) const; + + void PrintEnumValueDescriptor(const EnumValueDescriptor& descriptor) const; + std::string OptionsValue(const std::string& serialized_options) const; + bool GeneratingDescriptorProto() const; + + template + std::string ModuleLevelDescriptorName(const DescriptorT& descriptor) const; + std::string ModuleLevelMessageName(const Descriptor& descriptor) const; + std::string ModuleLevelServiceDescriptorName( + const ServiceDescriptor& descriptor) const; + + template + void PrintSerializedPbInterval(const DescriptorT& descriptor, + DescriptorProtoT& proto) const; + + void FixAllDescriptorOptions() const; + void FixOptionsForField(const FieldDescriptor& field) const; + void FixOptionsForOneof(const OneofDescriptor& oneof) const; + void FixOptionsForEnum(const EnumDescriptor& descriptor) const; + void FixOptionsForMessage(const Descriptor& descriptor) const; + + void CopyPublicDependenciesAliases(const std::string& copy_from, + const FileDescriptor* file) const; + + // Very coarse-grained lock to ensure that Generate() is reentrant. + // Guards file_, printer_ and file_descriptor_serialized_. + mutable Mutex mutex_; + mutable const FileDescriptor* file_; // Set in Generate(). Under mutex_. + mutable std::string file_descriptor_serialized_; + mutable io::Printer* printer_; // Set in Generate(). Under mutex_. + mutable bool pure_python_workable_; + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(Generator); +}; + +} // namespace python +} // namespace compiler +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_COMPILER_PYTHON_GENERATOR_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/ruby/ruby_generator.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/ruby/ruby_generator.h new file mode 100644 index 0000000000000000000000000000000000000000..9d297c5f183c8d7a041db0c5c074bf133bf9a222 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/ruby/ruby_generator.h @@ -0,0 +1,73 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Generates Ruby code for a given .proto file. + +#ifndef GOOGLE_PROTOBUF_COMPILER_RUBY_GENERATOR_H__ +#define GOOGLE_PROTOBUF_COMPILER_RUBY_GENERATOR_H__ + +#include + +#include + +#include + +namespace google { +namespace protobuf { +namespace compiler { +namespace ruby { + +// CodeGenerator implementation for generated Ruby protocol buffer classes. +// If you create your own protocol compiler binary and you want it to support +// Ruby output, you can do so by registering an instance of this +// CodeGenerator with the CommandLineInterface in your main() function. +class PROTOC_EXPORT Generator : public CodeGenerator { + bool Generate(const FileDescriptor* file, const string& parameter, + GeneratorContext* generator_context, + string* error) const override; + uint64_t GetSupportedFeatures() const override { + return FEATURE_PROTO3_OPTIONAL; + } +}; + +} // namespace ruby +} // namespace compiler +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_COMPILER_RUBY_GENERATOR_H__ + + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/descriptor.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/descriptor.h new file mode 100644 index 0000000000000000000000000000000000000000..1865cafd595c20b21e6a05b8054291b9a7376a71 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/descriptor.h @@ -0,0 +1,2325 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: kenton@google.com (Kenton Varda) +// Based on original Protocol Buffers design by +// Sanjay Ghemawat, Jeff Dean, and others. +// +// This file contains classes which describe a type of protocol message. +// You can use a message's descriptor to learn at runtime what fields +// it contains and what the types of those fields are. The Message +// interface also allows you to dynamically access and modify individual +// fields by passing the FieldDescriptor of the field you are interested +// in. +// +// Most users will not care about descriptors, because they will write +// code specific to certain protocol types and will simply use the classes +// generated by the protocol compiler directly. Advanced users who want +// to operate on arbitrary types (not known at compile time) may want to +// read descriptors in order to learn about the contents of a message. +// A very small number of users will want to construct their own +// Descriptors, either because they are implementing Message manually or +// because they are writing something like the protocol compiler. +// +// For an example of how you might use descriptors, see the code example +// at the top of message.h. + +#ifndef GOOGLE_PROTOBUF_DESCRIPTOR_H__ +#define GOOGLE_PROTOBUF_DESCRIPTOR_H__ + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +// TYPE_BOOL is defined in the MacOS's ConditionalMacros.h. +#ifdef TYPE_BOOL +#undef TYPE_BOOL +#endif // TYPE_BOOL + +#ifdef SWIG +#define PROTOBUF_EXPORT +#endif + + +namespace google { +namespace protobuf { + +// Defined in this file. +class Descriptor; +class FieldDescriptor; +class OneofDescriptor; +class EnumDescriptor; +class EnumValueDescriptor; +class ServiceDescriptor; +class MethodDescriptor; +class FileDescriptor; +class DescriptorDatabase; +class DescriptorPool; + +// Defined in descriptor.proto +class DescriptorProto; +class DescriptorProto_ExtensionRange; +class FieldDescriptorProto; +class OneofDescriptorProto; +class EnumDescriptorProto; +class EnumValueDescriptorProto; +class ServiceDescriptorProto; +class MethodDescriptorProto; +class FileDescriptorProto; +class MessageOptions; +class FieldOptions; +class OneofOptions; +class EnumOptions; +class EnumValueOptions; +class ExtensionRangeOptions; +class ServiceOptions; +class MethodOptions; +class FileOptions; +class UninterpretedOption; +class SourceCodeInfo; + +// Defined in message.h +class Message; +class Reflection; + +// Defined in descriptor.cc +class DescriptorBuilder; +class FileDescriptorTables; +struct Symbol; + +// Defined in unknown_field_set.h. +class UnknownField; + +// Defined in command_line_interface.cc +namespace compiler { +class CommandLineInterface; +namespace cpp { +// Defined in helpers.h +class Formatter; +} // namespace cpp +} // namespace compiler + +namespace descriptor_unittest { +class DescriptorTest; +} // namespace descriptor_unittest + +// Defined in printer.h +namespace io { +class Printer; +} // namespace io + +// NB, all indices are zero-based. +struct SourceLocation { + int start_line; + int end_line; + int start_column; + int end_column; + + // Doc comments found at the source location. + // See the comments in SourceCodeInfo.Location (descriptor.proto) for details. + std::string leading_comments; + std::string trailing_comments; + std::vector leading_detached_comments; +}; + +// Options when generating machine-parsable output from a descriptor with +// DebugString(). +struct DebugStringOptions { + // include original user comments as recorded in SourceLocation entries. N.B. + // that this must be |false| by default: several other pieces of code (for + // example, the C++ code generation for fields in the proto compiler) rely on + // DebugString() output being unobstructed by user comments. + bool include_comments; + // If true, elide the braced body in the debug string. + bool elide_group_body; + bool elide_oneof_body; + + DebugStringOptions() + : include_comments(false), + elide_group_body(false), + elide_oneof_body(false) { + } +}; + +// A class to handle the simplest cases of a lazily linked descriptor +// for a message type that isn't built at the time of cross linking, +// which is needed when a pool has lazily_build_dependencies_ set. +// Must be instantiated as mutable in a descriptor. +namespace internal { +class PROTOBUF_EXPORT LazyDescriptor { + public: + // Init function to be called at init time of a descriptor containing + // a LazyDescriptor. + void Init() { + descriptor_ = nullptr; + name_ = nullptr; + once_ = nullptr; + file_ = nullptr; + } + + // Sets the value of the descriptor if it is known during the descriptor + // building process. Not thread safe, should only be called during the + // descriptor build process. Should not be called after SetLazy has been + // called. + void Set(const Descriptor* descriptor); + + // Sets the information needed to lazily cross link the descriptor at a later + // time, SetLazy is not thread safe, should be called only once at descriptor + // build time if the symbol wasn't found and building of the file containing + // that type is delayed because lazily_build_dependencies_ is set on the pool. + // Should not be called after Set() has been called. + void SetLazy(StringPiece name, const FileDescriptor* file); + + // Returns the current value of the descriptor, thread-safe. If SetLazy(...) + // has been called, will do a one-time cross link of the type specified, + // building the descriptor file that contains the type if necessary. + inline const Descriptor* Get() { + Once(); + return descriptor_; + } + + private: + static void OnceStatic(LazyDescriptor* lazy); + void OnceInternal(); + void Once(); + + const Descriptor* descriptor_; + const std::string* name_; + internal::once_flag* once_; + const FileDescriptor* file_; +}; +} // namespace internal + +// Describes a type of protocol message, or a particular group within a +// message. To obtain the Descriptor for a given message object, call +// Message::GetDescriptor(). Generated message classes also have a +// static method called descriptor() which returns the type's descriptor. +// Use DescriptorPool to construct your own descriptors. +class PROTOBUF_EXPORT Descriptor { + public: + typedef DescriptorProto Proto; + + // The name of the message type, not including its scope. + const std::string& name() const; + + // The fully-qualified name of the message type, scope delimited by + // periods. For example, message type "Foo" which is declared in package + // "bar" has full name "bar.Foo". If a type "Baz" is nested within + // Foo, Baz's full_name is "bar.Foo.Baz". To get only the part that + // comes after the last '.', use name(). + const std::string& full_name() const; + + // Index of this descriptor within the file or containing type's message + // type array. + int index() const; + + // The .proto file in which this message type was defined. Never nullptr. + const FileDescriptor* file() const; + + // If this Descriptor describes a nested type, this returns the type + // in which it is nested. Otherwise, returns nullptr. + const Descriptor* containing_type() const; + + // Get options for this message type. These are specified in the .proto file + // by placing lines like "option foo = 1234;" in the message definition. + // Allowed options are defined by MessageOptions in descriptor.proto, and any + // available extensions of that message. + const MessageOptions& options() const; + + // Write the contents of this Descriptor into the given DescriptorProto. + // The target DescriptorProto must be clear before calling this; if it + // isn't, the result may be garbage. + void CopyTo(DescriptorProto* proto) const; + + // Write the contents of this descriptor in a human-readable form. Output + // will be suitable for re-parsing. + std::string DebugString() const; + + // Similar to DebugString(), but additionally takes options (e.g., + // include original user comments in output). + std::string DebugStringWithOptions(const DebugStringOptions& options) const; + + // Returns true if this is a placeholder for an unknown type. This will + // only be the case if this descriptor comes from a DescriptorPool + // with AllowUnknownDependencies() set. + bool is_placeholder() const; + + enum WellKnownType { + WELLKNOWNTYPE_UNSPECIFIED, // Not a well-known type. + + // Wrapper types. + WELLKNOWNTYPE_DOUBLEVALUE, // google.protobuf.DoubleValue + WELLKNOWNTYPE_FLOATVALUE, // google.protobuf.FloatValue + WELLKNOWNTYPE_INT64VALUE, // google.protobuf.Int64Value + WELLKNOWNTYPE_UINT64VALUE, // google.protobuf.UInt64Value + WELLKNOWNTYPE_INT32VALUE, // google.protobuf.Int32Value + WELLKNOWNTYPE_UINT32VALUE, // google.protobuf.UInt32Value + WELLKNOWNTYPE_STRINGVALUE, // google.protobuf.StringValue + WELLKNOWNTYPE_BYTESVALUE, // google.protobuf.BytesValue + WELLKNOWNTYPE_BOOLVALUE, // google.protobuf.BoolValue + + // Other well known types. + WELLKNOWNTYPE_ANY, // google.protobuf.Any + WELLKNOWNTYPE_FIELDMASK, // google.protobuf.FieldMask + WELLKNOWNTYPE_DURATION, // google.protobuf.Duration + WELLKNOWNTYPE_TIMESTAMP, // google.protobuf.Timestamp + WELLKNOWNTYPE_VALUE, // google.protobuf.Value + WELLKNOWNTYPE_LISTVALUE, // google.protobuf.ListValue + WELLKNOWNTYPE_STRUCT, // google.protobuf.Struct + + // New well-known types may be added in the future. + // Please make sure any switch() statements have a 'default' case. + __WELLKNOWNTYPE__DO_NOT_USE__ADD_DEFAULT_INSTEAD__, + }; + + WellKnownType well_known_type() const; + + // Field stuff ----------------------------------------------------- + + // The number of fields in this message type. + int field_count() const; + // Gets a field by index, where 0 <= index < field_count(). + // These are returned in the order they were defined in the .proto file. + const FieldDescriptor* field(int index) const; + + // Looks up a field by declared tag number. Returns nullptr if no such field + // exists. + const FieldDescriptor* FindFieldByNumber(int number) const; + // Looks up a field by name. Returns nullptr if no such field exists. + const FieldDescriptor* FindFieldByName(ConstStringParam name) const; + + // Looks up a field by lowercased name (as returned by lowercase_name()). + // This lookup may be ambiguous if multiple field names differ only by case, + // in which case the field returned is chosen arbitrarily from the matches. + const FieldDescriptor* FindFieldByLowercaseName( + ConstStringParam lowercase_name) const; + + // Looks up a field by camel-case name (as returned by camelcase_name()). + // This lookup may be ambiguous if multiple field names differ in a way that + // leads them to have identical camel-case names, in which case the field + // returned is chosen arbitrarily from the matches. + const FieldDescriptor* FindFieldByCamelcaseName( + ConstStringParam camelcase_name) const; + + // The number of oneofs in this message type. + int oneof_decl_count() const; + // The number of oneofs in this message type, excluding synthetic oneofs. + // Real oneofs always come first, so iterating up to real_oneof_decl_cout() + // will yield all real oneofs. + int real_oneof_decl_count() const; + // Get a oneof by index, where 0 <= index < oneof_decl_count(). + // These are returned in the order they were defined in the .proto file. + const OneofDescriptor* oneof_decl(int index) const; + + // Looks up a oneof by name. Returns nullptr if no such oneof exists. + const OneofDescriptor* FindOneofByName(ConstStringParam name) const; + + // Nested type stuff ----------------------------------------------- + + // The number of nested types in this message type. + int nested_type_count() const; + // Gets a nested type by index, where 0 <= index < nested_type_count(). + // These are returned in the order they were defined in the .proto file. + const Descriptor* nested_type(int index) const; + + // Looks up a nested type by name. Returns nullptr if no such nested type + // exists. + const Descriptor* FindNestedTypeByName(ConstStringParam name) const; + + // Enum stuff ------------------------------------------------------ + + // The number of enum types in this message type. + int enum_type_count() const; + // Gets an enum type by index, where 0 <= index < enum_type_count(). + // These are returned in the order they were defined in the .proto file. + const EnumDescriptor* enum_type(int index) const; + + // Looks up an enum type by name. Returns nullptr if no such enum type + // exists. + const EnumDescriptor* FindEnumTypeByName(ConstStringParam name) const; + + // Looks up an enum value by name, among all enum types in this message. + // Returns nullptr if no such value exists. + const EnumValueDescriptor* FindEnumValueByName(ConstStringParam name) const; + + // Extensions ------------------------------------------------------ + + // A range of field numbers which are designated for third-party + // extensions. + struct ExtensionRange { + typedef DescriptorProto_ExtensionRange Proto; + + typedef ExtensionRangeOptions OptionsType; + + // See Descriptor::CopyTo(). + void CopyTo(DescriptorProto_ExtensionRange* proto) const; + + int start; // inclusive + int end; // exclusive + + const ExtensionRangeOptions* options_; + }; + + // The number of extension ranges in this message type. + int extension_range_count() const; + // Gets an extension range by index, where 0 <= index < + // extension_range_count(). These are returned in the order they were defined + // in the .proto file. + const ExtensionRange* extension_range(int index) const; + + // Returns true if the number is in one of the extension ranges. + bool IsExtensionNumber(int number) const; + + // Returns nullptr if no extension range contains the given number. + const ExtensionRange* FindExtensionRangeContainingNumber(int number) const; + + // The number of extensions defined nested within this message type's scope. + // See doc: + // https://developers.google.com/protocol-buffers/docs/proto#nested-extensions + // + // Note that the extensions may be extending *other* messages. + // + // For example: + // message M1 { + // extensions 1 to max; + // } + // + // message M2 { + // extend M1 { + // optional int32 foo = 1; + // } + // } + // + // In this case, + // DescriptorPool::generated_pool() + // ->FindMessageTypeByName("M2") + // ->extension(0) + // will return "foo", even though "foo" is an extension of M1. + // To find all known extensions of a given message, instead use + // DescriptorPool::FindAllExtensions. + int extension_count() const; + // Get an extension by index, where 0 <= index < extension_count(). + // These are returned in the order they were defined in the .proto file. + const FieldDescriptor* extension(int index) const; + + // Looks up a named extension (which extends some *other* message type) + // defined within this message type's scope. + const FieldDescriptor* FindExtensionByName(ConstStringParam name) const; + + // Similar to FindFieldByLowercaseName(), but finds extensions defined within + // this message type's scope. + const FieldDescriptor* FindExtensionByLowercaseName( + ConstStringParam name) const; + + // Similar to FindFieldByCamelcaseName(), but finds extensions defined within + // this message type's scope. + const FieldDescriptor* FindExtensionByCamelcaseName( + ConstStringParam name) const; + + // Reserved fields ------------------------------------------------- + + // A range of reserved field numbers. + struct ReservedRange { + int start; // inclusive + int end; // exclusive + }; + + // The number of reserved ranges in this message type. + int reserved_range_count() const; + // Gets an reserved range by index, where 0 <= index < + // reserved_range_count(). These are returned in the order they were defined + // in the .proto file. + const ReservedRange* reserved_range(int index) const; + + // Returns true if the number is in one of the reserved ranges. + bool IsReservedNumber(int number) const; + + // Returns nullptr if no reserved range contains the given number. + const ReservedRange* FindReservedRangeContainingNumber(int number) const; + + // The number of reserved field names in this message type. + int reserved_name_count() const; + + // Gets a reserved name by index, where 0 <= index < reserved_name_count(). + const std::string& reserved_name(int index) const; + + // Returns true if the field name is reserved. + bool IsReservedName(ConstStringParam name) const; + + // Source Location --------------------------------------------------- + + // Updates |*out_location| to the source location of the complete + // extent of this message declaration. Returns false and leaves + // |*out_location| unchanged iff location information was not available. + bool GetSourceLocation(SourceLocation* out_location) const; + + // Maps -------------------------------------------------------------- + + // Returns the FieldDescriptor for the "key" field. If this isn't a map entry + // field, returns nullptr. + const FieldDescriptor* map_key() const; + + // Returns the FieldDescriptor for the "value" field. If this isn't a map + // entry field, returns nullptr. + const FieldDescriptor* map_value() const; + + private: + typedef MessageOptions OptionsType; + + // Allows tests to test CopyTo(proto, true). + friend class descriptor_unittest::DescriptorTest; + + // Allows access to GetLocationPath for annotations. + friend class io::Printer; + friend class compiler::cpp::Formatter; + + // Fill the json_name field of FieldDescriptorProto. + void CopyJsonNameTo(DescriptorProto* proto) const; + + // Internal version of DebugString; controls the level of indenting for + // correct depth. Takes |options| to control debug-string options, and + // |include_opening_clause| to indicate whether the "message ... " part of the + // clause has already been generated (this varies depending on context). + void DebugString(int depth, std::string* contents, + const DebugStringOptions& options, + bool include_opening_clause) const; + + // Walks up the descriptor tree to generate the source location path + // to this descriptor from the file root. + void GetLocationPath(std::vector* output) const; + + const std::string* name_; + const std::string* full_name_; + const FileDescriptor* file_; + const Descriptor* containing_type_; + const MessageOptions* options_; + + // These arrays are separated from their sizes to minimize padding on 64-bit. + FieldDescriptor* fields_; + OneofDescriptor* oneof_decls_; + Descriptor* nested_types_; + EnumDescriptor* enum_types_; + ExtensionRange* extension_ranges_; + FieldDescriptor* extensions_; + ReservedRange* reserved_ranges_; + const std::string** reserved_names_; + + int field_count_; + int oneof_decl_count_; + int real_oneof_decl_count_; + int nested_type_count_; + int enum_type_count_; + int extension_range_count_; + int extension_count_; + int reserved_range_count_; + int reserved_name_count_; + + // True if this is a placeholder for an unknown type. + bool is_placeholder_; + // True if this is a placeholder and the type name wasn't fully-qualified. + bool is_unqualified_placeholder_; + // Well known type. Stored as char to conserve space. + char well_known_type_; + + // IMPORTANT: If you add a new field, make sure to search for all instances + // of Allocate() and AllocateArray() in descriptor.cc + // and update them to initialize the field. + + // Must be constructed using DescriptorPool. + Descriptor() {} + friend class DescriptorBuilder; + friend class DescriptorPool; + friend class EnumDescriptor; + friend class FieldDescriptor; + friend class OneofDescriptor; + friend class MethodDescriptor; + friend class FileDescriptor; + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(Descriptor); +}; + + +// Describes a single field of a message. To get the descriptor for a given +// field, first get the Descriptor for the message in which it is defined, +// then call Descriptor::FindFieldByName(). To get a FieldDescriptor for +// an extension, do one of the following: +// - Get the Descriptor or FileDescriptor for its containing scope, then +// call Descriptor::FindExtensionByName() or +// FileDescriptor::FindExtensionByName(). +// - Given a DescriptorPool, call DescriptorPool::FindExtensionByNumber() or +// DescriptorPool::FindExtensionByPrintableName(). +// Use DescriptorPool to construct your own descriptors. +class PROTOBUF_EXPORT FieldDescriptor { + public: + typedef FieldDescriptorProto Proto; + + // Identifies a field type. 0 is reserved for errors. The order is weird + // for historical reasons. Types 12 and up are new in proto2. + enum Type { + TYPE_DOUBLE = 1, // double, exactly eight bytes on the wire. + TYPE_FLOAT = 2, // float, exactly four bytes on the wire. + TYPE_INT64 = 3, // int64, varint on the wire. Negative numbers + // take 10 bytes. Use TYPE_SINT64 if negative + // values are likely. + TYPE_UINT64 = 4, // uint64, varint on the wire. + TYPE_INT32 = 5, // int32, varint on the wire. Negative numbers + // take 10 bytes. Use TYPE_SINT32 if negative + // values are likely. + TYPE_FIXED64 = 6, // uint64, exactly eight bytes on the wire. + TYPE_FIXED32 = 7, // uint32, exactly four bytes on the wire. + TYPE_BOOL = 8, // bool, varint on the wire. + TYPE_STRING = 9, // UTF-8 text. + TYPE_GROUP = 10, // Tag-delimited message. Deprecated. + TYPE_MESSAGE = 11, // Length-delimited message. + + TYPE_BYTES = 12, // Arbitrary byte array. + TYPE_UINT32 = 13, // uint32, varint on the wire + TYPE_ENUM = 14, // Enum, varint on the wire + TYPE_SFIXED32 = 15, // int32, exactly four bytes on the wire + TYPE_SFIXED64 = 16, // int64, exactly eight bytes on the wire + TYPE_SINT32 = 17, // int32, ZigZag-encoded varint on the wire + TYPE_SINT64 = 18, // int64, ZigZag-encoded varint on the wire + + MAX_TYPE = 18, // Constant useful for defining lookup tables + // indexed by Type. + }; + + // Specifies the C++ data type used to represent the field. There is a + // fixed mapping from Type to CppType where each Type maps to exactly one + // CppType. 0 is reserved for errors. + enum CppType { + CPPTYPE_INT32 = 1, // TYPE_INT32, TYPE_SINT32, TYPE_SFIXED32 + CPPTYPE_INT64 = 2, // TYPE_INT64, TYPE_SINT64, TYPE_SFIXED64 + CPPTYPE_UINT32 = 3, // TYPE_UINT32, TYPE_FIXED32 + CPPTYPE_UINT64 = 4, // TYPE_UINT64, TYPE_FIXED64 + CPPTYPE_DOUBLE = 5, // TYPE_DOUBLE + CPPTYPE_FLOAT = 6, // TYPE_FLOAT + CPPTYPE_BOOL = 7, // TYPE_BOOL + CPPTYPE_ENUM = 8, // TYPE_ENUM + CPPTYPE_STRING = 9, // TYPE_STRING, TYPE_BYTES + CPPTYPE_MESSAGE = 10, // TYPE_MESSAGE, TYPE_GROUP + + MAX_CPPTYPE = 10, // Constant useful for defining lookup tables + // indexed by CppType. + }; + + // Identifies whether the field is optional, required, or repeated. 0 is + // reserved for errors. + enum Label { + LABEL_OPTIONAL = 1, // optional + LABEL_REQUIRED = 2, // required + LABEL_REPEATED = 3, // repeated + + MAX_LABEL = 3, // Constant useful for defining lookup tables + // indexed by Label. + }; + + // Valid field numbers are positive integers up to kMaxNumber. + static const int kMaxNumber = (1 << 29) - 1; + + // First field number reserved for the protocol buffer library implementation. + // Users may not declare fields that use reserved numbers. + static const int kFirstReservedNumber = 19000; + // Last field number reserved for the protocol buffer library implementation. + // Users may not declare fields that use reserved numbers. + static const int kLastReservedNumber = 19999; + + const std::string& name() const; // Name of this field within the message. + const std::string& full_name() const; // Fully-qualified name of the field. + const std::string& json_name() const; // JSON name of this field. + const FileDescriptor* file() const; // File in which this field was defined. + bool is_extension() const; // Is this an extension field? + int number() const; // Declared tag number. + + // Same as name() except converted to lower-case. This (and especially the + // FindFieldByLowercaseName() method) can be useful when parsing formats + // which prefer to use lowercase naming style. (Although, technically + // field names should be lowercased anyway according to the protobuf style + // guide, so this only makes a difference when dealing with old .proto files + // which do not follow the guide.) + const std::string& lowercase_name() const; + + // Same as name() except converted to camel-case. In this conversion, any + // time an underscore appears in the name, it is removed and the next + // letter is capitalized. Furthermore, the first letter of the name is + // lower-cased. Examples: + // FooBar -> fooBar + // foo_bar -> fooBar + // fooBar -> fooBar + // This (and especially the FindFieldByCamelcaseName() method) can be useful + // when parsing formats which prefer to use camel-case naming style. + const std::string& camelcase_name() const; + + Type type() const; // Declared type of this field. + const char* type_name() const; // Name of the declared type. + CppType cpp_type() const; // C++ type of this field. + const char* cpp_type_name() const; // Name of the C++ type. + Label label() const; // optional/required/repeated + + bool is_required() const; // shorthand for label() == LABEL_REQUIRED + bool is_optional() const; // shorthand for label() == LABEL_OPTIONAL + bool is_repeated() const; // shorthand for label() == LABEL_REPEATED + bool is_packable() const; // shorthand for is_repeated() && + // IsTypePackable(type()) + bool is_packed() const; // shorthand for is_packable() && + // options().packed() + bool is_map() const; // shorthand for type() == TYPE_MESSAGE && + // message_type()->options().map_entry() + + // Returns true if this field was syntactically written with "optional" in the + // .proto file. Excludes singular proto3 fields that do not have a label. + bool has_optional_keyword() const; + + // Returns true if this field tracks presence, ie. does the field + // distinguish between "unset" and "present with default value." + // This includes required, optional, and oneof fields. It excludes maps, + // repeated fields, and singular proto3 fields without "optional". + // + // For fields where has_presence() == true, the return value of + // Reflection::HasField() is semantically meaningful. + bool has_presence() const; + + // Index of this field within the message's field array, or the file or + // extension scope's extensions array. + int index() const; + + // Does this field have an explicitly-declared default value? + bool has_default_value() const; + + // Whether the user has specified the json_name field option in the .proto + // file. + bool has_json_name() const; + + // Get the field default value if cpp_type() == CPPTYPE_INT32. If no + // explicit default was defined, the default is 0. + int32 default_value_int32() const; + // Get the field default value if cpp_type() == CPPTYPE_INT64. If no + // explicit default was defined, the default is 0. + int64 default_value_int64() const; + // Get the field default value if cpp_type() == CPPTYPE_UINT32. If no + // explicit default was defined, the default is 0. + uint32 default_value_uint32() const; + // Get the field default value if cpp_type() == CPPTYPE_UINT64. If no + // explicit default was defined, the default is 0. + uint64 default_value_uint64() const; + // Get the field default value if cpp_type() == CPPTYPE_FLOAT. If no + // explicit default was defined, the default is 0.0. + float default_value_float() const; + // Get the field default value if cpp_type() == CPPTYPE_DOUBLE. If no + // explicit default was defined, the default is 0.0. + double default_value_double() const; + // Get the field default value if cpp_type() == CPPTYPE_BOOL. If no + // explicit default was defined, the default is false. + bool default_value_bool() const; + // Get the field default value if cpp_type() == CPPTYPE_ENUM. If no + // explicit default was defined, the default is the first value defined + // in the enum type (all enum types are required to have at least one value). + // This never returns nullptr. + const EnumValueDescriptor* default_value_enum() const; + // Get the field default value if cpp_type() == CPPTYPE_STRING. If no + // explicit default was defined, the default is the empty string. + const std::string& default_value_string() const; + + // The Descriptor for the message of which this is a field. For extensions, + // this is the extended type. Never nullptr. + const Descriptor* containing_type() const; + + // If the field is a member of a oneof, this is the one, otherwise this is + // nullptr. + const OneofDescriptor* containing_oneof() const; + + // If the field is a member of a non-synthetic oneof, returns the descriptor + // for the oneof, otherwise returns nullptr. + const OneofDescriptor* real_containing_oneof() const; + + // If the field is a member of a oneof, returns the index in that oneof. + int index_in_oneof() const; + + // An extension may be declared within the scope of another message. If this + // field is an extension (is_extension() is true), then extension_scope() + // returns that message, or nullptr if the extension was declared at global + // scope. If this is not an extension, extension_scope() is undefined (may + // assert-fail). + const Descriptor* extension_scope() const; + + // If type is TYPE_MESSAGE or TYPE_GROUP, returns a descriptor for the + // message or the group type. Otherwise, returns null. + const Descriptor* message_type() const; + // If type is TYPE_ENUM, returns a descriptor for the enum. Otherwise, + // returns null. + const EnumDescriptor* enum_type() const; + + // Get the FieldOptions for this field. This includes things listed in + // square brackets after the field definition. E.g., the field: + // optional string text = 1 [ctype=CORD]; + // has the "ctype" option set. Allowed options are defined by FieldOptions in + // descriptor.proto, and any available extensions of that message. + const FieldOptions& options() const; + + // See Descriptor::CopyTo(). + void CopyTo(FieldDescriptorProto* proto) const; + + // See Descriptor::DebugString(). + std::string DebugString() const; + + // See Descriptor::DebugStringWithOptions(). + std::string DebugStringWithOptions(const DebugStringOptions& options) const; + + // Helper method to get the CppType for a particular Type. + static CppType TypeToCppType(Type type); + + // Helper method to get the name of a Type. + static const char* TypeName(Type type); + + // Helper method to get the name of a CppType. + static const char* CppTypeName(CppType cpp_type); + + // Return true iff [packed = true] is valid for fields of this type. + static inline bool IsTypePackable(Type field_type); + + // Returns full_name() except if the field is a MessageSet extension, + // in which case it returns the full_name() of the containing message type + // for backwards compatibility with proto1. + // + // A MessageSet extension is defined as an optional message extension + // whose containing type has the message_set_wire_format option set. + // This should be true of extensions of google.protobuf.bridge.MessageSet; + // by convention, such extensions are named "message_set_extension". + // + // The opposite operation (looking up an extension's FieldDescriptor given + // its printable name) can be accomplished with + // message->file()->pool()->FindExtensionByPrintableName(message, name) + // where the extension extends "message". + const std::string& PrintableNameForExtension() const; + + // Source Location --------------------------------------------------- + + // Updates |*out_location| to the source location of the complete + // extent of this field declaration. Returns false and leaves + // |*out_location| unchanged iff location information was not available. + bool GetSourceLocation(SourceLocation* out_location) const; + + private: + typedef FieldOptions OptionsType; + + // Allows access to GetLocationPath for annotations. + friend class io::Printer; + friend class compiler::cpp::Formatter; + + // Fill the json_name field of FieldDescriptorProto. + void CopyJsonNameTo(FieldDescriptorProto* proto) const; + + // See Descriptor::DebugString(). + void DebugString(int depth, std::string* contents, + const DebugStringOptions& options) const; + + // formats the default value appropriately and returns it as a string. + // Must have a default value to call this. If quote_string_type is true, then + // types of CPPTYPE_STRING whill be surrounded by quotes and CEscaped. + std::string DefaultValueAsString(bool quote_string_type) const; + + // Helper function that returns the field type name for DebugString. + std::string FieldTypeNameDebugString() const; + + // Walks up the descriptor tree to generate the source location path + // to this descriptor from the file root. + void GetLocationPath(std::vector* output) const; + + // Returns true if this is a map message type. + bool is_map_message_type() const; + + const std::string* name_; + const std::string* full_name_; + const std::string* lowercase_name_; + const std::string* camelcase_name_; + // If has_json_name_ is true, it's the value specified by the user. + // Otherwise, it has the same value as camelcase_name_. + const std::string* json_name_; + const FileDescriptor* file_; + internal::once_flag* type_once_; + static void TypeOnceInit(const FieldDescriptor* to_init); + void InternalTypeOnceInit() const; + mutable Type type_; + Label label_; + bool has_default_value_; + bool proto3_optional_; + // Whether the user has specified the json_name field option in the .proto + // file. + bool has_json_name_; + bool is_extension_; + int number_; + int index_in_oneof_; + const Descriptor* containing_type_; + const OneofDescriptor* containing_oneof_; + const Descriptor* extension_scope_; + mutable const Descriptor* message_type_; + mutable const EnumDescriptor* enum_type_; + const FieldOptions* options_; + const std::string* type_name_; + const std::string* default_value_enum_name_; + // IMPORTANT: If you add a new field, make sure to search for all instances + // of Allocate() and AllocateArray() in + // descriptor.cc and update them to initialize the field. + + union { + int32 default_value_int32_; + int64 default_value_int64_; + uint32 default_value_uint32_; + uint64 default_value_uint64_; + float default_value_float_; + double default_value_double_; + bool default_value_bool_; + + mutable const EnumValueDescriptor* default_value_enum_; + const std::string* default_value_string_; + }; + + static const CppType kTypeToCppTypeMap[MAX_TYPE + 1]; + + static const char* const kTypeToName[MAX_TYPE + 1]; + + static const char* const kCppTypeToName[MAX_CPPTYPE + 1]; + + static const char* const kLabelToName[MAX_LABEL + 1]; + + // Must be constructed using DescriptorPool. + FieldDescriptor() {} + friend class DescriptorBuilder; + friend class FileDescriptor; + friend class Descriptor; + friend class OneofDescriptor; + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(FieldDescriptor); +}; + + +// Describes a oneof defined in a message type. +class PROTOBUF_EXPORT OneofDescriptor { + public: + typedef OneofDescriptorProto Proto; + + const std::string& name() const; // Name of this oneof. + const std::string& full_name() const; // Fully-qualified name of the oneof. + + // Index of this oneof within the message's oneof array. + int index() const; + + // Returns whether this oneof was inserted by the compiler to wrap a proto3 + // optional field. If this returns true, code generators should *not* emit it. + bool is_synthetic() const; + + // The .proto file in which this oneof was defined. Never nullptr. + const FileDescriptor* file() const; + // The Descriptor for the message containing this oneof. + const Descriptor* containing_type() const; + + // The number of (non-extension) fields which are members of this oneof. + int field_count() const; + // Get a member of this oneof, in the order in which they were declared in the + // .proto file. Does not include extensions. + const FieldDescriptor* field(int index) const; + + const OneofOptions& options() const; + + // See Descriptor::CopyTo(). + void CopyTo(OneofDescriptorProto* proto) const; + + // See Descriptor::DebugString(). + std::string DebugString() const; + + // See Descriptor::DebugStringWithOptions(). + std::string DebugStringWithOptions(const DebugStringOptions& options) const; + + // Source Location --------------------------------------------------- + + // Updates |*out_location| to the source location of the complete + // extent of this oneof declaration. Returns false and leaves + // |*out_location| unchanged iff location information was not available. + bool GetSourceLocation(SourceLocation* out_location) const; + + private: + typedef OneofOptions OptionsType; + + // Allows access to GetLocationPath for annotations. + friend class io::Printer; + friend class compiler::cpp::Formatter; + + // See Descriptor::DebugString(). + void DebugString(int depth, std::string* contents, + const DebugStringOptions& options) const; + + // Walks up the descriptor tree to generate the source location path + // to this descriptor from the file root. + void GetLocationPath(std::vector* output) const; + + const std::string* name_; + const std::string* full_name_; + const Descriptor* containing_type_; + int field_count_; + const FieldDescriptor** fields_; + const OneofOptions* options_; + + // IMPORTANT: If you add a new field, make sure to search for all instances + // of Allocate() and AllocateArray() + // in descriptor.cc and update them to initialize the field. + + // Must be constructed using DescriptorPool. + OneofDescriptor() {} + friend class DescriptorBuilder; + friend class Descriptor; + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(OneofDescriptor); +}; + +// Describes an enum type defined in a .proto file. To get the EnumDescriptor +// for a generated enum type, call TypeName_descriptor(). Use DescriptorPool +// to construct your own descriptors. +class PROTOBUF_EXPORT EnumDescriptor { + public: + typedef EnumDescriptorProto Proto; + + // The name of this enum type in the containing scope. + const std::string& name() const; + + // The fully-qualified name of the enum type, scope delimited by periods. + const std::string& full_name() const; + + // Index of this enum within the file or containing message's enum array. + int index() const; + + // The .proto file in which this enum type was defined. Never nullptr. + const FileDescriptor* file() const; + + // The number of values for this EnumDescriptor. Guaranteed to be greater + // than zero. + int value_count() const; + // Gets a value by index, where 0 <= index < value_count(). + // These are returned in the order they were defined in the .proto file. + const EnumValueDescriptor* value(int index) const; + + // Looks up a value by name. Returns nullptr if no such value exists. + const EnumValueDescriptor* FindValueByName(ConstStringParam name) const; + // Looks up a value by number. Returns nullptr if no such value exists. If + // multiple values have this number, the first one defined is returned. + const EnumValueDescriptor* FindValueByNumber(int number) const; + + // If this enum type is nested in a message type, this is that message type. + // Otherwise, nullptr. + const Descriptor* containing_type() const; + + // Get options for this enum type. These are specified in the .proto file by + // placing lines like "option foo = 1234;" in the enum definition. Allowed + // options are defined by EnumOptions in descriptor.proto, and any available + // extensions of that message. + const EnumOptions& options() const; + + // See Descriptor::CopyTo(). + void CopyTo(EnumDescriptorProto* proto) const; + + // See Descriptor::DebugString(). + std::string DebugString() const; + + // See Descriptor::DebugStringWithOptions(). + std::string DebugStringWithOptions(const DebugStringOptions& options) const; + + // Returns true if this is a placeholder for an unknown enum. This will + // only be the case if this descriptor comes from a DescriptorPool + // with AllowUnknownDependencies() set. + bool is_placeholder() const; + + // Reserved fields ------------------------------------------------- + + // A range of reserved field numbers. + struct ReservedRange { + int start; // inclusive + int end; // inclusive + }; + + // The number of reserved ranges in this message type. + int reserved_range_count() const; + // Gets an reserved range by index, where 0 <= index < + // reserved_range_count(). These are returned in the order they were defined + // in the .proto file. + const EnumDescriptor::ReservedRange* reserved_range(int index) const; + + // Returns true if the number is in one of the reserved ranges. + bool IsReservedNumber(int number) const; + + // Returns nullptr if no reserved range contains the given number. + const EnumDescriptor::ReservedRange* FindReservedRangeContainingNumber( + int number) const; + + // The number of reserved field names in this message type. + int reserved_name_count() const; + + // Gets a reserved name by index, where 0 <= index < reserved_name_count(). + const std::string& reserved_name(int index) const; + + // Returns true if the field name is reserved. + bool IsReservedName(ConstStringParam name) const; + + // Source Location --------------------------------------------------- + + // Updates |*out_location| to the source location of the complete + // extent of this enum declaration. Returns false and leaves + // |*out_location| unchanged iff location information was not available. + bool GetSourceLocation(SourceLocation* out_location) const; + + private: + typedef EnumOptions OptionsType; + + // Allows access to GetLocationPath for annotations. + friend class io::Printer; + friend class compiler::cpp::Formatter; + + // Looks up a value by number. If the value does not exist, dynamically + // creates a new EnumValueDescriptor for that value, assuming that it was + // unknown. If a new descriptor is created, this is done in a thread-safe way, + // and future calls will return the same value descriptor pointer. + // + // This is private but is used by Reflection (which is friended below) to + // return a valid EnumValueDescriptor from GetEnum() when this feature is + // enabled. + const EnumValueDescriptor* FindValueByNumberCreatingIfUnknown( + int number) const; + + // See Descriptor::DebugString(). + void DebugString(int depth, std::string* contents, + const DebugStringOptions& options) const; + + // Walks up the descriptor tree to generate the source location path + // to this descriptor from the file root. + void GetLocationPath(std::vector* output) const; + + const std::string* name_; + const std::string* full_name_; + const FileDescriptor* file_; + const Descriptor* containing_type_; + const EnumOptions* options_; + + // True if this is a placeholder for an unknown type. + bool is_placeholder_; + // True if this is a placeholder and the type name wasn't fully-qualified. + bool is_unqualified_placeholder_; + + int value_count_; + EnumValueDescriptor* values_; + + int reserved_range_count_; + int reserved_name_count_; + EnumDescriptor::ReservedRange* reserved_ranges_; + const std::string** reserved_names_; + + // IMPORTANT: If you add a new field, make sure to search for all instances + // of Allocate() and AllocateArray() in + // descriptor.cc and update them to initialize the field. + + // Must be constructed using DescriptorPool. + EnumDescriptor() {} + friend class DescriptorBuilder; + friend class Descriptor; + friend class FieldDescriptor; + friend class EnumValueDescriptor; + friend class FileDescriptor; + friend class DescriptorPool; + friend class Reflection; + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(EnumDescriptor); +}; + +// Describes an individual enum constant of a particular type. To get the +// EnumValueDescriptor for a given enum value, first get the EnumDescriptor +// for its type, then use EnumDescriptor::FindValueByName() or +// EnumDescriptor::FindValueByNumber(). Use DescriptorPool to construct +// your own descriptors. +class PROTOBUF_EXPORT EnumValueDescriptor { + public: + typedef EnumValueDescriptorProto Proto; + + const std::string& name() const; // Name of this enum constant. + int index() const; // Index within the enums's Descriptor. + int number() const; // Numeric value of this enum constant. + + // The full_name of an enum value is a sibling symbol of the enum type. + // e.g. the full name of FieldDescriptorProto::TYPE_INT32 is actually + // "google.protobuf.FieldDescriptorProto.TYPE_INT32", NOT + // "google.protobuf.FieldDescriptorProto.Type.TYPE_INT32". This is to conform + // with C++ scoping rules for enums. + const std::string& full_name() const; + + // The .proto file in which this value was defined. Never nullptr. + const FileDescriptor* file() const; + // The type of this value. Never nullptr. + const EnumDescriptor* type() const; + + // Get options for this enum value. These are specified in the .proto file by + // adding text like "[foo = 1234]" after an enum value definition. Allowed + // options are defined by EnumValueOptions in descriptor.proto, and any + // available extensions of that message. + const EnumValueOptions& options() const; + + // See Descriptor::CopyTo(). + void CopyTo(EnumValueDescriptorProto* proto) const; + + // See Descriptor::DebugString(). + std::string DebugString() const; + + // See Descriptor::DebugStringWithOptions(). + std::string DebugStringWithOptions(const DebugStringOptions& options) const; + + // Source Location --------------------------------------------------- + + // Updates |*out_location| to the source location of the complete + // extent of this enum value declaration. Returns false and leaves + // |*out_location| unchanged iff location information was not available. + bool GetSourceLocation(SourceLocation* out_location) const; + + private: + typedef EnumValueOptions OptionsType; + + // Allows access to GetLocationPath for annotations. + friend class io::Printer; + friend class compiler::cpp::Formatter; + + // See Descriptor::DebugString(). + void DebugString(int depth, std::string* contents, + const DebugStringOptions& options) const; + + // Walks up the descriptor tree to generate the source location path + // to this descriptor from the file root. + void GetLocationPath(std::vector* output) const; + + const std::string* name_; + const std::string* full_name_; + int number_; + const EnumDescriptor* type_; + const EnumValueOptions* options_; + // IMPORTANT: If you add a new field, make sure to search for all instances + // of Allocate() and AllocateArray() + // in descriptor.cc and update them to initialize the field. + + // Must be constructed using DescriptorPool. + EnumValueDescriptor() {} + friend class DescriptorBuilder; + friend class EnumDescriptor; + friend class DescriptorPool; + friend class FileDescriptorTables; + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(EnumValueDescriptor); +}; + +// Describes an RPC service. Use DescriptorPool to construct your own +// descriptors. +class PROTOBUF_EXPORT ServiceDescriptor { + public: + typedef ServiceDescriptorProto Proto; + + // The name of the service, not including its containing scope. + const std::string& name() const; + // The fully-qualified name of the service, scope delimited by periods. + const std::string& full_name() const; + // Index of this service within the file's services array. + int index() const; + + // The .proto file in which this service was defined. Never nullptr. + const FileDescriptor* file() const; + + // Get options for this service type. These are specified in the .proto file + // by placing lines like "option foo = 1234;" in the service definition. + // Allowed options are defined by ServiceOptions in descriptor.proto, and any + // available extensions of that message. + const ServiceOptions& options() const; + + // The number of methods this service defines. + int method_count() const; + // Gets a MethodDescriptor by index, where 0 <= index < method_count(). + // These are returned in the order they were defined in the .proto file. + const MethodDescriptor* method(int index) const; + + // Look up a MethodDescriptor by name. + const MethodDescriptor* FindMethodByName(ConstStringParam name) const; + // See Descriptor::CopyTo(). + void CopyTo(ServiceDescriptorProto* proto) const; + + // See Descriptor::DebugString(). + std::string DebugString() const; + + // See Descriptor::DebugStringWithOptions(). + std::string DebugStringWithOptions(const DebugStringOptions& options) const; + + // Source Location --------------------------------------------------- + + // Updates |*out_location| to the source location of the complete + // extent of this service declaration. Returns false and leaves + // |*out_location| unchanged iff location information was not available. + bool GetSourceLocation(SourceLocation* out_location) const; + + private: + typedef ServiceOptions OptionsType; + + // Allows access to GetLocationPath for annotations. + friend class io::Printer; + friend class compiler::cpp::Formatter; + + // See Descriptor::DebugString(). + void DebugString(std::string* contents, + const DebugStringOptions& options) const; + + // Walks up the descriptor tree to generate the source location path + // to this descriptor from the file root. + void GetLocationPath(std::vector* output) const; + + const std::string* name_; + const std::string* full_name_; + const FileDescriptor* file_; + const ServiceOptions* options_; + MethodDescriptor* methods_; + int method_count_; + // IMPORTANT: If you add a new field, make sure to search for all instances + // of Allocate() and AllocateArray() in + // descriptor.cc and update them to initialize the field. + + // Must be constructed using DescriptorPool. + ServiceDescriptor() {} + friend class DescriptorBuilder; + friend class FileDescriptor; + friend class MethodDescriptor; + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(ServiceDescriptor); +}; + + +// Describes an individual service method. To obtain a MethodDescriptor given +// a service, first get its ServiceDescriptor, then call +// ServiceDescriptor::FindMethodByName(). Use DescriptorPool to construct your +// own descriptors. +class PROTOBUF_EXPORT MethodDescriptor { + public: + typedef MethodDescriptorProto Proto; + + // Name of this method, not including containing scope. + const std::string& name() const; + // The fully-qualified name of the method, scope delimited by periods. + const std::string& full_name() const; + // Index within the service's Descriptor. + int index() const; + + // The .proto file in which this method was defined. Never nullptr. + const FileDescriptor* file() const; + // Gets the service to which this method belongs. Never nullptr. + const ServiceDescriptor* service() const; + + // Gets the type of protocol message which this method accepts as input. + const Descriptor* input_type() const; + // Gets the type of protocol message which this message produces as output. + const Descriptor* output_type() const; + + // Gets whether the client streams multiple requests. + bool client_streaming() const; + // Gets whether the server streams multiple responses. + bool server_streaming() const; + + // Get options for this method. These are specified in the .proto file by + // placing lines like "option foo = 1234;" in curly-braces after a method + // declaration. Allowed options are defined by MethodOptions in + // descriptor.proto, and any available extensions of that message. + const MethodOptions& options() const; + + // See Descriptor::CopyTo(). + void CopyTo(MethodDescriptorProto* proto) const; + + // See Descriptor::DebugString(). + std::string DebugString() const; + + // See Descriptor::DebugStringWithOptions(). + std::string DebugStringWithOptions(const DebugStringOptions& options) const; + + // Source Location --------------------------------------------------- + + // Updates |*out_location| to the source location of the complete + // extent of this method declaration. Returns false and leaves + // |*out_location| unchanged iff location information was not available. + bool GetSourceLocation(SourceLocation* out_location) const; + + private: + typedef MethodOptions OptionsType; + + // Allows access to GetLocationPath for annotations. + friend class io::Printer; + friend class compiler::cpp::Formatter; + + // See Descriptor::DebugString(). + void DebugString(int depth, std::string* contents, + const DebugStringOptions& options) const; + + // Walks up the descriptor tree to generate the source location path + // to this descriptor from the file root. + void GetLocationPath(std::vector* output) const; + + const std::string* name_; + const std::string* full_name_; + const ServiceDescriptor* service_; + mutable internal::LazyDescriptor input_type_; + mutable internal::LazyDescriptor output_type_; + const MethodOptions* options_; + bool client_streaming_; + bool server_streaming_; + // IMPORTANT: If you add a new field, make sure to search for all instances + // of Allocate() and AllocateArray() in + // descriptor.cc and update them to initialize the field. + + // Must be constructed using DescriptorPool. + MethodDescriptor() {} + friend class DescriptorBuilder; + friend class ServiceDescriptor; + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(MethodDescriptor); +}; + + +// Describes a whole .proto file. To get the FileDescriptor for a compiled-in +// file, get the descriptor for something defined in that file and call +// descriptor->file(). Use DescriptorPool to construct your own descriptors. +class PROTOBUF_EXPORT FileDescriptor { + public: + typedef FileDescriptorProto Proto; + + // The filename, relative to the source tree. + // e.g. "foo/bar/baz.proto" + const std::string& name() const; + + // The package, e.g. "google.protobuf.compiler". + const std::string& package() const; + + // The DescriptorPool in which this FileDescriptor and all its contents were + // allocated. Never nullptr. + const DescriptorPool* pool() const; + + // The number of files imported by this one. + int dependency_count() const; + // Gets an imported file by index, where 0 <= index < dependency_count(). + // These are returned in the order they were defined in the .proto file. + const FileDescriptor* dependency(int index) const; + + // The number of files public imported by this one. + // The public dependency list is a subset of the dependency list. + int public_dependency_count() const; + // Gets a public imported file by index, where 0 <= index < + // public_dependency_count(). + // These are returned in the order they were defined in the .proto file. + const FileDescriptor* public_dependency(int index) const; + + // The number of files that are imported for weak fields. + // The weak dependency list is a subset of the dependency list. + int weak_dependency_count() const; + // Gets a weak imported file by index, where 0 <= index < + // weak_dependency_count(). + // These are returned in the order they were defined in the .proto file. + const FileDescriptor* weak_dependency(int index) const; + + // Number of top-level message types defined in this file. (This does not + // include nested types.) + int message_type_count() const; + // Gets a top-level message type, where 0 <= index < message_type_count(). + // These are returned in the order they were defined in the .proto file. + const Descriptor* message_type(int index) const; + + // Number of top-level enum types defined in this file. (This does not + // include nested types.) + int enum_type_count() const; + // Gets a top-level enum type, where 0 <= index < enum_type_count(). + // These are returned in the order they were defined in the .proto file. + const EnumDescriptor* enum_type(int index) const; + + // Number of services defined in this file. + int service_count() const; + // Gets a service, where 0 <= index < service_count(). + // These are returned in the order they were defined in the .proto file. + const ServiceDescriptor* service(int index) const; + + // Number of extensions defined at file scope. (This does not include + // extensions nested within message types.) + int extension_count() const; + // Gets an extension's descriptor, where 0 <= index < extension_count(). + // These are returned in the order they were defined in the .proto file. + const FieldDescriptor* extension(int index) const; + + // Get options for this file. These are specified in the .proto file by + // placing lines like "option foo = 1234;" at the top level, outside of any + // other definitions. Allowed options are defined by FileOptions in + // descriptor.proto, and any available extensions of that message. + const FileOptions& options() const; + + // Syntax of this file. + enum Syntax { + SYNTAX_UNKNOWN = 0, + SYNTAX_PROTO2 = 2, + SYNTAX_PROTO3 = 3, + }; + Syntax syntax() const; + static const char* SyntaxName(Syntax syntax); + + // Find a top-level message type by name. Returns nullptr if not found. + const Descriptor* FindMessageTypeByName(ConstStringParam name) const; + // Find a top-level enum type by name. Returns nullptr if not found. + const EnumDescriptor* FindEnumTypeByName(ConstStringParam name) const; + // Find an enum value defined in any top-level enum by name. Returns nullptr + // if not found. + const EnumValueDescriptor* FindEnumValueByName(ConstStringParam name) const; + // Find a service definition by name. Returns nullptr if not found. + const ServiceDescriptor* FindServiceByName(ConstStringParam name) const; + // Find a top-level extension definition by name. Returns nullptr if not + // found. + const FieldDescriptor* FindExtensionByName(ConstStringParam name) const; + // Similar to FindExtensionByName(), but searches by lowercased-name. See + // Descriptor::FindFieldByLowercaseName(). + const FieldDescriptor* FindExtensionByLowercaseName( + ConstStringParam name) const; + // Similar to FindExtensionByName(), but searches by camelcased-name. See + // Descriptor::FindFieldByCamelcaseName(). + const FieldDescriptor* FindExtensionByCamelcaseName( + ConstStringParam name) const; + + // See Descriptor::CopyTo(). + // Notes: + // - This method does NOT copy source code information since it is relatively + // large and rarely needed. See CopySourceCodeInfoTo() below. + void CopyTo(FileDescriptorProto* proto) const; + // Write the source code information of this FileDescriptor into the given + // FileDescriptorProto. See CopyTo() above. + void CopySourceCodeInfoTo(FileDescriptorProto* proto) const; + // Fill the json_name field of FieldDescriptorProto for all fields. Can only + // be called after CopyTo(). + void CopyJsonNameTo(FileDescriptorProto* proto) const; + + // See Descriptor::DebugString(). + std::string DebugString() const; + + // See Descriptor::DebugStringWithOptions(). + std::string DebugStringWithOptions(const DebugStringOptions& options) const; + + // Returns true if this is a placeholder for an unknown file. This will + // only be the case if this descriptor comes from a DescriptorPool + // with AllowUnknownDependencies() set. + bool is_placeholder() const; + + // Updates |*out_location| to the source location of the complete extent of + // this file declaration (namely, the empty path). + bool GetSourceLocation(SourceLocation* out_location) const; + + // Updates |*out_location| to the source location of the complete + // extent of the declaration or declaration-part denoted by |path|. + // Returns false and leaves |*out_location| unchanged iff location + // information was not available. (See SourceCodeInfo for + // description of path encoding.) + bool GetSourceLocation(const std::vector& path, + SourceLocation* out_location) const; + + private: + typedef FileOptions OptionsType; + + const std::string* name_; + const std::string* package_; + const DescriptorPool* pool_; + internal::once_flag* dependencies_once_; + static void DependenciesOnceInit(const FileDescriptor* to_init); + void InternalDependenciesOnceInit() const; + + // These are arranged to minimize padding on 64-bit. + int dependency_count_; + int public_dependency_count_; + int weak_dependency_count_; + int message_type_count_; + int enum_type_count_; + int service_count_; + int extension_count_; + Syntax syntax_; + bool is_placeholder_; + + // Indicates the FileDescriptor is completed building. Used to verify + // that type accessor functions that can possibly build a dependent file + // aren't called during the process of building the file. + bool finished_building_; + + mutable const FileDescriptor** dependencies_; + const std::string** dependencies_names_; + int* public_dependencies_; + int* weak_dependencies_; + Descriptor* message_types_; + EnumDescriptor* enum_types_; + ServiceDescriptor* services_; + FieldDescriptor* extensions_; + const FileOptions* options_; + + const FileDescriptorTables* tables_; + const SourceCodeInfo* source_code_info_; + + // IMPORTANT: If you add a new field, make sure to search for all instances + // of Allocate() and AllocateArray() in + // descriptor.cc and update them to initialize the field. + + FileDescriptor() {} + friend class DescriptorBuilder; + friend class DescriptorPool; + friend class Descriptor; + friend class FieldDescriptor; + friend class internal::LazyDescriptor; + friend class OneofDescriptor; + friend class EnumDescriptor; + friend class EnumValueDescriptor; + friend class MethodDescriptor; + friend class ServiceDescriptor; + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(FileDescriptor); +}; + + +// =================================================================== + +// Used to construct descriptors. +// +// Normally you won't want to build your own descriptors. Message classes +// constructed by the protocol compiler will provide them for you. However, +// if you are implementing Message on your own, or if you are writing a +// program which can operate on totally arbitrary types and needs to load +// them from some sort of database, you might need to. +// +// Since Descriptors are composed of a whole lot of cross-linked bits of +// data that would be a pain to put together manually, the +// DescriptorPool class is provided to make the process easier. It can +// take a FileDescriptorProto (defined in descriptor.proto), validate it, +// and convert it to a set of nicely cross-linked Descriptors. +// +// DescriptorPool also helps with memory management. Descriptors are +// composed of many objects containing static data and pointers to each +// other. In all likelihood, when it comes time to delete this data, +// you'll want to delete it all at once. In fact, it is not uncommon to +// have a whole pool of descriptors all cross-linked with each other which +// you wish to delete all at once. This class represents such a pool, and +// handles the memory management for you. +// +// You can also search for descriptors within a DescriptorPool by name, and +// extensions by number. +class PROTOBUF_EXPORT DescriptorPool { + public: + // Create a normal, empty DescriptorPool. + DescriptorPool(); + + // Constructs a DescriptorPool that, when it can't find something among the + // descriptors already in the pool, looks for it in the given + // DescriptorDatabase. + // Notes: + // - If a DescriptorPool is constructed this way, its BuildFile*() methods + // must not be called (they will assert-fail). The only way to populate + // the pool with descriptors is to call the Find*By*() methods. + // - The Find*By*() methods may block the calling thread if the + // DescriptorDatabase blocks. This in turn means that parsing messages + // may block if they need to look up extensions. + // - The Find*By*() methods will use mutexes for thread-safety, thus making + // them slower even when they don't have to fall back to the database. + // In fact, even the Find*By*() methods of descriptor objects owned by + // this pool will be slower, since they will have to obtain locks too. + // - An ErrorCollector may optionally be given to collect validation errors + // in files loaded from the database. If not given, errors will be printed + // to GOOGLE_LOG(ERROR). Remember that files are built on-demand, so this + // ErrorCollector may be called from any thread that calls one of the + // Find*By*() methods. + // - The DescriptorDatabase must not be mutated during the lifetime of + // the DescriptorPool. Even if the client takes care to avoid data races, + // changes to the content of the DescriptorDatabase may not be reflected + // in subsequent lookups in the DescriptorPool. + class ErrorCollector; + explicit DescriptorPool(DescriptorDatabase* fallback_database, + ErrorCollector* error_collector = nullptr); + + ~DescriptorPool(); + + // Get a pointer to the generated pool. Generated protocol message classes + // which are compiled into the binary will allocate their descriptors in + // this pool. Do not add your own descriptors to this pool. + static const DescriptorPool* generated_pool(); + + + // Find a FileDescriptor in the pool by file name. Returns nullptr if not + // found. + const FileDescriptor* FindFileByName(ConstStringParam name) const; + + // Find the FileDescriptor in the pool which defines the given symbol. + // If any of the Find*ByName() methods below would succeed, then this is + // equivalent to calling that method and calling the result's file() method. + // Otherwise this returns nullptr. + const FileDescriptor* FindFileContainingSymbol( + ConstStringParam symbol_name) const; + + // Looking up descriptors ------------------------------------------ + // These find descriptors by fully-qualified name. These will find both + // top-level descriptors and nested descriptors. They return nullptr if not + // found. + + const Descriptor* FindMessageTypeByName(ConstStringParam name) const; + const FieldDescriptor* FindFieldByName(ConstStringParam name) const; + const FieldDescriptor* FindExtensionByName(ConstStringParam name) const; + const OneofDescriptor* FindOneofByName(ConstStringParam name) const; + const EnumDescriptor* FindEnumTypeByName(ConstStringParam name) const; + const EnumValueDescriptor* FindEnumValueByName(ConstStringParam name) const; + const ServiceDescriptor* FindServiceByName(ConstStringParam name) const; + const MethodDescriptor* FindMethodByName(ConstStringParam name) const; + + // Finds an extension of the given type by number. The extendee must be + // a member of this DescriptorPool or one of its underlays. + const FieldDescriptor* FindExtensionByNumber(const Descriptor* extendee, + int number) const; + + // Finds an extension of the given type by its printable name. + // See comments above PrintableNameForExtension() for the definition of + // "printable name". The extendee must be a member of this DescriptorPool + // or one of its underlays. Returns nullptr if there is no known message + // extension with the given printable name. + const FieldDescriptor* FindExtensionByPrintableName( + const Descriptor* extendee, ConstStringParam printable_name) const; + + // Finds extensions of extendee. The extensions will be appended to + // out in an undefined order. Only extensions defined directly in + // this DescriptorPool or one of its underlays are guaranteed to be + // found: extensions defined in the fallback database might not be found + // depending on the database implementation. + void FindAllExtensions(const Descriptor* extendee, + std::vector* out) const; + + // Building descriptors -------------------------------------------- + + // When converting a FileDescriptorProto to a FileDescriptor, various + // errors might be detected in the input. The caller may handle these + // programmatically by implementing an ErrorCollector. + class PROTOBUF_EXPORT ErrorCollector { + public: + inline ErrorCollector() {} + virtual ~ErrorCollector(); + + // These constants specify what exact part of the construct is broken. + // This is useful e.g. for mapping the error back to an exact location + // in a .proto file. + enum ErrorLocation { + NAME, // the symbol name, or the package name for files + NUMBER, // field or extension range number + TYPE, // field type + EXTENDEE, // field extendee + DEFAULT_VALUE, // field default value + INPUT_TYPE, // method input type + OUTPUT_TYPE, // method output type + OPTION_NAME, // name in assignment + OPTION_VALUE, // value in option assignment + IMPORT, // import error + OTHER // some other problem + }; + + // Reports an error in the FileDescriptorProto. Use this function if the + // problem occurred should interrupt building the FileDescriptorProto. + virtual void AddError( + const std::string& filename, // File name in which the error occurred. + const std::string& element_name, // Full name of the erroneous element. + const Message* descriptor, // Descriptor of the erroneous element. + ErrorLocation location, // One of the location constants, above. + const std::string& message // Human-readable error message. + ) = 0; + + // Reports a warning in the FileDescriptorProto. Use this function if the + // problem occurred should NOT interrupt building the FileDescriptorProto. + virtual void AddWarning( + const std::string& /*filename*/, // File name in which the error + // occurred. + const std::string& /*element_name*/, // Full name of the erroneous + // element. + const Message* /*descriptor*/, // Descriptor of the erroneous element. + ErrorLocation /*location*/, // One of the location constants, above. + const std::string& /*message*/ // Human-readable error message. + ) {} + + private: + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(ErrorCollector); + }; + + // Convert the FileDescriptorProto to real descriptors and place them in + // this DescriptorPool. All dependencies of the file must already be in + // the pool. Returns the resulting FileDescriptor, or nullptr if there were + // problems with the input (e.g. the message was invalid, or dependencies + // were missing). Details about the errors are written to GOOGLE_LOG(ERROR). + const FileDescriptor* BuildFile(const FileDescriptorProto& proto); + + // Same as BuildFile() except errors are sent to the given ErrorCollector. + const FileDescriptor* BuildFileCollectingErrors( + const FileDescriptorProto& proto, ErrorCollector* error_collector); + + // By default, it is an error if a FileDescriptorProto contains references + // to types or other files that are not found in the DescriptorPool (or its + // backing DescriptorDatabase, if any). If you call + // AllowUnknownDependencies(), however, then unknown types and files + // will be replaced by placeholder descriptors (which can be identified by + // the is_placeholder() method). This can allow you to + // perform some useful operations with a .proto file even if you do not + // have access to other .proto files on which it depends. However, some + // heuristics must be used to fill in the gaps in information, and these + // can lead to descriptors which are inaccurate. For example, the + // DescriptorPool may be forced to guess whether an unknown type is a message + // or an enum, as well as what package it resides in. Furthermore, + // placeholder types will not be discoverable via FindMessageTypeByName() + // and similar methods, which could confuse some descriptor-based algorithms. + // Generally, the results of this option should be handled with extreme care. + void AllowUnknownDependencies() { allow_unknown_ = true; } + + // By default, weak imports are allowed to be missing, in which case we will + // use a placeholder for the dependency and convert the field to be an Empty + // message field. If you call EnforceWeakDependencies(true), however, the + // DescriptorPool will report a import not found error. + void EnforceWeakDependencies(bool enforce) { enforce_weak_ = enforce; } + + // Internal stuff -------------------------------------------------- + // These methods MUST NOT be called from outside the proto2 library. + // These methods may contain hidden pitfalls and may be removed in a + // future library version. + + // Create a DescriptorPool which is overlaid on top of some other pool. + // If you search for a descriptor in the overlay and it is not found, the + // underlay will be searched as a backup. If the underlay has its own + // underlay, that will be searched next, and so on. This also means that + // files built in the overlay will be cross-linked with the underlay's + // descriptors if necessary. The underlay remains property of the caller; + // it must remain valid for the lifetime of the newly-constructed pool. + // + // Example: Say you want to parse a .proto file at runtime in order to use + // its type with a DynamicMessage. Say this .proto file has dependencies, + // but you know that all the dependencies will be things that are already + // compiled into the binary. For ease of use, you'd like to load the types + // right out of generated_pool() rather than have to parse redundant copies + // of all these .protos and runtime. But, you don't want to add the parsed + // types directly into generated_pool(): this is not allowed, and would be + // bad design anyway. So, instead, you could use generated_pool() as an + // underlay for a new DescriptorPool in which you add only the new file. + // + // WARNING: Use of underlays can lead to many subtle gotchas. Instead, + // try to formulate what you want to do in terms of DescriptorDatabases. + explicit DescriptorPool(const DescriptorPool* underlay); + + // Called by generated classes at init time to add their descriptors to + // generated_pool. Do NOT call this in your own code! filename must be a + // permanent string (e.g. a string literal). + static void InternalAddGeneratedFile(const void* encoded_file_descriptor, + int size); + + // Disallow [enforce_utf8 = false] in .proto files. + void DisallowEnforceUtf8() { disallow_enforce_utf8_ = true; } + + + // For internal use only: Gets a non-const pointer to the generated pool. + // This is called at static-initialization time only, so thread-safety is + // not a concern. If both an underlay and a fallback database are present, + // the underlay takes precedence. + static DescriptorPool* internal_generated_pool(); + + // For internal use only: Gets a non-const pointer to the generated + // descriptor database. + // Only used for testing. + static DescriptorDatabase* internal_generated_database(); + + // For internal use only: Changes the behavior of BuildFile() such that it + // allows the file to make reference to message types declared in other files + // which it did not officially declare as dependencies. + void InternalDontEnforceDependencies(); + + // For internal use only: Enables lazy building of dependencies of a file. + // Delay the building of dependencies of a file descriptor until absolutely + // necessary, like when message_type() is called on a field that is defined + // in that dependency's file. This will cause functional issues if a proto + // or one of it's dependencies has errors. Should only be enabled for the + // generated_pool_ (because no descriptor build errors are guaranteed by + // the compilation generation process), testing, or if a lack of descriptor + // build errors can be guaranteed for a pool. + void InternalSetLazilyBuildDependencies() { + lazily_build_dependencies_ = true; + // This needs to be set when lazily building dependencies, as it breaks + // dependency checking. + InternalDontEnforceDependencies(); + } + + // For internal use only. + void internal_set_underlay(const DescriptorPool* underlay) { + underlay_ = underlay; + } + + // For internal (unit test) use only: Returns true if a FileDescriptor has + // been constructed for the given file, false otherwise. Useful for testing + // lazy descriptor initialization behavior. + bool InternalIsFileLoaded(ConstStringParam filename) const; + + // Add a file to unused_import_track_files_. DescriptorBuilder will log + // warnings or errors for those files if there is any unused import. + void AddUnusedImportTrackFile(ConstStringParam file_name, + bool is_error = false); + void ClearUnusedImportTrackFiles(); + + private: + friend class Descriptor; + friend class internal::LazyDescriptor; + friend class FieldDescriptor; + friend class EnumDescriptor; + friend class ServiceDescriptor; + friend class MethodDescriptor; + friend class FileDescriptor; + friend class StreamDescriptor; + friend class DescriptorBuilder; + friend class FileDescriptorTables; + + // Return true if the given name is a sub-symbol of any non-package + // descriptor that already exists in the descriptor pool. (The full + // definition of such types is already known.) + bool IsSubSymbolOfBuiltType(StringPiece name) const; + + // Tries to find something in the fallback database and link in the + // corresponding proto file. Returns true if successful, in which case + // the caller should search for the thing again. These are declared + // const because they are called by (semantically) const methods. + bool TryFindFileInFallbackDatabase(StringPiece name) const; + bool TryFindSymbolInFallbackDatabase(StringPiece name) const; + bool TryFindExtensionInFallbackDatabase(const Descriptor* containing_type, + int field_number) const; + + // This internal find extension method only check with its table and underlay + // descriptor_pool's table. It does not check with fallback DB and no + // additional proto file will be build in this method. + const FieldDescriptor* InternalFindExtensionByNumberNoLock( + const Descriptor* extendee, int number) const; + + // Like BuildFile() but called internally when the file has been loaded from + // fallback_database_. Declared const because it is called by (semantically) + // const methods. + const FileDescriptor* BuildFileFromDatabase( + const FileDescriptorProto& proto) const; + + // Helper for when lazily_build_dependencies_ is set, can look up a symbol + // after the file's descriptor is built, and can build the file where that + // symbol is defined if necessary. Will create a placeholder if the type + // doesn't exist in the fallback database, or the file doesn't build + // successfully. + Symbol CrossLinkOnDemandHelper(StringPiece name, + bool expecting_enum) const; + + // Create a placeholder FileDescriptor of the specified name + FileDescriptor* NewPlaceholderFile(StringPiece name) const; + FileDescriptor* NewPlaceholderFileWithMutexHeld(StringPiece name) const; + + enum PlaceholderType { + PLACEHOLDER_MESSAGE, + PLACEHOLDER_ENUM, + PLACEHOLDER_EXTENDABLE_MESSAGE + }; + // Create a placeholder Descriptor of the specified name + Symbol NewPlaceholder(StringPiece name, + PlaceholderType placeholder_type) const; + Symbol NewPlaceholderWithMutexHeld(StringPiece name, + PlaceholderType placeholder_type) const; + + // If fallback_database_ is nullptr, this is nullptr. Otherwise, this is a + // mutex which must be locked while accessing tables_. + internal::WrappedMutex* mutex_; + + // See constructor. + DescriptorDatabase* fallback_database_; + ErrorCollector* default_error_collector_; + const DescriptorPool* underlay_; + + // This class contains a lot of hash maps with complicated types that + // we'd like to keep out of the header. + class Tables; + std::unique_ptr tables_; + + bool enforce_dependencies_; + bool lazily_build_dependencies_; + bool allow_unknown_; + bool enforce_weak_; + bool disallow_enforce_utf8_; + + // Set of files to track for unused imports. The bool value when true means + // unused imports are treated as errors (and as warnings when false). + std::map unused_import_track_files_; + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(DescriptorPool); +}; + + +// inline methods ==================================================== + +// These macros makes this repetitive code more readable. +#define PROTOBUF_DEFINE_ACCESSOR(CLASS, FIELD, TYPE) \ + inline TYPE CLASS::FIELD() const { return FIELD##_; } + +// Strings fields are stored as pointers but returned as const references. +#define PROTOBUF_DEFINE_STRING_ACCESSOR(CLASS, FIELD) \ + inline const std::string& CLASS::FIELD() const { return *FIELD##_; } + +// Arrays take an index parameter, obviously. +#define PROTOBUF_DEFINE_ARRAY_ACCESSOR(CLASS, FIELD, TYPE) \ + inline TYPE CLASS::FIELD(int index) const { return FIELD##s_ + index; } + +#define PROTOBUF_DEFINE_OPTIONS_ACCESSOR(CLASS, TYPE) \ + inline const TYPE& CLASS::options() const { return *options_; } + +PROTOBUF_DEFINE_STRING_ACCESSOR(Descriptor, name) +PROTOBUF_DEFINE_STRING_ACCESSOR(Descriptor, full_name) +PROTOBUF_DEFINE_ACCESSOR(Descriptor, file, const FileDescriptor*) +PROTOBUF_DEFINE_ACCESSOR(Descriptor, containing_type, const Descriptor*) + +PROTOBUF_DEFINE_ACCESSOR(Descriptor, field_count, int) +PROTOBUF_DEFINE_ACCESSOR(Descriptor, oneof_decl_count, int) +PROTOBUF_DEFINE_ACCESSOR(Descriptor, real_oneof_decl_count, int) +PROTOBUF_DEFINE_ACCESSOR(Descriptor, nested_type_count, int) +PROTOBUF_DEFINE_ACCESSOR(Descriptor, enum_type_count, int) + +PROTOBUF_DEFINE_ARRAY_ACCESSOR(Descriptor, field, const FieldDescriptor*) +PROTOBUF_DEFINE_ARRAY_ACCESSOR(Descriptor, oneof_decl, const OneofDescriptor*) +PROTOBUF_DEFINE_ARRAY_ACCESSOR(Descriptor, nested_type, const Descriptor*) +PROTOBUF_DEFINE_ARRAY_ACCESSOR(Descriptor, enum_type, const EnumDescriptor*) + +PROTOBUF_DEFINE_ACCESSOR(Descriptor, extension_range_count, int) +PROTOBUF_DEFINE_ACCESSOR(Descriptor, extension_count, int) +PROTOBUF_DEFINE_ARRAY_ACCESSOR(Descriptor, extension_range, + const Descriptor::ExtensionRange*) +PROTOBUF_DEFINE_ARRAY_ACCESSOR(Descriptor, extension, const FieldDescriptor*) + +PROTOBUF_DEFINE_ACCESSOR(Descriptor, reserved_range_count, int) +PROTOBUF_DEFINE_ARRAY_ACCESSOR(Descriptor, reserved_range, + const Descriptor::ReservedRange*) +PROTOBUF_DEFINE_ACCESSOR(Descriptor, reserved_name_count, int) + +PROTOBUF_DEFINE_OPTIONS_ACCESSOR(Descriptor, MessageOptions) +PROTOBUF_DEFINE_ACCESSOR(Descriptor, is_placeholder, bool) + +PROTOBUF_DEFINE_STRING_ACCESSOR(FieldDescriptor, name) +PROTOBUF_DEFINE_STRING_ACCESSOR(FieldDescriptor, full_name) +PROTOBUF_DEFINE_STRING_ACCESSOR(FieldDescriptor, json_name) +PROTOBUF_DEFINE_STRING_ACCESSOR(FieldDescriptor, lowercase_name) +PROTOBUF_DEFINE_STRING_ACCESSOR(FieldDescriptor, camelcase_name) +PROTOBUF_DEFINE_ACCESSOR(FieldDescriptor, file, const FileDescriptor*) +PROTOBUF_DEFINE_ACCESSOR(FieldDescriptor, number, int) +PROTOBUF_DEFINE_ACCESSOR(FieldDescriptor, is_extension, bool) +PROTOBUF_DEFINE_ACCESSOR(FieldDescriptor, label, FieldDescriptor::Label) +PROTOBUF_DEFINE_ACCESSOR(FieldDescriptor, containing_type, const Descriptor*) +PROTOBUF_DEFINE_ACCESSOR(FieldDescriptor, containing_oneof, + const OneofDescriptor*) +PROTOBUF_DEFINE_ACCESSOR(FieldDescriptor, index_in_oneof, int) +PROTOBUF_DEFINE_ACCESSOR(FieldDescriptor, extension_scope, const Descriptor*) +PROTOBUF_DEFINE_OPTIONS_ACCESSOR(FieldDescriptor, FieldOptions) +PROTOBUF_DEFINE_ACCESSOR(FieldDescriptor, has_default_value, bool) +PROTOBUF_DEFINE_ACCESSOR(FieldDescriptor, has_json_name, bool) +PROTOBUF_DEFINE_ACCESSOR(FieldDescriptor, default_value_int32, int32) +PROTOBUF_DEFINE_ACCESSOR(FieldDescriptor, default_value_int64, int64) +PROTOBUF_DEFINE_ACCESSOR(FieldDescriptor, default_value_uint32, uint32) +PROTOBUF_DEFINE_ACCESSOR(FieldDescriptor, default_value_uint64, uint64) +PROTOBUF_DEFINE_ACCESSOR(FieldDescriptor, default_value_float, float) +PROTOBUF_DEFINE_ACCESSOR(FieldDescriptor, default_value_double, double) +PROTOBUF_DEFINE_ACCESSOR(FieldDescriptor, default_value_bool, bool) +PROTOBUF_DEFINE_STRING_ACCESSOR(FieldDescriptor, default_value_string) + +PROTOBUF_DEFINE_STRING_ACCESSOR(OneofDescriptor, name) +PROTOBUF_DEFINE_STRING_ACCESSOR(OneofDescriptor, full_name) +PROTOBUF_DEFINE_ACCESSOR(OneofDescriptor, containing_type, const Descriptor*) +PROTOBUF_DEFINE_ACCESSOR(OneofDescriptor, field_count, int) +PROTOBUF_DEFINE_OPTIONS_ACCESSOR(OneofDescriptor, OneofOptions) + +PROTOBUF_DEFINE_STRING_ACCESSOR(EnumDescriptor, name) +PROTOBUF_DEFINE_STRING_ACCESSOR(EnumDescriptor, full_name) +PROTOBUF_DEFINE_ACCESSOR(EnumDescriptor, file, const FileDescriptor*) +PROTOBUF_DEFINE_ACCESSOR(EnumDescriptor, containing_type, const Descriptor*) +PROTOBUF_DEFINE_ACCESSOR(EnumDescriptor, value_count, int) +PROTOBUF_DEFINE_ARRAY_ACCESSOR(EnumDescriptor, value, + const EnumValueDescriptor*) +PROTOBUF_DEFINE_OPTIONS_ACCESSOR(EnumDescriptor, EnumOptions) +PROTOBUF_DEFINE_ACCESSOR(EnumDescriptor, is_placeholder, bool) +PROTOBUF_DEFINE_ACCESSOR(EnumDescriptor, reserved_range_count, int) +PROTOBUF_DEFINE_ARRAY_ACCESSOR(EnumDescriptor, reserved_range, + const EnumDescriptor::ReservedRange*) +PROTOBUF_DEFINE_ACCESSOR(EnumDescriptor, reserved_name_count, int) + +PROTOBUF_DEFINE_STRING_ACCESSOR(EnumValueDescriptor, name) +PROTOBUF_DEFINE_STRING_ACCESSOR(EnumValueDescriptor, full_name) +PROTOBUF_DEFINE_ACCESSOR(EnumValueDescriptor, number, int) +PROTOBUF_DEFINE_ACCESSOR(EnumValueDescriptor, type, const EnumDescriptor*) +PROTOBUF_DEFINE_OPTIONS_ACCESSOR(EnumValueDescriptor, EnumValueOptions) + +PROTOBUF_DEFINE_STRING_ACCESSOR(ServiceDescriptor, name) +PROTOBUF_DEFINE_STRING_ACCESSOR(ServiceDescriptor, full_name) +PROTOBUF_DEFINE_ACCESSOR(ServiceDescriptor, file, const FileDescriptor*) +PROTOBUF_DEFINE_ACCESSOR(ServiceDescriptor, method_count, int) +PROTOBUF_DEFINE_ARRAY_ACCESSOR(ServiceDescriptor, method, + const MethodDescriptor*) +PROTOBUF_DEFINE_OPTIONS_ACCESSOR(ServiceDescriptor, ServiceOptions) + +PROTOBUF_DEFINE_STRING_ACCESSOR(MethodDescriptor, name) +PROTOBUF_DEFINE_STRING_ACCESSOR(MethodDescriptor, full_name) +PROTOBUF_DEFINE_ACCESSOR(MethodDescriptor, service, const ServiceDescriptor*) +PROTOBUF_DEFINE_OPTIONS_ACCESSOR(MethodDescriptor, MethodOptions) +PROTOBUF_DEFINE_ACCESSOR(MethodDescriptor, client_streaming, bool) +PROTOBUF_DEFINE_ACCESSOR(MethodDescriptor, server_streaming, bool) + +PROTOBUF_DEFINE_STRING_ACCESSOR(FileDescriptor, name) +PROTOBUF_DEFINE_STRING_ACCESSOR(FileDescriptor, package) +PROTOBUF_DEFINE_ACCESSOR(FileDescriptor, pool, const DescriptorPool*) +PROTOBUF_DEFINE_ACCESSOR(FileDescriptor, dependency_count, int) +PROTOBUF_DEFINE_ACCESSOR(FileDescriptor, public_dependency_count, int) +PROTOBUF_DEFINE_ACCESSOR(FileDescriptor, weak_dependency_count, int) +PROTOBUF_DEFINE_ACCESSOR(FileDescriptor, message_type_count, int) +PROTOBUF_DEFINE_ACCESSOR(FileDescriptor, enum_type_count, int) +PROTOBUF_DEFINE_ACCESSOR(FileDescriptor, service_count, int) +PROTOBUF_DEFINE_ACCESSOR(FileDescriptor, extension_count, int) +PROTOBUF_DEFINE_OPTIONS_ACCESSOR(FileDescriptor, FileOptions) +PROTOBUF_DEFINE_ACCESSOR(FileDescriptor, is_placeholder, bool) + +PROTOBUF_DEFINE_ARRAY_ACCESSOR(FileDescriptor, message_type, const Descriptor*) +PROTOBUF_DEFINE_ARRAY_ACCESSOR(FileDescriptor, enum_type, const EnumDescriptor*) +PROTOBUF_DEFINE_ARRAY_ACCESSOR(FileDescriptor, service, + const ServiceDescriptor*) +PROTOBUF_DEFINE_ARRAY_ACCESSOR(FileDescriptor, extension, + const FieldDescriptor*) + +#undef PROTOBUF_DEFINE_ACCESSOR +#undef PROTOBUF_DEFINE_STRING_ACCESSOR +#undef PROTOBUF_DEFINE_ARRAY_ACCESSOR + +// A few accessors differ from the macros... + +inline Descriptor::WellKnownType Descriptor::well_known_type() const { + return static_cast(well_known_type_); +} + +inline bool Descriptor::IsExtensionNumber(int number) const { + return FindExtensionRangeContainingNumber(number) != nullptr; +} + +inline bool Descriptor::IsReservedNumber(int number) const { + return FindReservedRangeContainingNumber(number) != nullptr; +} + +inline bool Descriptor::IsReservedName(ConstStringParam name) const { + for (int i = 0; i < reserved_name_count(); i++) { + if (name == static_cast(reserved_name(i))) { + return true; + } + } + return false; +} + +// Can't use PROTOBUF_DEFINE_ARRAY_ACCESSOR because reserved_names_ is actually +// an array of pointers rather than the usual array of objects. +inline const std::string& Descriptor::reserved_name(int index) const { + return *reserved_names_[index]; +} + +inline bool EnumDescriptor::IsReservedNumber(int number) const { + return FindReservedRangeContainingNumber(number) != nullptr; +} + +inline bool EnumDescriptor::IsReservedName(ConstStringParam name) const { + for (int i = 0; i < reserved_name_count(); i++) { + if (name == static_cast(reserved_name(i))) { + return true; + } + } + return false; +} + +// Can't use PROTOBUF_DEFINE_ARRAY_ACCESSOR because reserved_names_ is actually +// an array of pointers rather than the usual array of objects. +inline const std::string& EnumDescriptor::reserved_name(int index) const { + return *reserved_names_[index]; +} + +inline FieldDescriptor::Type FieldDescriptor::type() const { + if (type_once_) { + internal::call_once(*type_once_, &FieldDescriptor::TypeOnceInit, this); + } + return type_; +} + +inline bool FieldDescriptor::is_required() const { + return label() == LABEL_REQUIRED; +} + +inline bool FieldDescriptor::is_optional() const { + return label() == LABEL_OPTIONAL; +} + +inline bool FieldDescriptor::is_repeated() const { + return label() == LABEL_REPEATED; +} + +inline bool FieldDescriptor::is_packable() const { + return is_repeated() && IsTypePackable(type()); +} + +inline bool FieldDescriptor::is_map() const { + return type() == TYPE_MESSAGE && is_map_message_type(); +} + +inline bool FieldDescriptor::has_optional_keyword() const { + return proto3_optional_ || + (file()->syntax() == FileDescriptor::SYNTAX_PROTO2 && is_optional() && + !containing_oneof()); +} + +inline const OneofDescriptor* FieldDescriptor::real_containing_oneof() const { + return containing_oneof_ && !containing_oneof_->is_synthetic() + ? containing_oneof_ + : nullptr; +} + +inline bool FieldDescriptor::has_presence() const { + if (is_repeated()) return false; + return cpp_type() == CPPTYPE_MESSAGE || containing_oneof() || + file()->syntax() == FileDescriptor::SYNTAX_PROTO2; +} + +// To save space, index() is computed by looking at the descriptor's position +// in the parent's array of children. +inline int FieldDescriptor::index() const { + if (!is_extension_) { + return static_cast(this - containing_type()->fields_); + } else if (extension_scope_ != nullptr) { + return static_cast(this - extension_scope_->extensions_); + } else { + return static_cast(this - file_->extensions_); + } +} + +inline int Descriptor::index() const { + if (containing_type_ == nullptr) { + return static_cast(this - file_->message_types_); + } else { + return static_cast(this - containing_type_->nested_types_); + } +} + +inline const FileDescriptor* OneofDescriptor::file() const { + return containing_type()->file(); +} + +inline int OneofDescriptor::index() const { + return static_cast(this - containing_type_->oneof_decls_); +} + +inline bool OneofDescriptor::is_synthetic() const { + return field_count() == 1 && field(0)->proto3_optional_; +} + +inline int EnumDescriptor::index() const { + if (containing_type_ == nullptr) { + return static_cast(this - file_->enum_types_); + } else { + return static_cast(this - containing_type_->enum_types_); + } +} + +inline const FileDescriptor* EnumValueDescriptor::file() const { + return type()->file(); +} + +inline int EnumValueDescriptor::index() const { + return static_cast(this - type_->values_); +} + +inline int ServiceDescriptor::index() const { + return static_cast(this - file_->services_); +} + +inline const FileDescriptor* MethodDescriptor::file() const { + return service()->file(); +} + +inline int MethodDescriptor::index() const { + return static_cast(this - service_->methods_); +} + +inline const char* FieldDescriptor::type_name() const { + return kTypeToName[type()]; +} + +inline FieldDescriptor::CppType FieldDescriptor::cpp_type() const { + return kTypeToCppTypeMap[type()]; +} + +inline const char* FieldDescriptor::cpp_type_name() const { + return kCppTypeToName[kTypeToCppTypeMap[type()]]; +} + +inline FieldDescriptor::CppType FieldDescriptor::TypeToCppType(Type type) { + return kTypeToCppTypeMap[type]; +} + +inline const char* FieldDescriptor::TypeName(Type type) { + return kTypeToName[type]; +} + +inline const char* FieldDescriptor::CppTypeName(CppType cpp_type) { + return kCppTypeToName[cpp_type]; +} + +inline bool FieldDescriptor::IsTypePackable(Type field_type) { + return (field_type != FieldDescriptor::TYPE_STRING && + field_type != FieldDescriptor::TYPE_GROUP && + field_type != FieldDescriptor::TYPE_MESSAGE && + field_type != FieldDescriptor::TYPE_BYTES); +} + +inline const FileDescriptor* FileDescriptor::public_dependency( + int index) const { + return dependency(public_dependencies_[index]); +} + +inline const FileDescriptor* FileDescriptor::weak_dependency(int index) const { + return dependency(weak_dependencies_[index]); +} + +inline FileDescriptor::Syntax FileDescriptor::syntax() const { return syntax_; } + +// Can't use PROTOBUF_DEFINE_ARRAY_ACCESSOR because fields_ is actually an array +// of pointers rather than the usual array of objects. +inline const FieldDescriptor* OneofDescriptor::field(int index) const { + return fields_[index]; +} + +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_DESCRIPTOR_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/descriptor.pb.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/descriptor.pb.h new file mode 100644 index 0000000000000000000000000000000000000000..9eb2b2e55df09165755d5977cef55d75725ab2d2 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/descriptor.pb.h @@ -0,0 +1,12958 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: google/protobuf/descriptor.proto + +#ifndef GOOGLE_PROTOBUF_INCLUDED_google_2fprotobuf_2fdescriptor_2eproto +#define GOOGLE_PROTOBUF_INCLUDED_google_2fprotobuf_2fdescriptor_2eproto + +#include +#include + +#include +#if PROTOBUF_VERSION < 3013000 +#error This file was generated by a newer version of protoc which is +#error incompatible with your Protocol Buffer headers. Please update +#error your headers. +#endif +#if 3013000 < PROTOBUF_MIN_PROTOC_VERSION +#error This file was generated by an older version of protoc which is +#error incompatible with your Protocol Buffer headers. Please +#error regenerate this file with a newer version of protoc. +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include // IWYU pragma: export +#include // IWYU pragma: export +#include +#include +// @@protoc_insertion_point(includes) +#include +#define PROTOBUF_INTERNAL_EXPORT_google_2fprotobuf_2fdescriptor_2eproto PROTOBUF_EXPORT +PROTOBUF_NAMESPACE_OPEN +namespace internal { +class AnyMetadata; +} // namespace internal +PROTOBUF_NAMESPACE_CLOSE + +// Internal implementation detail -- do not use these members. +struct PROTOBUF_EXPORT TableStruct_google_2fprotobuf_2fdescriptor_2eproto { + static const ::PROTOBUF_NAMESPACE_ID::internal::ParseTableField entries[] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::AuxiliaryParseTableField aux[] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::ParseTable schema[27] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::FieldMetadata field_metadata[]; + static const ::PROTOBUF_NAMESPACE_ID::internal::SerializationTable serialization_table[]; + static const ::PROTOBUF_NAMESPACE_ID::uint32 offsets[]; +}; +extern PROTOBUF_EXPORT const ::PROTOBUF_NAMESPACE_ID::internal::DescriptorTable descriptor_table_google_2fprotobuf_2fdescriptor_2eproto; +PROTOBUF_NAMESPACE_OPEN +class DescriptorProto; +class DescriptorProtoDefaultTypeInternal; +PROTOBUF_EXPORT extern DescriptorProtoDefaultTypeInternal _DescriptorProto_default_instance_; +class DescriptorProto_ExtensionRange; +class DescriptorProto_ExtensionRangeDefaultTypeInternal; +PROTOBUF_EXPORT extern DescriptorProto_ExtensionRangeDefaultTypeInternal _DescriptorProto_ExtensionRange_default_instance_; +class DescriptorProto_ReservedRange; +class DescriptorProto_ReservedRangeDefaultTypeInternal; +PROTOBUF_EXPORT extern DescriptorProto_ReservedRangeDefaultTypeInternal _DescriptorProto_ReservedRange_default_instance_; +class EnumDescriptorProto; +class EnumDescriptorProtoDefaultTypeInternal; +PROTOBUF_EXPORT extern EnumDescriptorProtoDefaultTypeInternal _EnumDescriptorProto_default_instance_; +class EnumDescriptorProto_EnumReservedRange; +class EnumDescriptorProto_EnumReservedRangeDefaultTypeInternal; +PROTOBUF_EXPORT extern EnumDescriptorProto_EnumReservedRangeDefaultTypeInternal _EnumDescriptorProto_EnumReservedRange_default_instance_; +class EnumOptions; +class EnumOptionsDefaultTypeInternal; +PROTOBUF_EXPORT extern EnumOptionsDefaultTypeInternal _EnumOptions_default_instance_; +class EnumValueDescriptorProto; +class EnumValueDescriptorProtoDefaultTypeInternal; +PROTOBUF_EXPORT extern EnumValueDescriptorProtoDefaultTypeInternal _EnumValueDescriptorProto_default_instance_; +class EnumValueOptions; +class EnumValueOptionsDefaultTypeInternal; +PROTOBUF_EXPORT extern EnumValueOptionsDefaultTypeInternal _EnumValueOptions_default_instance_; +class ExtensionRangeOptions; +class ExtensionRangeOptionsDefaultTypeInternal; +PROTOBUF_EXPORT extern ExtensionRangeOptionsDefaultTypeInternal _ExtensionRangeOptions_default_instance_; +class FieldDescriptorProto; +class FieldDescriptorProtoDefaultTypeInternal; +PROTOBUF_EXPORT extern FieldDescriptorProtoDefaultTypeInternal _FieldDescriptorProto_default_instance_; +class FieldOptions; +class FieldOptionsDefaultTypeInternal; +PROTOBUF_EXPORT extern FieldOptionsDefaultTypeInternal _FieldOptions_default_instance_; +class FileDescriptorProto; +class FileDescriptorProtoDefaultTypeInternal; +PROTOBUF_EXPORT extern FileDescriptorProtoDefaultTypeInternal _FileDescriptorProto_default_instance_; +class FileDescriptorSet; +class FileDescriptorSetDefaultTypeInternal; +PROTOBUF_EXPORT extern FileDescriptorSetDefaultTypeInternal _FileDescriptorSet_default_instance_; +class FileOptions; +class FileOptionsDefaultTypeInternal; +PROTOBUF_EXPORT extern FileOptionsDefaultTypeInternal _FileOptions_default_instance_; +class GeneratedCodeInfo; +class GeneratedCodeInfoDefaultTypeInternal; +PROTOBUF_EXPORT extern GeneratedCodeInfoDefaultTypeInternal _GeneratedCodeInfo_default_instance_; +class GeneratedCodeInfo_Annotation; +class GeneratedCodeInfo_AnnotationDefaultTypeInternal; +PROTOBUF_EXPORT extern GeneratedCodeInfo_AnnotationDefaultTypeInternal _GeneratedCodeInfo_Annotation_default_instance_; +class MessageOptions; +class MessageOptionsDefaultTypeInternal; +PROTOBUF_EXPORT extern MessageOptionsDefaultTypeInternal _MessageOptions_default_instance_; +class MethodDescriptorProto; +class MethodDescriptorProtoDefaultTypeInternal; +PROTOBUF_EXPORT extern MethodDescriptorProtoDefaultTypeInternal _MethodDescriptorProto_default_instance_; +class MethodOptions; +class MethodOptionsDefaultTypeInternal; +PROTOBUF_EXPORT extern MethodOptionsDefaultTypeInternal _MethodOptions_default_instance_; +class OneofDescriptorProto; +class OneofDescriptorProtoDefaultTypeInternal; +PROTOBUF_EXPORT extern OneofDescriptorProtoDefaultTypeInternal _OneofDescriptorProto_default_instance_; +class OneofOptions; +class OneofOptionsDefaultTypeInternal; +PROTOBUF_EXPORT extern OneofOptionsDefaultTypeInternal _OneofOptions_default_instance_; +class ServiceDescriptorProto; +class ServiceDescriptorProtoDefaultTypeInternal; +PROTOBUF_EXPORT extern ServiceDescriptorProtoDefaultTypeInternal _ServiceDescriptorProto_default_instance_; +class ServiceOptions; +class ServiceOptionsDefaultTypeInternal; +PROTOBUF_EXPORT extern ServiceOptionsDefaultTypeInternal _ServiceOptions_default_instance_; +class SourceCodeInfo; +class SourceCodeInfoDefaultTypeInternal; +PROTOBUF_EXPORT extern SourceCodeInfoDefaultTypeInternal _SourceCodeInfo_default_instance_; +class SourceCodeInfo_Location; +class SourceCodeInfo_LocationDefaultTypeInternal; +PROTOBUF_EXPORT extern SourceCodeInfo_LocationDefaultTypeInternal _SourceCodeInfo_Location_default_instance_; +class UninterpretedOption; +class UninterpretedOptionDefaultTypeInternal; +PROTOBUF_EXPORT extern UninterpretedOptionDefaultTypeInternal _UninterpretedOption_default_instance_; +class UninterpretedOption_NamePart; +class UninterpretedOption_NamePartDefaultTypeInternal; +PROTOBUF_EXPORT extern UninterpretedOption_NamePartDefaultTypeInternal _UninterpretedOption_NamePart_default_instance_; +PROTOBUF_NAMESPACE_CLOSE +PROTOBUF_NAMESPACE_OPEN +template<> PROTOBUF_EXPORT PROTOBUF_NAMESPACE_ID::DescriptorProto* Arena::CreateMaybeMessage(Arena*); +template<> PROTOBUF_EXPORT PROTOBUF_NAMESPACE_ID::DescriptorProto_ExtensionRange* Arena::CreateMaybeMessage(Arena*); +template<> PROTOBUF_EXPORT PROTOBUF_NAMESPACE_ID::DescriptorProto_ReservedRange* Arena::CreateMaybeMessage(Arena*); +template<> PROTOBUF_EXPORT PROTOBUF_NAMESPACE_ID::EnumDescriptorProto* Arena::CreateMaybeMessage(Arena*); +template<> PROTOBUF_EXPORT PROTOBUF_NAMESPACE_ID::EnumDescriptorProto_EnumReservedRange* Arena::CreateMaybeMessage(Arena*); +template<> PROTOBUF_EXPORT PROTOBUF_NAMESPACE_ID::EnumOptions* Arena::CreateMaybeMessage(Arena*); +template<> PROTOBUF_EXPORT PROTOBUF_NAMESPACE_ID::EnumValueDescriptorProto* Arena::CreateMaybeMessage(Arena*); +template<> PROTOBUF_EXPORT PROTOBUF_NAMESPACE_ID::EnumValueOptions* Arena::CreateMaybeMessage(Arena*); +template<> PROTOBUF_EXPORT PROTOBUF_NAMESPACE_ID::ExtensionRangeOptions* Arena::CreateMaybeMessage(Arena*); +template<> PROTOBUF_EXPORT PROTOBUF_NAMESPACE_ID::FieldDescriptorProto* Arena::CreateMaybeMessage(Arena*); +template<> PROTOBUF_EXPORT PROTOBUF_NAMESPACE_ID::FieldOptions* Arena::CreateMaybeMessage(Arena*); +template<> PROTOBUF_EXPORT PROTOBUF_NAMESPACE_ID::FileDescriptorProto* Arena::CreateMaybeMessage(Arena*); +template<> PROTOBUF_EXPORT PROTOBUF_NAMESPACE_ID::FileDescriptorSet* Arena::CreateMaybeMessage(Arena*); +template<> PROTOBUF_EXPORT PROTOBUF_NAMESPACE_ID::FileOptions* Arena::CreateMaybeMessage(Arena*); +template<> PROTOBUF_EXPORT PROTOBUF_NAMESPACE_ID::GeneratedCodeInfo* Arena::CreateMaybeMessage(Arena*); +template<> PROTOBUF_EXPORT PROTOBUF_NAMESPACE_ID::GeneratedCodeInfo_Annotation* Arena::CreateMaybeMessage(Arena*); +template<> PROTOBUF_EXPORT PROTOBUF_NAMESPACE_ID::MessageOptions* Arena::CreateMaybeMessage(Arena*); +template<> PROTOBUF_EXPORT PROTOBUF_NAMESPACE_ID::MethodDescriptorProto* Arena::CreateMaybeMessage(Arena*); +template<> PROTOBUF_EXPORT PROTOBUF_NAMESPACE_ID::MethodOptions* Arena::CreateMaybeMessage(Arena*); +template<> PROTOBUF_EXPORT PROTOBUF_NAMESPACE_ID::OneofDescriptorProto* Arena::CreateMaybeMessage(Arena*); +template<> PROTOBUF_EXPORT PROTOBUF_NAMESPACE_ID::OneofOptions* Arena::CreateMaybeMessage(Arena*); +template<> PROTOBUF_EXPORT PROTOBUF_NAMESPACE_ID::ServiceDescriptorProto* Arena::CreateMaybeMessage(Arena*); +template<> PROTOBUF_EXPORT PROTOBUF_NAMESPACE_ID::ServiceOptions* Arena::CreateMaybeMessage(Arena*); +template<> PROTOBUF_EXPORT PROTOBUF_NAMESPACE_ID::SourceCodeInfo* Arena::CreateMaybeMessage(Arena*); +template<> PROTOBUF_EXPORT PROTOBUF_NAMESPACE_ID::SourceCodeInfo_Location* Arena::CreateMaybeMessage(Arena*); +template<> PROTOBUF_EXPORT PROTOBUF_NAMESPACE_ID::UninterpretedOption* Arena::CreateMaybeMessage(Arena*); +template<> PROTOBUF_EXPORT PROTOBUF_NAMESPACE_ID::UninterpretedOption_NamePart* Arena::CreateMaybeMessage(Arena*); +PROTOBUF_NAMESPACE_CLOSE +PROTOBUF_NAMESPACE_OPEN + +enum FieldDescriptorProto_Type : int { + FieldDescriptorProto_Type_TYPE_DOUBLE = 1, + FieldDescriptorProto_Type_TYPE_FLOAT = 2, + FieldDescriptorProto_Type_TYPE_INT64 = 3, + FieldDescriptorProto_Type_TYPE_UINT64 = 4, + FieldDescriptorProto_Type_TYPE_INT32 = 5, + FieldDescriptorProto_Type_TYPE_FIXED64 = 6, + FieldDescriptorProto_Type_TYPE_FIXED32 = 7, + FieldDescriptorProto_Type_TYPE_BOOL = 8, + FieldDescriptorProto_Type_TYPE_STRING = 9, + FieldDescriptorProto_Type_TYPE_GROUP = 10, + FieldDescriptorProto_Type_TYPE_MESSAGE = 11, + FieldDescriptorProto_Type_TYPE_BYTES = 12, + FieldDescriptorProto_Type_TYPE_UINT32 = 13, + FieldDescriptorProto_Type_TYPE_ENUM = 14, + FieldDescriptorProto_Type_TYPE_SFIXED32 = 15, + FieldDescriptorProto_Type_TYPE_SFIXED64 = 16, + FieldDescriptorProto_Type_TYPE_SINT32 = 17, + FieldDescriptorProto_Type_TYPE_SINT64 = 18 +}; +PROTOBUF_EXPORT bool FieldDescriptorProto_Type_IsValid(int value); +constexpr FieldDescriptorProto_Type FieldDescriptorProto_Type_Type_MIN = FieldDescriptorProto_Type_TYPE_DOUBLE; +constexpr FieldDescriptorProto_Type FieldDescriptorProto_Type_Type_MAX = FieldDescriptorProto_Type_TYPE_SINT64; +constexpr int FieldDescriptorProto_Type_Type_ARRAYSIZE = FieldDescriptorProto_Type_Type_MAX + 1; + +PROTOBUF_EXPORT const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* FieldDescriptorProto_Type_descriptor(); +template +inline const std::string& FieldDescriptorProto_Type_Name(T enum_t_value) { + static_assert(::std::is_same::value || + ::std::is_integral::value, + "Incorrect type passed to function FieldDescriptorProto_Type_Name."); + return ::PROTOBUF_NAMESPACE_ID::internal::NameOfEnum( + FieldDescriptorProto_Type_descriptor(), enum_t_value); +} +inline bool FieldDescriptorProto_Type_Parse( + ::PROTOBUF_NAMESPACE_ID::ConstStringParam name, FieldDescriptorProto_Type* value) { + return ::PROTOBUF_NAMESPACE_ID::internal::ParseNamedEnum( + FieldDescriptorProto_Type_descriptor(), name, value); +} +enum FieldDescriptorProto_Label : int { + FieldDescriptorProto_Label_LABEL_OPTIONAL = 1, + FieldDescriptorProto_Label_LABEL_REQUIRED = 2, + FieldDescriptorProto_Label_LABEL_REPEATED = 3 +}; +PROTOBUF_EXPORT bool FieldDescriptorProto_Label_IsValid(int value); +constexpr FieldDescriptorProto_Label FieldDescriptorProto_Label_Label_MIN = FieldDescriptorProto_Label_LABEL_OPTIONAL; +constexpr FieldDescriptorProto_Label FieldDescriptorProto_Label_Label_MAX = FieldDescriptorProto_Label_LABEL_REPEATED; +constexpr int FieldDescriptorProto_Label_Label_ARRAYSIZE = FieldDescriptorProto_Label_Label_MAX + 1; + +PROTOBUF_EXPORT const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* FieldDescriptorProto_Label_descriptor(); +template +inline const std::string& FieldDescriptorProto_Label_Name(T enum_t_value) { + static_assert(::std::is_same::value || + ::std::is_integral::value, + "Incorrect type passed to function FieldDescriptorProto_Label_Name."); + return ::PROTOBUF_NAMESPACE_ID::internal::NameOfEnum( + FieldDescriptorProto_Label_descriptor(), enum_t_value); +} +inline bool FieldDescriptorProto_Label_Parse( + ::PROTOBUF_NAMESPACE_ID::ConstStringParam name, FieldDescriptorProto_Label* value) { + return ::PROTOBUF_NAMESPACE_ID::internal::ParseNamedEnum( + FieldDescriptorProto_Label_descriptor(), name, value); +} +enum FileOptions_OptimizeMode : int { + FileOptions_OptimizeMode_SPEED = 1, + FileOptions_OptimizeMode_CODE_SIZE = 2, + FileOptions_OptimizeMode_LITE_RUNTIME = 3 +}; +PROTOBUF_EXPORT bool FileOptions_OptimizeMode_IsValid(int value); +constexpr FileOptions_OptimizeMode FileOptions_OptimizeMode_OptimizeMode_MIN = FileOptions_OptimizeMode_SPEED; +constexpr FileOptions_OptimizeMode FileOptions_OptimizeMode_OptimizeMode_MAX = FileOptions_OptimizeMode_LITE_RUNTIME; +constexpr int FileOptions_OptimizeMode_OptimizeMode_ARRAYSIZE = FileOptions_OptimizeMode_OptimizeMode_MAX + 1; + +PROTOBUF_EXPORT const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* FileOptions_OptimizeMode_descriptor(); +template +inline const std::string& FileOptions_OptimizeMode_Name(T enum_t_value) { + static_assert(::std::is_same::value || + ::std::is_integral::value, + "Incorrect type passed to function FileOptions_OptimizeMode_Name."); + return ::PROTOBUF_NAMESPACE_ID::internal::NameOfEnum( + FileOptions_OptimizeMode_descriptor(), enum_t_value); +} +inline bool FileOptions_OptimizeMode_Parse( + ::PROTOBUF_NAMESPACE_ID::ConstStringParam name, FileOptions_OptimizeMode* value) { + return ::PROTOBUF_NAMESPACE_ID::internal::ParseNamedEnum( + FileOptions_OptimizeMode_descriptor(), name, value); +} +enum FieldOptions_CType : int { + FieldOptions_CType_STRING = 0, + FieldOptions_CType_CORD = 1, + FieldOptions_CType_STRING_PIECE = 2 +}; +PROTOBUF_EXPORT bool FieldOptions_CType_IsValid(int value); +constexpr FieldOptions_CType FieldOptions_CType_CType_MIN = FieldOptions_CType_STRING; +constexpr FieldOptions_CType FieldOptions_CType_CType_MAX = FieldOptions_CType_STRING_PIECE; +constexpr int FieldOptions_CType_CType_ARRAYSIZE = FieldOptions_CType_CType_MAX + 1; + +PROTOBUF_EXPORT const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* FieldOptions_CType_descriptor(); +template +inline const std::string& FieldOptions_CType_Name(T enum_t_value) { + static_assert(::std::is_same::value || + ::std::is_integral::value, + "Incorrect type passed to function FieldOptions_CType_Name."); + return ::PROTOBUF_NAMESPACE_ID::internal::NameOfEnum( + FieldOptions_CType_descriptor(), enum_t_value); +} +inline bool FieldOptions_CType_Parse( + ::PROTOBUF_NAMESPACE_ID::ConstStringParam name, FieldOptions_CType* value) { + return ::PROTOBUF_NAMESPACE_ID::internal::ParseNamedEnum( + FieldOptions_CType_descriptor(), name, value); +} +enum FieldOptions_JSType : int { + FieldOptions_JSType_JS_NORMAL = 0, + FieldOptions_JSType_JS_STRING = 1, + FieldOptions_JSType_JS_NUMBER = 2 +}; +PROTOBUF_EXPORT bool FieldOptions_JSType_IsValid(int value); +constexpr FieldOptions_JSType FieldOptions_JSType_JSType_MIN = FieldOptions_JSType_JS_NORMAL; +constexpr FieldOptions_JSType FieldOptions_JSType_JSType_MAX = FieldOptions_JSType_JS_NUMBER; +constexpr int FieldOptions_JSType_JSType_ARRAYSIZE = FieldOptions_JSType_JSType_MAX + 1; + +PROTOBUF_EXPORT const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* FieldOptions_JSType_descriptor(); +template +inline const std::string& FieldOptions_JSType_Name(T enum_t_value) { + static_assert(::std::is_same::value || + ::std::is_integral::value, + "Incorrect type passed to function FieldOptions_JSType_Name."); + return ::PROTOBUF_NAMESPACE_ID::internal::NameOfEnum( + FieldOptions_JSType_descriptor(), enum_t_value); +} +inline bool FieldOptions_JSType_Parse( + ::PROTOBUF_NAMESPACE_ID::ConstStringParam name, FieldOptions_JSType* value) { + return ::PROTOBUF_NAMESPACE_ID::internal::ParseNamedEnum( + FieldOptions_JSType_descriptor(), name, value); +} +enum MethodOptions_IdempotencyLevel : int { + MethodOptions_IdempotencyLevel_IDEMPOTENCY_UNKNOWN = 0, + MethodOptions_IdempotencyLevel_NO_SIDE_EFFECTS = 1, + MethodOptions_IdempotencyLevel_IDEMPOTENT = 2 +}; +PROTOBUF_EXPORT bool MethodOptions_IdempotencyLevel_IsValid(int value); +constexpr MethodOptions_IdempotencyLevel MethodOptions_IdempotencyLevel_IdempotencyLevel_MIN = MethodOptions_IdempotencyLevel_IDEMPOTENCY_UNKNOWN; +constexpr MethodOptions_IdempotencyLevel MethodOptions_IdempotencyLevel_IdempotencyLevel_MAX = MethodOptions_IdempotencyLevel_IDEMPOTENT; +constexpr int MethodOptions_IdempotencyLevel_IdempotencyLevel_ARRAYSIZE = MethodOptions_IdempotencyLevel_IdempotencyLevel_MAX + 1; + +PROTOBUF_EXPORT const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* MethodOptions_IdempotencyLevel_descriptor(); +template +inline const std::string& MethodOptions_IdempotencyLevel_Name(T enum_t_value) { + static_assert(::std::is_same::value || + ::std::is_integral::value, + "Incorrect type passed to function MethodOptions_IdempotencyLevel_Name."); + return ::PROTOBUF_NAMESPACE_ID::internal::NameOfEnum( + MethodOptions_IdempotencyLevel_descriptor(), enum_t_value); +} +inline bool MethodOptions_IdempotencyLevel_Parse( + ::PROTOBUF_NAMESPACE_ID::ConstStringParam name, MethodOptions_IdempotencyLevel* value) { + return ::PROTOBUF_NAMESPACE_ID::internal::ParseNamedEnum( + MethodOptions_IdempotencyLevel_descriptor(), name, value); +} +// =================================================================== + +class PROTOBUF_EXPORT FileDescriptorSet PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:google.protobuf.FileDescriptorSet) */ { + public: + inline FileDescriptorSet() : FileDescriptorSet(nullptr) {} + virtual ~FileDescriptorSet(); + + FileDescriptorSet(const FileDescriptorSet& from); + FileDescriptorSet(FileDescriptorSet&& from) noexcept + : FileDescriptorSet() { + *this = ::std::move(from); + } + + inline FileDescriptorSet& operator=(const FileDescriptorSet& from) { + CopyFrom(from); + return *this; + } + inline FileDescriptorSet& operator=(FileDescriptorSet&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const FileDescriptorSet& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const FileDescriptorSet* internal_default_instance() { + return reinterpret_cast( + &_FileDescriptorSet_default_instance_); + } + static constexpr int kIndexInFileMessages = + 0; + + friend void swap(FileDescriptorSet& a, FileDescriptorSet& b) { + a.Swap(&b); + } + inline void Swap(FileDescriptorSet* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(FileDescriptorSet* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline FileDescriptorSet* New() const final { + return CreateMaybeMessage(nullptr); + } + + FileDescriptorSet* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const FileDescriptorSet& from); + void MergeFrom(const FileDescriptorSet& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(FileDescriptorSet* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "google.protobuf.FileDescriptorSet"; + } + protected: + explicit FileDescriptorSet(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_google_2fprotobuf_2fdescriptor_2eproto); + return ::descriptor_table_google_2fprotobuf_2fdescriptor_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kFileFieldNumber = 1, + }; + // repeated .google.protobuf.FileDescriptorProto file = 1; + int file_size() const; + private: + int _internal_file_size() const; + public: + void clear_file(); + PROTOBUF_NAMESPACE_ID::FileDescriptorProto* mutable_file(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::FileDescriptorProto >* + mutable_file(); + private: + const PROTOBUF_NAMESPACE_ID::FileDescriptorProto& _internal_file(int index) const; + PROTOBUF_NAMESPACE_ID::FileDescriptorProto* _internal_add_file(); + public: + const PROTOBUF_NAMESPACE_ID::FileDescriptorProto& file(int index) const; + PROTOBUF_NAMESPACE_ID::FileDescriptorProto* add_file(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::FileDescriptorProto >& + file() const; + + // @@protoc_insertion_point(class_scope:google.protobuf.FileDescriptorSet) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::FileDescriptorProto > file_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + friend struct ::TableStruct_google_2fprotobuf_2fdescriptor_2eproto; +}; +// ------------------------------------------------------------------- + +class PROTOBUF_EXPORT FileDescriptorProto PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:google.protobuf.FileDescriptorProto) */ { + public: + inline FileDescriptorProto() : FileDescriptorProto(nullptr) {} + virtual ~FileDescriptorProto(); + + FileDescriptorProto(const FileDescriptorProto& from); + FileDescriptorProto(FileDescriptorProto&& from) noexcept + : FileDescriptorProto() { + *this = ::std::move(from); + } + + inline FileDescriptorProto& operator=(const FileDescriptorProto& from) { + CopyFrom(from); + return *this; + } + inline FileDescriptorProto& operator=(FileDescriptorProto&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const FileDescriptorProto& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const FileDescriptorProto* internal_default_instance() { + return reinterpret_cast( + &_FileDescriptorProto_default_instance_); + } + static constexpr int kIndexInFileMessages = + 1; + + friend void swap(FileDescriptorProto& a, FileDescriptorProto& b) { + a.Swap(&b); + } + inline void Swap(FileDescriptorProto* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(FileDescriptorProto* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline FileDescriptorProto* New() const final { + return CreateMaybeMessage(nullptr); + } + + FileDescriptorProto* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const FileDescriptorProto& from); + void MergeFrom(const FileDescriptorProto& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(FileDescriptorProto* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "google.protobuf.FileDescriptorProto"; + } + protected: + explicit FileDescriptorProto(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_google_2fprotobuf_2fdescriptor_2eproto); + return ::descriptor_table_google_2fprotobuf_2fdescriptor_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kDependencyFieldNumber = 3, + kMessageTypeFieldNumber = 4, + kEnumTypeFieldNumber = 5, + kServiceFieldNumber = 6, + kExtensionFieldNumber = 7, + kPublicDependencyFieldNumber = 10, + kWeakDependencyFieldNumber = 11, + kNameFieldNumber = 1, + kPackageFieldNumber = 2, + kSyntaxFieldNumber = 12, + kOptionsFieldNumber = 8, + kSourceCodeInfoFieldNumber = 9, + }; + // repeated string dependency = 3; + int dependency_size() const; + private: + int _internal_dependency_size() const; + public: + void clear_dependency(); + const std::string& dependency(int index) const; + std::string* mutable_dependency(int index); + void set_dependency(int index, const std::string& value); + void set_dependency(int index, std::string&& value); + void set_dependency(int index, const char* value); + void set_dependency(int index, const char* value, size_t size); + std::string* add_dependency(); + void add_dependency(const std::string& value); + void add_dependency(std::string&& value); + void add_dependency(const char* value); + void add_dependency(const char* value, size_t size); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField& dependency() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField* mutable_dependency(); + private: + const std::string& _internal_dependency(int index) const; + std::string* _internal_add_dependency(); + public: + + // repeated .google.protobuf.DescriptorProto message_type = 4; + int message_type_size() const; + private: + int _internal_message_type_size() const; + public: + void clear_message_type(); + PROTOBUF_NAMESPACE_ID::DescriptorProto* mutable_message_type(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::DescriptorProto >* + mutable_message_type(); + private: + const PROTOBUF_NAMESPACE_ID::DescriptorProto& _internal_message_type(int index) const; + PROTOBUF_NAMESPACE_ID::DescriptorProto* _internal_add_message_type(); + public: + const PROTOBUF_NAMESPACE_ID::DescriptorProto& message_type(int index) const; + PROTOBUF_NAMESPACE_ID::DescriptorProto* add_message_type(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::DescriptorProto >& + message_type() const; + + // repeated .google.protobuf.EnumDescriptorProto enum_type = 5; + int enum_type_size() const; + private: + int _internal_enum_type_size() const; + public: + void clear_enum_type(); + PROTOBUF_NAMESPACE_ID::EnumDescriptorProto* mutable_enum_type(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::EnumDescriptorProto >* + mutable_enum_type(); + private: + const PROTOBUF_NAMESPACE_ID::EnumDescriptorProto& _internal_enum_type(int index) const; + PROTOBUF_NAMESPACE_ID::EnumDescriptorProto* _internal_add_enum_type(); + public: + const PROTOBUF_NAMESPACE_ID::EnumDescriptorProto& enum_type(int index) const; + PROTOBUF_NAMESPACE_ID::EnumDescriptorProto* add_enum_type(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::EnumDescriptorProto >& + enum_type() const; + + // repeated .google.protobuf.ServiceDescriptorProto service = 6; + int service_size() const; + private: + int _internal_service_size() const; + public: + void clear_service(); + PROTOBUF_NAMESPACE_ID::ServiceDescriptorProto* mutable_service(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::ServiceDescriptorProto >* + mutable_service(); + private: + const PROTOBUF_NAMESPACE_ID::ServiceDescriptorProto& _internal_service(int index) const; + PROTOBUF_NAMESPACE_ID::ServiceDescriptorProto* _internal_add_service(); + public: + const PROTOBUF_NAMESPACE_ID::ServiceDescriptorProto& service(int index) const; + PROTOBUF_NAMESPACE_ID::ServiceDescriptorProto* add_service(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::ServiceDescriptorProto >& + service() const; + + // repeated .google.protobuf.FieldDescriptorProto extension = 7; + int extension_size() const; + private: + int _internal_extension_size() const; + public: + void clear_extension(); + PROTOBUF_NAMESPACE_ID::FieldDescriptorProto* mutable_extension(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::FieldDescriptorProto >* + mutable_extension(); + private: + const PROTOBUF_NAMESPACE_ID::FieldDescriptorProto& _internal_extension(int index) const; + PROTOBUF_NAMESPACE_ID::FieldDescriptorProto* _internal_add_extension(); + public: + const PROTOBUF_NAMESPACE_ID::FieldDescriptorProto& extension(int index) const; + PROTOBUF_NAMESPACE_ID::FieldDescriptorProto* add_extension(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::FieldDescriptorProto >& + extension() const; + + // repeated int32 public_dependency = 10; + int public_dependency_size() const; + private: + int _internal_public_dependency_size() const; + public: + void clear_public_dependency(); + private: + ::PROTOBUF_NAMESPACE_ID::int32 _internal_public_dependency(int index) const; + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >& + _internal_public_dependency() const; + void _internal_add_public_dependency(::PROTOBUF_NAMESPACE_ID::int32 value); + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >* + _internal_mutable_public_dependency(); + public: + ::PROTOBUF_NAMESPACE_ID::int32 public_dependency(int index) const; + void set_public_dependency(int index, ::PROTOBUF_NAMESPACE_ID::int32 value); + void add_public_dependency(::PROTOBUF_NAMESPACE_ID::int32 value); + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >& + public_dependency() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >* + mutable_public_dependency(); + + // repeated int32 weak_dependency = 11; + int weak_dependency_size() const; + private: + int _internal_weak_dependency_size() const; + public: + void clear_weak_dependency(); + private: + ::PROTOBUF_NAMESPACE_ID::int32 _internal_weak_dependency(int index) const; + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >& + _internal_weak_dependency() const; + void _internal_add_weak_dependency(::PROTOBUF_NAMESPACE_ID::int32 value); + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >* + _internal_mutable_weak_dependency(); + public: + ::PROTOBUF_NAMESPACE_ID::int32 weak_dependency(int index) const; + void set_weak_dependency(int index, ::PROTOBUF_NAMESPACE_ID::int32 value); + void add_weak_dependency(::PROTOBUF_NAMESPACE_ID::int32 value); + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >& + weak_dependency() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >* + mutable_weak_dependency(); + + // optional string name = 1; + bool has_name() const; + private: + bool _internal_has_name() const; + public: + void clear_name(); + const std::string& name() const; + void set_name(const std::string& value); + void set_name(std::string&& value); + void set_name(const char* value); + void set_name(const char* value, size_t size); + std::string* mutable_name(); + std::string* release_name(); + void set_allocated_name(std::string* name); + private: + const std::string& _internal_name() const; + void _internal_set_name(const std::string& value); + std::string* _internal_mutable_name(); + public: + + // optional string package = 2; + bool has_package() const; + private: + bool _internal_has_package() const; + public: + void clear_package(); + const std::string& package() const; + void set_package(const std::string& value); + void set_package(std::string&& value); + void set_package(const char* value); + void set_package(const char* value, size_t size); + std::string* mutable_package(); + std::string* release_package(); + void set_allocated_package(std::string* package); + private: + const std::string& _internal_package() const; + void _internal_set_package(const std::string& value); + std::string* _internal_mutable_package(); + public: + + // optional string syntax = 12; + bool has_syntax() const; + private: + bool _internal_has_syntax() const; + public: + void clear_syntax(); + const std::string& syntax() const; + void set_syntax(const std::string& value); + void set_syntax(std::string&& value); + void set_syntax(const char* value); + void set_syntax(const char* value, size_t size); + std::string* mutable_syntax(); + std::string* release_syntax(); + void set_allocated_syntax(std::string* syntax); + private: + const std::string& _internal_syntax() const; + void _internal_set_syntax(const std::string& value); + std::string* _internal_mutable_syntax(); + public: + + // optional .google.protobuf.FileOptions options = 8; + bool has_options() const; + private: + bool _internal_has_options() const; + public: + void clear_options(); + const PROTOBUF_NAMESPACE_ID::FileOptions& options() const; + PROTOBUF_NAMESPACE_ID::FileOptions* release_options(); + PROTOBUF_NAMESPACE_ID::FileOptions* mutable_options(); + void set_allocated_options(PROTOBUF_NAMESPACE_ID::FileOptions* options); + private: + const PROTOBUF_NAMESPACE_ID::FileOptions& _internal_options() const; + PROTOBUF_NAMESPACE_ID::FileOptions* _internal_mutable_options(); + public: + void unsafe_arena_set_allocated_options( + PROTOBUF_NAMESPACE_ID::FileOptions* options); + PROTOBUF_NAMESPACE_ID::FileOptions* unsafe_arena_release_options(); + + // optional .google.protobuf.SourceCodeInfo source_code_info = 9; + bool has_source_code_info() const; + private: + bool _internal_has_source_code_info() const; + public: + void clear_source_code_info(); + const PROTOBUF_NAMESPACE_ID::SourceCodeInfo& source_code_info() const; + PROTOBUF_NAMESPACE_ID::SourceCodeInfo* release_source_code_info(); + PROTOBUF_NAMESPACE_ID::SourceCodeInfo* mutable_source_code_info(); + void set_allocated_source_code_info(PROTOBUF_NAMESPACE_ID::SourceCodeInfo* source_code_info); + private: + const PROTOBUF_NAMESPACE_ID::SourceCodeInfo& _internal_source_code_info() const; + PROTOBUF_NAMESPACE_ID::SourceCodeInfo* _internal_mutable_source_code_info(); + public: + void unsafe_arena_set_allocated_source_code_info( + PROTOBUF_NAMESPACE_ID::SourceCodeInfo* source_code_info); + PROTOBUF_NAMESPACE_ID::SourceCodeInfo* unsafe_arena_release_source_code_info(); + + // @@protoc_insertion_point(class_scope:google.protobuf.FileDescriptorProto) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField dependency_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::DescriptorProto > message_type_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::EnumDescriptorProto > enum_type_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::ServiceDescriptorProto > service_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::FieldDescriptorProto > extension_; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 > public_dependency_; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 > weak_dependency_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr name_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr package_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr syntax_; + PROTOBUF_NAMESPACE_ID::FileOptions* options_; + PROTOBUF_NAMESPACE_ID::SourceCodeInfo* source_code_info_; + friend struct ::TableStruct_google_2fprotobuf_2fdescriptor_2eproto; +}; +// ------------------------------------------------------------------- + +class PROTOBUF_EXPORT DescriptorProto_ExtensionRange PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:google.protobuf.DescriptorProto.ExtensionRange) */ { + public: + inline DescriptorProto_ExtensionRange() : DescriptorProto_ExtensionRange(nullptr) {} + virtual ~DescriptorProto_ExtensionRange(); + + DescriptorProto_ExtensionRange(const DescriptorProto_ExtensionRange& from); + DescriptorProto_ExtensionRange(DescriptorProto_ExtensionRange&& from) noexcept + : DescriptorProto_ExtensionRange() { + *this = ::std::move(from); + } + + inline DescriptorProto_ExtensionRange& operator=(const DescriptorProto_ExtensionRange& from) { + CopyFrom(from); + return *this; + } + inline DescriptorProto_ExtensionRange& operator=(DescriptorProto_ExtensionRange&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const DescriptorProto_ExtensionRange& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const DescriptorProto_ExtensionRange* internal_default_instance() { + return reinterpret_cast( + &_DescriptorProto_ExtensionRange_default_instance_); + } + static constexpr int kIndexInFileMessages = + 2; + + friend void swap(DescriptorProto_ExtensionRange& a, DescriptorProto_ExtensionRange& b) { + a.Swap(&b); + } + inline void Swap(DescriptorProto_ExtensionRange* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(DescriptorProto_ExtensionRange* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline DescriptorProto_ExtensionRange* New() const final { + return CreateMaybeMessage(nullptr); + } + + DescriptorProto_ExtensionRange* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const DescriptorProto_ExtensionRange& from); + void MergeFrom(const DescriptorProto_ExtensionRange& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(DescriptorProto_ExtensionRange* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "google.protobuf.DescriptorProto.ExtensionRange"; + } + protected: + explicit DescriptorProto_ExtensionRange(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_google_2fprotobuf_2fdescriptor_2eproto); + return ::descriptor_table_google_2fprotobuf_2fdescriptor_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kOptionsFieldNumber = 3, + kStartFieldNumber = 1, + kEndFieldNumber = 2, + }; + // optional .google.protobuf.ExtensionRangeOptions options = 3; + bool has_options() const; + private: + bool _internal_has_options() const; + public: + void clear_options(); + const PROTOBUF_NAMESPACE_ID::ExtensionRangeOptions& options() const; + PROTOBUF_NAMESPACE_ID::ExtensionRangeOptions* release_options(); + PROTOBUF_NAMESPACE_ID::ExtensionRangeOptions* mutable_options(); + void set_allocated_options(PROTOBUF_NAMESPACE_ID::ExtensionRangeOptions* options); + private: + const PROTOBUF_NAMESPACE_ID::ExtensionRangeOptions& _internal_options() const; + PROTOBUF_NAMESPACE_ID::ExtensionRangeOptions* _internal_mutable_options(); + public: + void unsafe_arena_set_allocated_options( + PROTOBUF_NAMESPACE_ID::ExtensionRangeOptions* options); + PROTOBUF_NAMESPACE_ID::ExtensionRangeOptions* unsafe_arena_release_options(); + + // optional int32 start = 1; + bool has_start() const; + private: + bool _internal_has_start() const; + public: + void clear_start(); + ::PROTOBUF_NAMESPACE_ID::int32 start() const; + void set_start(::PROTOBUF_NAMESPACE_ID::int32 value); + private: + ::PROTOBUF_NAMESPACE_ID::int32 _internal_start() const; + void _internal_set_start(::PROTOBUF_NAMESPACE_ID::int32 value); + public: + + // optional int32 end = 2; + bool has_end() const; + private: + bool _internal_has_end() const; + public: + void clear_end(); + ::PROTOBUF_NAMESPACE_ID::int32 end() const; + void set_end(::PROTOBUF_NAMESPACE_ID::int32 value); + private: + ::PROTOBUF_NAMESPACE_ID::int32 _internal_end() const; + void _internal_set_end(::PROTOBUF_NAMESPACE_ID::int32 value); + public: + + // @@protoc_insertion_point(class_scope:google.protobuf.DescriptorProto.ExtensionRange) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + PROTOBUF_NAMESPACE_ID::ExtensionRangeOptions* options_; + ::PROTOBUF_NAMESPACE_ID::int32 start_; + ::PROTOBUF_NAMESPACE_ID::int32 end_; + friend struct ::TableStruct_google_2fprotobuf_2fdescriptor_2eproto; +}; +// ------------------------------------------------------------------- + +class PROTOBUF_EXPORT DescriptorProto_ReservedRange PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:google.protobuf.DescriptorProto.ReservedRange) */ { + public: + inline DescriptorProto_ReservedRange() : DescriptorProto_ReservedRange(nullptr) {} + virtual ~DescriptorProto_ReservedRange(); + + DescriptorProto_ReservedRange(const DescriptorProto_ReservedRange& from); + DescriptorProto_ReservedRange(DescriptorProto_ReservedRange&& from) noexcept + : DescriptorProto_ReservedRange() { + *this = ::std::move(from); + } + + inline DescriptorProto_ReservedRange& operator=(const DescriptorProto_ReservedRange& from) { + CopyFrom(from); + return *this; + } + inline DescriptorProto_ReservedRange& operator=(DescriptorProto_ReservedRange&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const DescriptorProto_ReservedRange& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const DescriptorProto_ReservedRange* internal_default_instance() { + return reinterpret_cast( + &_DescriptorProto_ReservedRange_default_instance_); + } + static constexpr int kIndexInFileMessages = + 3; + + friend void swap(DescriptorProto_ReservedRange& a, DescriptorProto_ReservedRange& b) { + a.Swap(&b); + } + inline void Swap(DescriptorProto_ReservedRange* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(DescriptorProto_ReservedRange* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline DescriptorProto_ReservedRange* New() const final { + return CreateMaybeMessage(nullptr); + } + + DescriptorProto_ReservedRange* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const DescriptorProto_ReservedRange& from); + void MergeFrom(const DescriptorProto_ReservedRange& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(DescriptorProto_ReservedRange* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "google.protobuf.DescriptorProto.ReservedRange"; + } + protected: + explicit DescriptorProto_ReservedRange(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_google_2fprotobuf_2fdescriptor_2eproto); + return ::descriptor_table_google_2fprotobuf_2fdescriptor_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kStartFieldNumber = 1, + kEndFieldNumber = 2, + }; + // optional int32 start = 1; + bool has_start() const; + private: + bool _internal_has_start() const; + public: + void clear_start(); + ::PROTOBUF_NAMESPACE_ID::int32 start() const; + void set_start(::PROTOBUF_NAMESPACE_ID::int32 value); + private: + ::PROTOBUF_NAMESPACE_ID::int32 _internal_start() const; + void _internal_set_start(::PROTOBUF_NAMESPACE_ID::int32 value); + public: + + // optional int32 end = 2; + bool has_end() const; + private: + bool _internal_has_end() const; + public: + void clear_end(); + ::PROTOBUF_NAMESPACE_ID::int32 end() const; + void set_end(::PROTOBUF_NAMESPACE_ID::int32 value); + private: + ::PROTOBUF_NAMESPACE_ID::int32 _internal_end() const; + void _internal_set_end(::PROTOBUF_NAMESPACE_ID::int32 value); + public: + + // @@protoc_insertion_point(class_scope:google.protobuf.DescriptorProto.ReservedRange) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::int32 start_; + ::PROTOBUF_NAMESPACE_ID::int32 end_; + friend struct ::TableStruct_google_2fprotobuf_2fdescriptor_2eproto; +}; +// ------------------------------------------------------------------- + +class PROTOBUF_EXPORT DescriptorProto PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:google.protobuf.DescriptorProto) */ { + public: + inline DescriptorProto() : DescriptorProto(nullptr) {} + virtual ~DescriptorProto(); + + DescriptorProto(const DescriptorProto& from); + DescriptorProto(DescriptorProto&& from) noexcept + : DescriptorProto() { + *this = ::std::move(from); + } + + inline DescriptorProto& operator=(const DescriptorProto& from) { + CopyFrom(from); + return *this; + } + inline DescriptorProto& operator=(DescriptorProto&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const DescriptorProto& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const DescriptorProto* internal_default_instance() { + return reinterpret_cast( + &_DescriptorProto_default_instance_); + } + static constexpr int kIndexInFileMessages = + 4; + + friend void swap(DescriptorProto& a, DescriptorProto& b) { + a.Swap(&b); + } + inline void Swap(DescriptorProto* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(DescriptorProto* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline DescriptorProto* New() const final { + return CreateMaybeMessage(nullptr); + } + + DescriptorProto* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const DescriptorProto& from); + void MergeFrom(const DescriptorProto& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(DescriptorProto* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "google.protobuf.DescriptorProto"; + } + protected: + explicit DescriptorProto(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_google_2fprotobuf_2fdescriptor_2eproto); + return ::descriptor_table_google_2fprotobuf_2fdescriptor_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + typedef DescriptorProto_ExtensionRange ExtensionRange; + typedef DescriptorProto_ReservedRange ReservedRange; + + // accessors ------------------------------------------------------- + + enum : int { + kFieldFieldNumber = 2, + kNestedTypeFieldNumber = 3, + kEnumTypeFieldNumber = 4, + kExtensionRangeFieldNumber = 5, + kExtensionFieldNumber = 6, + kOneofDeclFieldNumber = 8, + kReservedRangeFieldNumber = 9, + kReservedNameFieldNumber = 10, + kNameFieldNumber = 1, + kOptionsFieldNumber = 7, + }; + // repeated .google.protobuf.FieldDescriptorProto field = 2; + int field_size() const; + private: + int _internal_field_size() const; + public: + void clear_field(); + PROTOBUF_NAMESPACE_ID::FieldDescriptorProto* mutable_field(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::FieldDescriptorProto >* + mutable_field(); + private: + const PROTOBUF_NAMESPACE_ID::FieldDescriptorProto& _internal_field(int index) const; + PROTOBUF_NAMESPACE_ID::FieldDescriptorProto* _internal_add_field(); + public: + const PROTOBUF_NAMESPACE_ID::FieldDescriptorProto& field(int index) const; + PROTOBUF_NAMESPACE_ID::FieldDescriptorProto* add_field(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::FieldDescriptorProto >& + field() const; + + // repeated .google.protobuf.DescriptorProto nested_type = 3; + int nested_type_size() const; + private: + int _internal_nested_type_size() const; + public: + void clear_nested_type(); + PROTOBUF_NAMESPACE_ID::DescriptorProto* mutable_nested_type(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::DescriptorProto >* + mutable_nested_type(); + private: + const PROTOBUF_NAMESPACE_ID::DescriptorProto& _internal_nested_type(int index) const; + PROTOBUF_NAMESPACE_ID::DescriptorProto* _internal_add_nested_type(); + public: + const PROTOBUF_NAMESPACE_ID::DescriptorProto& nested_type(int index) const; + PROTOBUF_NAMESPACE_ID::DescriptorProto* add_nested_type(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::DescriptorProto >& + nested_type() const; + + // repeated .google.protobuf.EnumDescriptorProto enum_type = 4; + int enum_type_size() const; + private: + int _internal_enum_type_size() const; + public: + void clear_enum_type(); + PROTOBUF_NAMESPACE_ID::EnumDescriptorProto* mutable_enum_type(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::EnumDescriptorProto >* + mutable_enum_type(); + private: + const PROTOBUF_NAMESPACE_ID::EnumDescriptorProto& _internal_enum_type(int index) const; + PROTOBUF_NAMESPACE_ID::EnumDescriptorProto* _internal_add_enum_type(); + public: + const PROTOBUF_NAMESPACE_ID::EnumDescriptorProto& enum_type(int index) const; + PROTOBUF_NAMESPACE_ID::EnumDescriptorProto* add_enum_type(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::EnumDescriptorProto >& + enum_type() const; + + // repeated .google.protobuf.DescriptorProto.ExtensionRange extension_range = 5; + int extension_range_size() const; + private: + int _internal_extension_range_size() const; + public: + void clear_extension_range(); + PROTOBUF_NAMESPACE_ID::DescriptorProto_ExtensionRange* mutable_extension_range(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::DescriptorProto_ExtensionRange >* + mutable_extension_range(); + private: + const PROTOBUF_NAMESPACE_ID::DescriptorProto_ExtensionRange& _internal_extension_range(int index) const; + PROTOBUF_NAMESPACE_ID::DescriptorProto_ExtensionRange* _internal_add_extension_range(); + public: + const PROTOBUF_NAMESPACE_ID::DescriptorProto_ExtensionRange& extension_range(int index) const; + PROTOBUF_NAMESPACE_ID::DescriptorProto_ExtensionRange* add_extension_range(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::DescriptorProto_ExtensionRange >& + extension_range() const; + + // repeated .google.protobuf.FieldDescriptorProto extension = 6; + int extension_size() const; + private: + int _internal_extension_size() const; + public: + void clear_extension(); + PROTOBUF_NAMESPACE_ID::FieldDescriptorProto* mutable_extension(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::FieldDescriptorProto >* + mutable_extension(); + private: + const PROTOBUF_NAMESPACE_ID::FieldDescriptorProto& _internal_extension(int index) const; + PROTOBUF_NAMESPACE_ID::FieldDescriptorProto* _internal_add_extension(); + public: + const PROTOBUF_NAMESPACE_ID::FieldDescriptorProto& extension(int index) const; + PROTOBUF_NAMESPACE_ID::FieldDescriptorProto* add_extension(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::FieldDescriptorProto >& + extension() const; + + // repeated .google.protobuf.OneofDescriptorProto oneof_decl = 8; + int oneof_decl_size() const; + private: + int _internal_oneof_decl_size() const; + public: + void clear_oneof_decl(); + PROTOBUF_NAMESPACE_ID::OneofDescriptorProto* mutable_oneof_decl(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::OneofDescriptorProto >* + mutable_oneof_decl(); + private: + const PROTOBUF_NAMESPACE_ID::OneofDescriptorProto& _internal_oneof_decl(int index) const; + PROTOBUF_NAMESPACE_ID::OneofDescriptorProto* _internal_add_oneof_decl(); + public: + const PROTOBUF_NAMESPACE_ID::OneofDescriptorProto& oneof_decl(int index) const; + PROTOBUF_NAMESPACE_ID::OneofDescriptorProto* add_oneof_decl(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::OneofDescriptorProto >& + oneof_decl() const; + + // repeated .google.protobuf.DescriptorProto.ReservedRange reserved_range = 9; + int reserved_range_size() const; + private: + int _internal_reserved_range_size() const; + public: + void clear_reserved_range(); + PROTOBUF_NAMESPACE_ID::DescriptorProto_ReservedRange* mutable_reserved_range(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::DescriptorProto_ReservedRange >* + mutable_reserved_range(); + private: + const PROTOBUF_NAMESPACE_ID::DescriptorProto_ReservedRange& _internal_reserved_range(int index) const; + PROTOBUF_NAMESPACE_ID::DescriptorProto_ReservedRange* _internal_add_reserved_range(); + public: + const PROTOBUF_NAMESPACE_ID::DescriptorProto_ReservedRange& reserved_range(int index) const; + PROTOBUF_NAMESPACE_ID::DescriptorProto_ReservedRange* add_reserved_range(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::DescriptorProto_ReservedRange >& + reserved_range() const; + + // repeated string reserved_name = 10; + int reserved_name_size() const; + private: + int _internal_reserved_name_size() const; + public: + void clear_reserved_name(); + const std::string& reserved_name(int index) const; + std::string* mutable_reserved_name(int index); + void set_reserved_name(int index, const std::string& value); + void set_reserved_name(int index, std::string&& value); + void set_reserved_name(int index, const char* value); + void set_reserved_name(int index, const char* value, size_t size); + std::string* add_reserved_name(); + void add_reserved_name(const std::string& value); + void add_reserved_name(std::string&& value); + void add_reserved_name(const char* value); + void add_reserved_name(const char* value, size_t size); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField& reserved_name() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField* mutable_reserved_name(); + private: + const std::string& _internal_reserved_name(int index) const; + std::string* _internal_add_reserved_name(); + public: + + // optional string name = 1; + bool has_name() const; + private: + bool _internal_has_name() const; + public: + void clear_name(); + const std::string& name() const; + void set_name(const std::string& value); + void set_name(std::string&& value); + void set_name(const char* value); + void set_name(const char* value, size_t size); + std::string* mutable_name(); + std::string* release_name(); + void set_allocated_name(std::string* name); + private: + const std::string& _internal_name() const; + void _internal_set_name(const std::string& value); + std::string* _internal_mutable_name(); + public: + + // optional .google.protobuf.MessageOptions options = 7; + bool has_options() const; + private: + bool _internal_has_options() const; + public: + void clear_options(); + const PROTOBUF_NAMESPACE_ID::MessageOptions& options() const; + PROTOBUF_NAMESPACE_ID::MessageOptions* release_options(); + PROTOBUF_NAMESPACE_ID::MessageOptions* mutable_options(); + void set_allocated_options(PROTOBUF_NAMESPACE_ID::MessageOptions* options); + private: + const PROTOBUF_NAMESPACE_ID::MessageOptions& _internal_options() const; + PROTOBUF_NAMESPACE_ID::MessageOptions* _internal_mutable_options(); + public: + void unsafe_arena_set_allocated_options( + PROTOBUF_NAMESPACE_ID::MessageOptions* options); + PROTOBUF_NAMESPACE_ID::MessageOptions* unsafe_arena_release_options(); + + // @@protoc_insertion_point(class_scope:google.protobuf.DescriptorProto) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::FieldDescriptorProto > field_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::DescriptorProto > nested_type_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::EnumDescriptorProto > enum_type_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::DescriptorProto_ExtensionRange > extension_range_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::FieldDescriptorProto > extension_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::OneofDescriptorProto > oneof_decl_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::DescriptorProto_ReservedRange > reserved_range_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField reserved_name_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr name_; + PROTOBUF_NAMESPACE_ID::MessageOptions* options_; + friend struct ::TableStruct_google_2fprotobuf_2fdescriptor_2eproto; +}; +// ------------------------------------------------------------------- + +class PROTOBUF_EXPORT ExtensionRangeOptions PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:google.protobuf.ExtensionRangeOptions) */ { + public: + inline ExtensionRangeOptions() : ExtensionRangeOptions(nullptr) {} + virtual ~ExtensionRangeOptions(); + + ExtensionRangeOptions(const ExtensionRangeOptions& from); + ExtensionRangeOptions(ExtensionRangeOptions&& from) noexcept + : ExtensionRangeOptions() { + *this = ::std::move(from); + } + + inline ExtensionRangeOptions& operator=(const ExtensionRangeOptions& from) { + CopyFrom(from); + return *this; + } + inline ExtensionRangeOptions& operator=(ExtensionRangeOptions&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const ExtensionRangeOptions& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const ExtensionRangeOptions* internal_default_instance() { + return reinterpret_cast( + &_ExtensionRangeOptions_default_instance_); + } + static constexpr int kIndexInFileMessages = + 5; + + friend void swap(ExtensionRangeOptions& a, ExtensionRangeOptions& b) { + a.Swap(&b); + } + inline void Swap(ExtensionRangeOptions* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(ExtensionRangeOptions* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline ExtensionRangeOptions* New() const final { + return CreateMaybeMessage(nullptr); + } + + ExtensionRangeOptions* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const ExtensionRangeOptions& from); + void MergeFrom(const ExtensionRangeOptions& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(ExtensionRangeOptions* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "google.protobuf.ExtensionRangeOptions"; + } + protected: + explicit ExtensionRangeOptions(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_google_2fprotobuf_2fdescriptor_2eproto); + return ::descriptor_table_google_2fprotobuf_2fdescriptor_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kUninterpretedOptionFieldNumber = 999, + }; + // repeated .google.protobuf.UninterpretedOption uninterpreted_option = 999; + int uninterpreted_option_size() const; + private: + int _internal_uninterpreted_option_size() const; + public: + void clear_uninterpreted_option(); + PROTOBUF_NAMESPACE_ID::UninterpretedOption* mutable_uninterpreted_option(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::UninterpretedOption >* + mutable_uninterpreted_option(); + private: + const PROTOBUF_NAMESPACE_ID::UninterpretedOption& _internal_uninterpreted_option(int index) const; + PROTOBUF_NAMESPACE_ID::UninterpretedOption* _internal_add_uninterpreted_option(); + public: + const PROTOBUF_NAMESPACE_ID::UninterpretedOption& uninterpreted_option(int index) const; + PROTOBUF_NAMESPACE_ID::UninterpretedOption* add_uninterpreted_option(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::UninterpretedOption >& + uninterpreted_option() const; + + GOOGLE_PROTOBUF_EXTENSION_ACCESSORS(ExtensionRangeOptions) + // @@protoc_insertion_point(class_scope:google.protobuf.ExtensionRangeOptions) + private: + class _Internal; + + ::PROTOBUF_NAMESPACE_ID::internal::ExtensionSet _extensions_; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::UninterpretedOption > uninterpreted_option_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + friend struct ::TableStruct_google_2fprotobuf_2fdescriptor_2eproto; +}; +// ------------------------------------------------------------------- + +class PROTOBUF_EXPORT FieldDescriptorProto PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:google.protobuf.FieldDescriptorProto) */ { + public: + inline FieldDescriptorProto() : FieldDescriptorProto(nullptr) {} + virtual ~FieldDescriptorProto(); + + FieldDescriptorProto(const FieldDescriptorProto& from); + FieldDescriptorProto(FieldDescriptorProto&& from) noexcept + : FieldDescriptorProto() { + *this = ::std::move(from); + } + + inline FieldDescriptorProto& operator=(const FieldDescriptorProto& from) { + CopyFrom(from); + return *this; + } + inline FieldDescriptorProto& operator=(FieldDescriptorProto&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const FieldDescriptorProto& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const FieldDescriptorProto* internal_default_instance() { + return reinterpret_cast( + &_FieldDescriptorProto_default_instance_); + } + static constexpr int kIndexInFileMessages = + 6; + + friend void swap(FieldDescriptorProto& a, FieldDescriptorProto& b) { + a.Swap(&b); + } + inline void Swap(FieldDescriptorProto* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(FieldDescriptorProto* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline FieldDescriptorProto* New() const final { + return CreateMaybeMessage(nullptr); + } + + FieldDescriptorProto* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const FieldDescriptorProto& from); + void MergeFrom(const FieldDescriptorProto& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(FieldDescriptorProto* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "google.protobuf.FieldDescriptorProto"; + } + protected: + explicit FieldDescriptorProto(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_google_2fprotobuf_2fdescriptor_2eproto); + return ::descriptor_table_google_2fprotobuf_2fdescriptor_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + typedef FieldDescriptorProto_Type Type; + static constexpr Type TYPE_DOUBLE = + FieldDescriptorProto_Type_TYPE_DOUBLE; + static constexpr Type TYPE_FLOAT = + FieldDescriptorProto_Type_TYPE_FLOAT; + static constexpr Type TYPE_INT64 = + FieldDescriptorProto_Type_TYPE_INT64; + static constexpr Type TYPE_UINT64 = + FieldDescriptorProto_Type_TYPE_UINT64; + static constexpr Type TYPE_INT32 = + FieldDescriptorProto_Type_TYPE_INT32; + static constexpr Type TYPE_FIXED64 = + FieldDescriptorProto_Type_TYPE_FIXED64; + static constexpr Type TYPE_FIXED32 = + FieldDescriptorProto_Type_TYPE_FIXED32; + static constexpr Type TYPE_BOOL = + FieldDescriptorProto_Type_TYPE_BOOL; + static constexpr Type TYPE_STRING = + FieldDescriptorProto_Type_TYPE_STRING; + static constexpr Type TYPE_GROUP = + FieldDescriptorProto_Type_TYPE_GROUP; + static constexpr Type TYPE_MESSAGE = + FieldDescriptorProto_Type_TYPE_MESSAGE; + static constexpr Type TYPE_BYTES = + FieldDescriptorProto_Type_TYPE_BYTES; + static constexpr Type TYPE_UINT32 = + FieldDescriptorProto_Type_TYPE_UINT32; + static constexpr Type TYPE_ENUM = + FieldDescriptorProto_Type_TYPE_ENUM; + static constexpr Type TYPE_SFIXED32 = + FieldDescriptorProto_Type_TYPE_SFIXED32; + static constexpr Type TYPE_SFIXED64 = + FieldDescriptorProto_Type_TYPE_SFIXED64; + static constexpr Type TYPE_SINT32 = + FieldDescriptorProto_Type_TYPE_SINT32; + static constexpr Type TYPE_SINT64 = + FieldDescriptorProto_Type_TYPE_SINT64; + static inline bool Type_IsValid(int value) { + return FieldDescriptorProto_Type_IsValid(value); + } + static constexpr Type Type_MIN = + FieldDescriptorProto_Type_Type_MIN; + static constexpr Type Type_MAX = + FieldDescriptorProto_Type_Type_MAX; + static constexpr int Type_ARRAYSIZE = + FieldDescriptorProto_Type_Type_ARRAYSIZE; + static inline const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* + Type_descriptor() { + return FieldDescriptorProto_Type_descriptor(); + } + template + static inline const std::string& Type_Name(T enum_t_value) { + static_assert(::std::is_same::value || + ::std::is_integral::value, + "Incorrect type passed to function Type_Name."); + return FieldDescriptorProto_Type_Name(enum_t_value); + } + static inline bool Type_Parse(::PROTOBUF_NAMESPACE_ID::ConstStringParam name, + Type* value) { + return FieldDescriptorProto_Type_Parse(name, value); + } + + typedef FieldDescriptorProto_Label Label; + static constexpr Label LABEL_OPTIONAL = + FieldDescriptorProto_Label_LABEL_OPTIONAL; + static constexpr Label LABEL_REQUIRED = + FieldDescriptorProto_Label_LABEL_REQUIRED; + static constexpr Label LABEL_REPEATED = + FieldDescriptorProto_Label_LABEL_REPEATED; + static inline bool Label_IsValid(int value) { + return FieldDescriptorProto_Label_IsValid(value); + } + static constexpr Label Label_MIN = + FieldDescriptorProto_Label_Label_MIN; + static constexpr Label Label_MAX = + FieldDescriptorProto_Label_Label_MAX; + static constexpr int Label_ARRAYSIZE = + FieldDescriptorProto_Label_Label_ARRAYSIZE; + static inline const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* + Label_descriptor() { + return FieldDescriptorProto_Label_descriptor(); + } + template + static inline const std::string& Label_Name(T enum_t_value) { + static_assert(::std::is_same::value || + ::std::is_integral::value, + "Incorrect type passed to function Label_Name."); + return FieldDescriptorProto_Label_Name(enum_t_value); + } + static inline bool Label_Parse(::PROTOBUF_NAMESPACE_ID::ConstStringParam name, + Label* value) { + return FieldDescriptorProto_Label_Parse(name, value); + } + + // accessors ------------------------------------------------------- + + enum : int { + kNameFieldNumber = 1, + kExtendeeFieldNumber = 2, + kTypeNameFieldNumber = 6, + kDefaultValueFieldNumber = 7, + kJsonNameFieldNumber = 10, + kOptionsFieldNumber = 8, + kNumberFieldNumber = 3, + kOneofIndexFieldNumber = 9, + kProto3OptionalFieldNumber = 17, + kLabelFieldNumber = 4, + kTypeFieldNumber = 5, + }; + // optional string name = 1; + bool has_name() const; + private: + bool _internal_has_name() const; + public: + void clear_name(); + const std::string& name() const; + void set_name(const std::string& value); + void set_name(std::string&& value); + void set_name(const char* value); + void set_name(const char* value, size_t size); + std::string* mutable_name(); + std::string* release_name(); + void set_allocated_name(std::string* name); + private: + const std::string& _internal_name() const; + void _internal_set_name(const std::string& value); + std::string* _internal_mutable_name(); + public: + + // optional string extendee = 2; + bool has_extendee() const; + private: + bool _internal_has_extendee() const; + public: + void clear_extendee(); + const std::string& extendee() const; + void set_extendee(const std::string& value); + void set_extendee(std::string&& value); + void set_extendee(const char* value); + void set_extendee(const char* value, size_t size); + std::string* mutable_extendee(); + std::string* release_extendee(); + void set_allocated_extendee(std::string* extendee); + private: + const std::string& _internal_extendee() const; + void _internal_set_extendee(const std::string& value); + std::string* _internal_mutable_extendee(); + public: + + // optional string type_name = 6; + bool has_type_name() const; + private: + bool _internal_has_type_name() const; + public: + void clear_type_name(); + const std::string& type_name() const; + void set_type_name(const std::string& value); + void set_type_name(std::string&& value); + void set_type_name(const char* value); + void set_type_name(const char* value, size_t size); + std::string* mutable_type_name(); + std::string* release_type_name(); + void set_allocated_type_name(std::string* type_name); + private: + const std::string& _internal_type_name() const; + void _internal_set_type_name(const std::string& value); + std::string* _internal_mutable_type_name(); + public: + + // optional string default_value = 7; + bool has_default_value() const; + private: + bool _internal_has_default_value() const; + public: + void clear_default_value(); + const std::string& default_value() const; + void set_default_value(const std::string& value); + void set_default_value(std::string&& value); + void set_default_value(const char* value); + void set_default_value(const char* value, size_t size); + std::string* mutable_default_value(); + std::string* release_default_value(); + void set_allocated_default_value(std::string* default_value); + private: + const std::string& _internal_default_value() const; + void _internal_set_default_value(const std::string& value); + std::string* _internal_mutable_default_value(); + public: + + // optional string json_name = 10; + bool has_json_name() const; + private: + bool _internal_has_json_name() const; + public: + void clear_json_name(); + const std::string& json_name() const; + void set_json_name(const std::string& value); + void set_json_name(std::string&& value); + void set_json_name(const char* value); + void set_json_name(const char* value, size_t size); + std::string* mutable_json_name(); + std::string* release_json_name(); + void set_allocated_json_name(std::string* json_name); + private: + const std::string& _internal_json_name() const; + void _internal_set_json_name(const std::string& value); + std::string* _internal_mutable_json_name(); + public: + + // optional .google.protobuf.FieldOptions options = 8; + bool has_options() const; + private: + bool _internal_has_options() const; + public: + void clear_options(); + const PROTOBUF_NAMESPACE_ID::FieldOptions& options() const; + PROTOBUF_NAMESPACE_ID::FieldOptions* release_options(); + PROTOBUF_NAMESPACE_ID::FieldOptions* mutable_options(); + void set_allocated_options(PROTOBUF_NAMESPACE_ID::FieldOptions* options); + private: + const PROTOBUF_NAMESPACE_ID::FieldOptions& _internal_options() const; + PROTOBUF_NAMESPACE_ID::FieldOptions* _internal_mutable_options(); + public: + void unsafe_arena_set_allocated_options( + PROTOBUF_NAMESPACE_ID::FieldOptions* options); + PROTOBUF_NAMESPACE_ID::FieldOptions* unsafe_arena_release_options(); + + // optional int32 number = 3; + bool has_number() const; + private: + bool _internal_has_number() const; + public: + void clear_number(); + ::PROTOBUF_NAMESPACE_ID::int32 number() const; + void set_number(::PROTOBUF_NAMESPACE_ID::int32 value); + private: + ::PROTOBUF_NAMESPACE_ID::int32 _internal_number() const; + void _internal_set_number(::PROTOBUF_NAMESPACE_ID::int32 value); + public: + + // optional int32 oneof_index = 9; + bool has_oneof_index() const; + private: + bool _internal_has_oneof_index() const; + public: + void clear_oneof_index(); + ::PROTOBUF_NAMESPACE_ID::int32 oneof_index() const; + void set_oneof_index(::PROTOBUF_NAMESPACE_ID::int32 value); + private: + ::PROTOBUF_NAMESPACE_ID::int32 _internal_oneof_index() const; + void _internal_set_oneof_index(::PROTOBUF_NAMESPACE_ID::int32 value); + public: + + // optional bool proto3_optional = 17; + bool has_proto3_optional() const; + private: + bool _internal_has_proto3_optional() const; + public: + void clear_proto3_optional(); + bool proto3_optional() const; + void set_proto3_optional(bool value); + private: + bool _internal_proto3_optional() const; + void _internal_set_proto3_optional(bool value); + public: + + // optional .google.protobuf.FieldDescriptorProto.Label label = 4; + bool has_label() const; + private: + bool _internal_has_label() const; + public: + void clear_label(); + PROTOBUF_NAMESPACE_ID::FieldDescriptorProto_Label label() const; + void set_label(PROTOBUF_NAMESPACE_ID::FieldDescriptorProto_Label value); + private: + PROTOBUF_NAMESPACE_ID::FieldDescriptorProto_Label _internal_label() const; + void _internal_set_label(PROTOBUF_NAMESPACE_ID::FieldDescriptorProto_Label value); + public: + + // optional .google.protobuf.FieldDescriptorProto.Type type = 5; + bool has_type() const; + private: + bool _internal_has_type() const; + public: + void clear_type(); + PROTOBUF_NAMESPACE_ID::FieldDescriptorProto_Type type() const; + void set_type(PROTOBUF_NAMESPACE_ID::FieldDescriptorProto_Type value); + private: + PROTOBUF_NAMESPACE_ID::FieldDescriptorProto_Type _internal_type() const; + void _internal_set_type(PROTOBUF_NAMESPACE_ID::FieldDescriptorProto_Type value); + public: + + // @@protoc_insertion_point(class_scope:google.protobuf.FieldDescriptorProto) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr name_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr extendee_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr type_name_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr default_value_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr json_name_; + PROTOBUF_NAMESPACE_ID::FieldOptions* options_; + ::PROTOBUF_NAMESPACE_ID::int32 number_; + ::PROTOBUF_NAMESPACE_ID::int32 oneof_index_; + bool proto3_optional_; + int label_; + int type_; + friend struct ::TableStruct_google_2fprotobuf_2fdescriptor_2eproto; +}; +// ------------------------------------------------------------------- + +class PROTOBUF_EXPORT OneofDescriptorProto PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:google.protobuf.OneofDescriptorProto) */ { + public: + inline OneofDescriptorProto() : OneofDescriptorProto(nullptr) {} + virtual ~OneofDescriptorProto(); + + OneofDescriptorProto(const OneofDescriptorProto& from); + OneofDescriptorProto(OneofDescriptorProto&& from) noexcept + : OneofDescriptorProto() { + *this = ::std::move(from); + } + + inline OneofDescriptorProto& operator=(const OneofDescriptorProto& from) { + CopyFrom(from); + return *this; + } + inline OneofDescriptorProto& operator=(OneofDescriptorProto&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const OneofDescriptorProto& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const OneofDescriptorProto* internal_default_instance() { + return reinterpret_cast( + &_OneofDescriptorProto_default_instance_); + } + static constexpr int kIndexInFileMessages = + 7; + + friend void swap(OneofDescriptorProto& a, OneofDescriptorProto& b) { + a.Swap(&b); + } + inline void Swap(OneofDescriptorProto* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(OneofDescriptorProto* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline OneofDescriptorProto* New() const final { + return CreateMaybeMessage(nullptr); + } + + OneofDescriptorProto* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const OneofDescriptorProto& from); + void MergeFrom(const OneofDescriptorProto& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(OneofDescriptorProto* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "google.protobuf.OneofDescriptorProto"; + } + protected: + explicit OneofDescriptorProto(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_google_2fprotobuf_2fdescriptor_2eproto); + return ::descriptor_table_google_2fprotobuf_2fdescriptor_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kNameFieldNumber = 1, + kOptionsFieldNumber = 2, + }; + // optional string name = 1; + bool has_name() const; + private: + bool _internal_has_name() const; + public: + void clear_name(); + const std::string& name() const; + void set_name(const std::string& value); + void set_name(std::string&& value); + void set_name(const char* value); + void set_name(const char* value, size_t size); + std::string* mutable_name(); + std::string* release_name(); + void set_allocated_name(std::string* name); + private: + const std::string& _internal_name() const; + void _internal_set_name(const std::string& value); + std::string* _internal_mutable_name(); + public: + + // optional .google.protobuf.OneofOptions options = 2; + bool has_options() const; + private: + bool _internal_has_options() const; + public: + void clear_options(); + const PROTOBUF_NAMESPACE_ID::OneofOptions& options() const; + PROTOBUF_NAMESPACE_ID::OneofOptions* release_options(); + PROTOBUF_NAMESPACE_ID::OneofOptions* mutable_options(); + void set_allocated_options(PROTOBUF_NAMESPACE_ID::OneofOptions* options); + private: + const PROTOBUF_NAMESPACE_ID::OneofOptions& _internal_options() const; + PROTOBUF_NAMESPACE_ID::OneofOptions* _internal_mutable_options(); + public: + void unsafe_arena_set_allocated_options( + PROTOBUF_NAMESPACE_ID::OneofOptions* options); + PROTOBUF_NAMESPACE_ID::OneofOptions* unsafe_arena_release_options(); + + // @@protoc_insertion_point(class_scope:google.protobuf.OneofDescriptorProto) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr name_; + PROTOBUF_NAMESPACE_ID::OneofOptions* options_; + friend struct ::TableStruct_google_2fprotobuf_2fdescriptor_2eproto; +}; +// ------------------------------------------------------------------- + +class PROTOBUF_EXPORT EnumDescriptorProto_EnumReservedRange PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:google.protobuf.EnumDescriptorProto.EnumReservedRange) */ { + public: + inline EnumDescriptorProto_EnumReservedRange() : EnumDescriptorProto_EnumReservedRange(nullptr) {} + virtual ~EnumDescriptorProto_EnumReservedRange(); + + EnumDescriptorProto_EnumReservedRange(const EnumDescriptorProto_EnumReservedRange& from); + EnumDescriptorProto_EnumReservedRange(EnumDescriptorProto_EnumReservedRange&& from) noexcept + : EnumDescriptorProto_EnumReservedRange() { + *this = ::std::move(from); + } + + inline EnumDescriptorProto_EnumReservedRange& operator=(const EnumDescriptorProto_EnumReservedRange& from) { + CopyFrom(from); + return *this; + } + inline EnumDescriptorProto_EnumReservedRange& operator=(EnumDescriptorProto_EnumReservedRange&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const EnumDescriptorProto_EnumReservedRange& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const EnumDescriptorProto_EnumReservedRange* internal_default_instance() { + return reinterpret_cast( + &_EnumDescriptorProto_EnumReservedRange_default_instance_); + } + static constexpr int kIndexInFileMessages = + 8; + + friend void swap(EnumDescriptorProto_EnumReservedRange& a, EnumDescriptorProto_EnumReservedRange& b) { + a.Swap(&b); + } + inline void Swap(EnumDescriptorProto_EnumReservedRange* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(EnumDescriptorProto_EnumReservedRange* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline EnumDescriptorProto_EnumReservedRange* New() const final { + return CreateMaybeMessage(nullptr); + } + + EnumDescriptorProto_EnumReservedRange* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const EnumDescriptorProto_EnumReservedRange& from); + void MergeFrom(const EnumDescriptorProto_EnumReservedRange& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(EnumDescriptorProto_EnumReservedRange* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "google.protobuf.EnumDescriptorProto.EnumReservedRange"; + } + protected: + explicit EnumDescriptorProto_EnumReservedRange(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_google_2fprotobuf_2fdescriptor_2eproto); + return ::descriptor_table_google_2fprotobuf_2fdescriptor_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kStartFieldNumber = 1, + kEndFieldNumber = 2, + }; + // optional int32 start = 1; + bool has_start() const; + private: + bool _internal_has_start() const; + public: + void clear_start(); + ::PROTOBUF_NAMESPACE_ID::int32 start() const; + void set_start(::PROTOBUF_NAMESPACE_ID::int32 value); + private: + ::PROTOBUF_NAMESPACE_ID::int32 _internal_start() const; + void _internal_set_start(::PROTOBUF_NAMESPACE_ID::int32 value); + public: + + // optional int32 end = 2; + bool has_end() const; + private: + bool _internal_has_end() const; + public: + void clear_end(); + ::PROTOBUF_NAMESPACE_ID::int32 end() const; + void set_end(::PROTOBUF_NAMESPACE_ID::int32 value); + private: + ::PROTOBUF_NAMESPACE_ID::int32 _internal_end() const; + void _internal_set_end(::PROTOBUF_NAMESPACE_ID::int32 value); + public: + + // @@protoc_insertion_point(class_scope:google.protobuf.EnumDescriptorProto.EnumReservedRange) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::int32 start_; + ::PROTOBUF_NAMESPACE_ID::int32 end_; + friend struct ::TableStruct_google_2fprotobuf_2fdescriptor_2eproto; +}; +// ------------------------------------------------------------------- + +class PROTOBUF_EXPORT EnumDescriptorProto PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:google.protobuf.EnumDescriptorProto) */ { + public: + inline EnumDescriptorProto() : EnumDescriptorProto(nullptr) {} + virtual ~EnumDescriptorProto(); + + EnumDescriptorProto(const EnumDescriptorProto& from); + EnumDescriptorProto(EnumDescriptorProto&& from) noexcept + : EnumDescriptorProto() { + *this = ::std::move(from); + } + + inline EnumDescriptorProto& operator=(const EnumDescriptorProto& from) { + CopyFrom(from); + return *this; + } + inline EnumDescriptorProto& operator=(EnumDescriptorProto&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const EnumDescriptorProto& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const EnumDescriptorProto* internal_default_instance() { + return reinterpret_cast( + &_EnumDescriptorProto_default_instance_); + } + static constexpr int kIndexInFileMessages = + 9; + + friend void swap(EnumDescriptorProto& a, EnumDescriptorProto& b) { + a.Swap(&b); + } + inline void Swap(EnumDescriptorProto* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(EnumDescriptorProto* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline EnumDescriptorProto* New() const final { + return CreateMaybeMessage(nullptr); + } + + EnumDescriptorProto* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const EnumDescriptorProto& from); + void MergeFrom(const EnumDescriptorProto& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(EnumDescriptorProto* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "google.protobuf.EnumDescriptorProto"; + } + protected: + explicit EnumDescriptorProto(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_google_2fprotobuf_2fdescriptor_2eproto); + return ::descriptor_table_google_2fprotobuf_2fdescriptor_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + typedef EnumDescriptorProto_EnumReservedRange EnumReservedRange; + + // accessors ------------------------------------------------------- + + enum : int { + kValueFieldNumber = 2, + kReservedRangeFieldNumber = 4, + kReservedNameFieldNumber = 5, + kNameFieldNumber = 1, + kOptionsFieldNumber = 3, + }; + // repeated .google.protobuf.EnumValueDescriptorProto value = 2; + int value_size() const; + private: + int _internal_value_size() const; + public: + void clear_value(); + PROTOBUF_NAMESPACE_ID::EnumValueDescriptorProto* mutable_value(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::EnumValueDescriptorProto >* + mutable_value(); + private: + const PROTOBUF_NAMESPACE_ID::EnumValueDescriptorProto& _internal_value(int index) const; + PROTOBUF_NAMESPACE_ID::EnumValueDescriptorProto* _internal_add_value(); + public: + const PROTOBUF_NAMESPACE_ID::EnumValueDescriptorProto& value(int index) const; + PROTOBUF_NAMESPACE_ID::EnumValueDescriptorProto* add_value(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::EnumValueDescriptorProto >& + value() const; + + // repeated .google.protobuf.EnumDescriptorProto.EnumReservedRange reserved_range = 4; + int reserved_range_size() const; + private: + int _internal_reserved_range_size() const; + public: + void clear_reserved_range(); + PROTOBUF_NAMESPACE_ID::EnumDescriptorProto_EnumReservedRange* mutable_reserved_range(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::EnumDescriptorProto_EnumReservedRange >* + mutable_reserved_range(); + private: + const PROTOBUF_NAMESPACE_ID::EnumDescriptorProto_EnumReservedRange& _internal_reserved_range(int index) const; + PROTOBUF_NAMESPACE_ID::EnumDescriptorProto_EnumReservedRange* _internal_add_reserved_range(); + public: + const PROTOBUF_NAMESPACE_ID::EnumDescriptorProto_EnumReservedRange& reserved_range(int index) const; + PROTOBUF_NAMESPACE_ID::EnumDescriptorProto_EnumReservedRange* add_reserved_range(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::EnumDescriptorProto_EnumReservedRange >& + reserved_range() const; + + // repeated string reserved_name = 5; + int reserved_name_size() const; + private: + int _internal_reserved_name_size() const; + public: + void clear_reserved_name(); + const std::string& reserved_name(int index) const; + std::string* mutable_reserved_name(int index); + void set_reserved_name(int index, const std::string& value); + void set_reserved_name(int index, std::string&& value); + void set_reserved_name(int index, const char* value); + void set_reserved_name(int index, const char* value, size_t size); + std::string* add_reserved_name(); + void add_reserved_name(const std::string& value); + void add_reserved_name(std::string&& value); + void add_reserved_name(const char* value); + void add_reserved_name(const char* value, size_t size); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField& reserved_name() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField* mutable_reserved_name(); + private: + const std::string& _internal_reserved_name(int index) const; + std::string* _internal_add_reserved_name(); + public: + + // optional string name = 1; + bool has_name() const; + private: + bool _internal_has_name() const; + public: + void clear_name(); + const std::string& name() const; + void set_name(const std::string& value); + void set_name(std::string&& value); + void set_name(const char* value); + void set_name(const char* value, size_t size); + std::string* mutable_name(); + std::string* release_name(); + void set_allocated_name(std::string* name); + private: + const std::string& _internal_name() const; + void _internal_set_name(const std::string& value); + std::string* _internal_mutable_name(); + public: + + // optional .google.protobuf.EnumOptions options = 3; + bool has_options() const; + private: + bool _internal_has_options() const; + public: + void clear_options(); + const PROTOBUF_NAMESPACE_ID::EnumOptions& options() const; + PROTOBUF_NAMESPACE_ID::EnumOptions* release_options(); + PROTOBUF_NAMESPACE_ID::EnumOptions* mutable_options(); + void set_allocated_options(PROTOBUF_NAMESPACE_ID::EnumOptions* options); + private: + const PROTOBUF_NAMESPACE_ID::EnumOptions& _internal_options() const; + PROTOBUF_NAMESPACE_ID::EnumOptions* _internal_mutable_options(); + public: + void unsafe_arena_set_allocated_options( + PROTOBUF_NAMESPACE_ID::EnumOptions* options); + PROTOBUF_NAMESPACE_ID::EnumOptions* unsafe_arena_release_options(); + + // @@protoc_insertion_point(class_scope:google.protobuf.EnumDescriptorProto) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::EnumValueDescriptorProto > value_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::EnumDescriptorProto_EnumReservedRange > reserved_range_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField reserved_name_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr name_; + PROTOBUF_NAMESPACE_ID::EnumOptions* options_; + friend struct ::TableStruct_google_2fprotobuf_2fdescriptor_2eproto; +}; +// ------------------------------------------------------------------- + +class PROTOBUF_EXPORT EnumValueDescriptorProto PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:google.protobuf.EnumValueDescriptorProto) */ { + public: + inline EnumValueDescriptorProto() : EnumValueDescriptorProto(nullptr) {} + virtual ~EnumValueDescriptorProto(); + + EnumValueDescriptorProto(const EnumValueDescriptorProto& from); + EnumValueDescriptorProto(EnumValueDescriptorProto&& from) noexcept + : EnumValueDescriptorProto() { + *this = ::std::move(from); + } + + inline EnumValueDescriptorProto& operator=(const EnumValueDescriptorProto& from) { + CopyFrom(from); + return *this; + } + inline EnumValueDescriptorProto& operator=(EnumValueDescriptorProto&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const EnumValueDescriptorProto& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const EnumValueDescriptorProto* internal_default_instance() { + return reinterpret_cast( + &_EnumValueDescriptorProto_default_instance_); + } + static constexpr int kIndexInFileMessages = + 10; + + friend void swap(EnumValueDescriptorProto& a, EnumValueDescriptorProto& b) { + a.Swap(&b); + } + inline void Swap(EnumValueDescriptorProto* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(EnumValueDescriptorProto* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline EnumValueDescriptorProto* New() const final { + return CreateMaybeMessage(nullptr); + } + + EnumValueDescriptorProto* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const EnumValueDescriptorProto& from); + void MergeFrom(const EnumValueDescriptorProto& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(EnumValueDescriptorProto* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "google.protobuf.EnumValueDescriptorProto"; + } + protected: + explicit EnumValueDescriptorProto(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_google_2fprotobuf_2fdescriptor_2eproto); + return ::descriptor_table_google_2fprotobuf_2fdescriptor_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kNameFieldNumber = 1, + kOptionsFieldNumber = 3, + kNumberFieldNumber = 2, + }; + // optional string name = 1; + bool has_name() const; + private: + bool _internal_has_name() const; + public: + void clear_name(); + const std::string& name() const; + void set_name(const std::string& value); + void set_name(std::string&& value); + void set_name(const char* value); + void set_name(const char* value, size_t size); + std::string* mutable_name(); + std::string* release_name(); + void set_allocated_name(std::string* name); + private: + const std::string& _internal_name() const; + void _internal_set_name(const std::string& value); + std::string* _internal_mutable_name(); + public: + + // optional .google.protobuf.EnumValueOptions options = 3; + bool has_options() const; + private: + bool _internal_has_options() const; + public: + void clear_options(); + const PROTOBUF_NAMESPACE_ID::EnumValueOptions& options() const; + PROTOBUF_NAMESPACE_ID::EnumValueOptions* release_options(); + PROTOBUF_NAMESPACE_ID::EnumValueOptions* mutable_options(); + void set_allocated_options(PROTOBUF_NAMESPACE_ID::EnumValueOptions* options); + private: + const PROTOBUF_NAMESPACE_ID::EnumValueOptions& _internal_options() const; + PROTOBUF_NAMESPACE_ID::EnumValueOptions* _internal_mutable_options(); + public: + void unsafe_arena_set_allocated_options( + PROTOBUF_NAMESPACE_ID::EnumValueOptions* options); + PROTOBUF_NAMESPACE_ID::EnumValueOptions* unsafe_arena_release_options(); + + // optional int32 number = 2; + bool has_number() const; + private: + bool _internal_has_number() const; + public: + void clear_number(); + ::PROTOBUF_NAMESPACE_ID::int32 number() const; + void set_number(::PROTOBUF_NAMESPACE_ID::int32 value); + private: + ::PROTOBUF_NAMESPACE_ID::int32 _internal_number() const; + void _internal_set_number(::PROTOBUF_NAMESPACE_ID::int32 value); + public: + + // @@protoc_insertion_point(class_scope:google.protobuf.EnumValueDescriptorProto) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr name_; + PROTOBUF_NAMESPACE_ID::EnumValueOptions* options_; + ::PROTOBUF_NAMESPACE_ID::int32 number_; + friend struct ::TableStruct_google_2fprotobuf_2fdescriptor_2eproto; +}; +// ------------------------------------------------------------------- + +class PROTOBUF_EXPORT ServiceDescriptorProto PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:google.protobuf.ServiceDescriptorProto) */ { + public: + inline ServiceDescriptorProto() : ServiceDescriptorProto(nullptr) {} + virtual ~ServiceDescriptorProto(); + + ServiceDescriptorProto(const ServiceDescriptorProto& from); + ServiceDescriptorProto(ServiceDescriptorProto&& from) noexcept + : ServiceDescriptorProto() { + *this = ::std::move(from); + } + + inline ServiceDescriptorProto& operator=(const ServiceDescriptorProto& from) { + CopyFrom(from); + return *this; + } + inline ServiceDescriptorProto& operator=(ServiceDescriptorProto&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const ServiceDescriptorProto& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const ServiceDescriptorProto* internal_default_instance() { + return reinterpret_cast( + &_ServiceDescriptorProto_default_instance_); + } + static constexpr int kIndexInFileMessages = + 11; + + friend void swap(ServiceDescriptorProto& a, ServiceDescriptorProto& b) { + a.Swap(&b); + } + inline void Swap(ServiceDescriptorProto* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(ServiceDescriptorProto* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline ServiceDescriptorProto* New() const final { + return CreateMaybeMessage(nullptr); + } + + ServiceDescriptorProto* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const ServiceDescriptorProto& from); + void MergeFrom(const ServiceDescriptorProto& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(ServiceDescriptorProto* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "google.protobuf.ServiceDescriptorProto"; + } + protected: + explicit ServiceDescriptorProto(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_google_2fprotobuf_2fdescriptor_2eproto); + return ::descriptor_table_google_2fprotobuf_2fdescriptor_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kMethodFieldNumber = 2, + kNameFieldNumber = 1, + kOptionsFieldNumber = 3, + }; + // repeated .google.protobuf.MethodDescriptorProto method = 2; + int method_size() const; + private: + int _internal_method_size() const; + public: + void clear_method(); + PROTOBUF_NAMESPACE_ID::MethodDescriptorProto* mutable_method(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::MethodDescriptorProto >* + mutable_method(); + private: + const PROTOBUF_NAMESPACE_ID::MethodDescriptorProto& _internal_method(int index) const; + PROTOBUF_NAMESPACE_ID::MethodDescriptorProto* _internal_add_method(); + public: + const PROTOBUF_NAMESPACE_ID::MethodDescriptorProto& method(int index) const; + PROTOBUF_NAMESPACE_ID::MethodDescriptorProto* add_method(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::MethodDescriptorProto >& + method() const; + + // optional string name = 1; + bool has_name() const; + private: + bool _internal_has_name() const; + public: + void clear_name(); + const std::string& name() const; + void set_name(const std::string& value); + void set_name(std::string&& value); + void set_name(const char* value); + void set_name(const char* value, size_t size); + std::string* mutable_name(); + std::string* release_name(); + void set_allocated_name(std::string* name); + private: + const std::string& _internal_name() const; + void _internal_set_name(const std::string& value); + std::string* _internal_mutable_name(); + public: + + // optional .google.protobuf.ServiceOptions options = 3; + bool has_options() const; + private: + bool _internal_has_options() const; + public: + void clear_options(); + const PROTOBUF_NAMESPACE_ID::ServiceOptions& options() const; + PROTOBUF_NAMESPACE_ID::ServiceOptions* release_options(); + PROTOBUF_NAMESPACE_ID::ServiceOptions* mutable_options(); + void set_allocated_options(PROTOBUF_NAMESPACE_ID::ServiceOptions* options); + private: + const PROTOBUF_NAMESPACE_ID::ServiceOptions& _internal_options() const; + PROTOBUF_NAMESPACE_ID::ServiceOptions* _internal_mutable_options(); + public: + void unsafe_arena_set_allocated_options( + PROTOBUF_NAMESPACE_ID::ServiceOptions* options); + PROTOBUF_NAMESPACE_ID::ServiceOptions* unsafe_arena_release_options(); + + // @@protoc_insertion_point(class_scope:google.protobuf.ServiceDescriptorProto) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::MethodDescriptorProto > method_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr name_; + PROTOBUF_NAMESPACE_ID::ServiceOptions* options_; + friend struct ::TableStruct_google_2fprotobuf_2fdescriptor_2eproto; +}; +// ------------------------------------------------------------------- + +class PROTOBUF_EXPORT MethodDescriptorProto PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:google.protobuf.MethodDescriptorProto) */ { + public: + inline MethodDescriptorProto() : MethodDescriptorProto(nullptr) {} + virtual ~MethodDescriptorProto(); + + MethodDescriptorProto(const MethodDescriptorProto& from); + MethodDescriptorProto(MethodDescriptorProto&& from) noexcept + : MethodDescriptorProto() { + *this = ::std::move(from); + } + + inline MethodDescriptorProto& operator=(const MethodDescriptorProto& from) { + CopyFrom(from); + return *this; + } + inline MethodDescriptorProto& operator=(MethodDescriptorProto&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const MethodDescriptorProto& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const MethodDescriptorProto* internal_default_instance() { + return reinterpret_cast( + &_MethodDescriptorProto_default_instance_); + } + static constexpr int kIndexInFileMessages = + 12; + + friend void swap(MethodDescriptorProto& a, MethodDescriptorProto& b) { + a.Swap(&b); + } + inline void Swap(MethodDescriptorProto* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(MethodDescriptorProto* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline MethodDescriptorProto* New() const final { + return CreateMaybeMessage(nullptr); + } + + MethodDescriptorProto* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const MethodDescriptorProto& from); + void MergeFrom(const MethodDescriptorProto& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(MethodDescriptorProto* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "google.protobuf.MethodDescriptorProto"; + } + protected: + explicit MethodDescriptorProto(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_google_2fprotobuf_2fdescriptor_2eproto); + return ::descriptor_table_google_2fprotobuf_2fdescriptor_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kNameFieldNumber = 1, + kInputTypeFieldNumber = 2, + kOutputTypeFieldNumber = 3, + kOptionsFieldNumber = 4, + kClientStreamingFieldNumber = 5, + kServerStreamingFieldNumber = 6, + }; + // optional string name = 1; + bool has_name() const; + private: + bool _internal_has_name() const; + public: + void clear_name(); + const std::string& name() const; + void set_name(const std::string& value); + void set_name(std::string&& value); + void set_name(const char* value); + void set_name(const char* value, size_t size); + std::string* mutable_name(); + std::string* release_name(); + void set_allocated_name(std::string* name); + private: + const std::string& _internal_name() const; + void _internal_set_name(const std::string& value); + std::string* _internal_mutable_name(); + public: + + // optional string input_type = 2; + bool has_input_type() const; + private: + bool _internal_has_input_type() const; + public: + void clear_input_type(); + const std::string& input_type() const; + void set_input_type(const std::string& value); + void set_input_type(std::string&& value); + void set_input_type(const char* value); + void set_input_type(const char* value, size_t size); + std::string* mutable_input_type(); + std::string* release_input_type(); + void set_allocated_input_type(std::string* input_type); + private: + const std::string& _internal_input_type() const; + void _internal_set_input_type(const std::string& value); + std::string* _internal_mutable_input_type(); + public: + + // optional string output_type = 3; + bool has_output_type() const; + private: + bool _internal_has_output_type() const; + public: + void clear_output_type(); + const std::string& output_type() const; + void set_output_type(const std::string& value); + void set_output_type(std::string&& value); + void set_output_type(const char* value); + void set_output_type(const char* value, size_t size); + std::string* mutable_output_type(); + std::string* release_output_type(); + void set_allocated_output_type(std::string* output_type); + private: + const std::string& _internal_output_type() const; + void _internal_set_output_type(const std::string& value); + std::string* _internal_mutable_output_type(); + public: + + // optional .google.protobuf.MethodOptions options = 4; + bool has_options() const; + private: + bool _internal_has_options() const; + public: + void clear_options(); + const PROTOBUF_NAMESPACE_ID::MethodOptions& options() const; + PROTOBUF_NAMESPACE_ID::MethodOptions* release_options(); + PROTOBUF_NAMESPACE_ID::MethodOptions* mutable_options(); + void set_allocated_options(PROTOBUF_NAMESPACE_ID::MethodOptions* options); + private: + const PROTOBUF_NAMESPACE_ID::MethodOptions& _internal_options() const; + PROTOBUF_NAMESPACE_ID::MethodOptions* _internal_mutable_options(); + public: + void unsafe_arena_set_allocated_options( + PROTOBUF_NAMESPACE_ID::MethodOptions* options); + PROTOBUF_NAMESPACE_ID::MethodOptions* unsafe_arena_release_options(); + + // optional bool client_streaming = 5 [default = false]; + bool has_client_streaming() const; + private: + bool _internal_has_client_streaming() const; + public: + void clear_client_streaming(); + bool client_streaming() const; + void set_client_streaming(bool value); + private: + bool _internal_client_streaming() const; + void _internal_set_client_streaming(bool value); + public: + + // optional bool server_streaming = 6 [default = false]; + bool has_server_streaming() const; + private: + bool _internal_has_server_streaming() const; + public: + void clear_server_streaming(); + bool server_streaming() const; + void set_server_streaming(bool value); + private: + bool _internal_server_streaming() const; + void _internal_set_server_streaming(bool value); + public: + + // @@protoc_insertion_point(class_scope:google.protobuf.MethodDescriptorProto) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr name_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr input_type_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr output_type_; + PROTOBUF_NAMESPACE_ID::MethodOptions* options_; + bool client_streaming_; + bool server_streaming_; + friend struct ::TableStruct_google_2fprotobuf_2fdescriptor_2eproto; +}; +// ------------------------------------------------------------------- + +class PROTOBUF_EXPORT FileOptions PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:google.protobuf.FileOptions) */ { + public: + inline FileOptions() : FileOptions(nullptr) {} + virtual ~FileOptions(); + + FileOptions(const FileOptions& from); + FileOptions(FileOptions&& from) noexcept + : FileOptions() { + *this = ::std::move(from); + } + + inline FileOptions& operator=(const FileOptions& from) { + CopyFrom(from); + return *this; + } + inline FileOptions& operator=(FileOptions&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const FileOptions& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const FileOptions* internal_default_instance() { + return reinterpret_cast( + &_FileOptions_default_instance_); + } + static constexpr int kIndexInFileMessages = + 13; + + friend void swap(FileOptions& a, FileOptions& b) { + a.Swap(&b); + } + inline void Swap(FileOptions* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(FileOptions* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline FileOptions* New() const final { + return CreateMaybeMessage(nullptr); + } + + FileOptions* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const FileOptions& from); + void MergeFrom(const FileOptions& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(FileOptions* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "google.protobuf.FileOptions"; + } + protected: + explicit FileOptions(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_google_2fprotobuf_2fdescriptor_2eproto); + return ::descriptor_table_google_2fprotobuf_2fdescriptor_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + typedef FileOptions_OptimizeMode OptimizeMode; + static constexpr OptimizeMode SPEED = + FileOptions_OptimizeMode_SPEED; + static constexpr OptimizeMode CODE_SIZE = + FileOptions_OptimizeMode_CODE_SIZE; + static constexpr OptimizeMode LITE_RUNTIME = + FileOptions_OptimizeMode_LITE_RUNTIME; + static inline bool OptimizeMode_IsValid(int value) { + return FileOptions_OptimizeMode_IsValid(value); + } + static constexpr OptimizeMode OptimizeMode_MIN = + FileOptions_OptimizeMode_OptimizeMode_MIN; + static constexpr OptimizeMode OptimizeMode_MAX = + FileOptions_OptimizeMode_OptimizeMode_MAX; + static constexpr int OptimizeMode_ARRAYSIZE = + FileOptions_OptimizeMode_OptimizeMode_ARRAYSIZE; + static inline const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* + OptimizeMode_descriptor() { + return FileOptions_OptimizeMode_descriptor(); + } + template + static inline const std::string& OptimizeMode_Name(T enum_t_value) { + static_assert(::std::is_same::value || + ::std::is_integral::value, + "Incorrect type passed to function OptimizeMode_Name."); + return FileOptions_OptimizeMode_Name(enum_t_value); + } + static inline bool OptimizeMode_Parse(::PROTOBUF_NAMESPACE_ID::ConstStringParam name, + OptimizeMode* value) { + return FileOptions_OptimizeMode_Parse(name, value); + } + + // accessors ------------------------------------------------------- + + enum : int { + kUninterpretedOptionFieldNumber = 999, + kJavaPackageFieldNumber = 1, + kJavaOuterClassnameFieldNumber = 8, + kGoPackageFieldNumber = 11, + kObjcClassPrefixFieldNumber = 36, + kCsharpNamespaceFieldNumber = 37, + kSwiftPrefixFieldNumber = 39, + kPhpClassPrefixFieldNumber = 40, + kPhpNamespaceFieldNumber = 41, + kPhpMetadataNamespaceFieldNumber = 44, + kRubyPackageFieldNumber = 45, + kJavaMultipleFilesFieldNumber = 10, + kJavaGenerateEqualsAndHashFieldNumber = 20, + kJavaStringCheckUtf8FieldNumber = 27, + kCcGenericServicesFieldNumber = 16, + kJavaGenericServicesFieldNumber = 17, + kPyGenericServicesFieldNumber = 18, + kPhpGenericServicesFieldNumber = 42, + kDeprecatedFieldNumber = 23, + kOptimizeForFieldNumber = 9, + kCcEnableArenasFieldNumber = 31, + }; + // repeated .google.protobuf.UninterpretedOption uninterpreted_option = 999; + int uninterpreted_option_size() const; + private: + int _internal_uninterpreted_option_size() const; + public: + void clear_uninterpreted_option(); + PROTOBUF_NAMESPACE_ID::UninterpretedOption* mutable_uninterpreted_option(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::UninterpretedOption >* + mutable_uninterpreted_option(); + private: + const PROTOBUF_NAMESPACE_ID::UninterpretedOption& _internal_uninterpreted_option(int index) const; + PROTOBUF_NAMESPACE_ID::UninterpretedOption* _internal_add_uninterpreted_option(); + public: + const PROTOBUF_NAMESPACE_ID::UninterpretedOption& uninterpreted_option(int index) const; + PROTOBUF_NAMESPACE_ID::UninterpretedOption* add_uninterpreted_option(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::UninterpretedOption >& + uninterpreted_option() const; + + // optional string java_package = 1; + bool has_java_package() const; + private: + bool _internal_has_java_package() const; + public: + void clear_java_package(); + const std::string& java_package() const; + void set_java_package(const std::string& value); + void set_java_package(std::string&& value); + void set_java_package(const char* value); + void set_java_package(const char* value, size_t size); + std::string* mutable_java_package(); + std::string* release_java_package(); + void set_allocated_java_package(std::string* java_package); + private: + const std::string& _internal_java_package() const; + void _internal_set_java_package(const std::string& value); + std::string* _internal_mutable_java_package(); + public: + + // optional string java_outer_classname = 8; + bool has_java_outer_classname() const; + private: + bool _internal_has_java_outer_classname() const; + public: + void clear_java_outer_classname(); + const std::string& java_outer_classname() const; + void set_java_outer_classname(const std::string& value); + void set_java_outer_classname(std::string&& value); + void set_java_outer_classname(const char* value); + void set_java_outer_classname(const char* value, size_t size); + std::string* mutable_java_outer_classname(); + std::string* release_java_outer_classname(); + void set_allocated_java_outer_classname(std::string* java_outer_classname); + private: + const std::string& _internal_java_outer_classname() const; + void _internal_set_java_outer_classname(const std::string& value); + std::string* _internal_mutable_java_outer_classname(); + public: + + // optional string go_package = 11; + bool has_go_package() const; + private: + bool _internal_has_go_package() const; + public: + void clear_go_package(); + const std::string& go_package() const; + void set_go_package(const std::string& value); + void set_go_package(std::string&& value); + void set_go_package(const char* value); + void set_go_package(const char* value, size_t size); + std::string* mutable_go_package(); + std::string* release_go_package(); + void set_allocated_go_package(std::string* go_package); + private: + const std::string& _internal_go_package() const; + void _internal_set_go_package(const std::string& value); + std::string* _internal_mutable_go_package(); + public: + + // optional string objc_class_prefix = 36; + bool has_objc_class_prefix() const; + private: + bool _internal_has_objc_class_prefix() const; + public: + void clear_objc_class_prefix(); + const std::string& objc_class_prefix() const; + void set_objc_class_prefix(const std::string& value); + void set_objc_class_prefix(std::string&& value); + void set_objc_class_prefix(const char* value); + void set_objc_class_prefix(const char* value, size_t size); + std::string* mutable_objc_class_prefix(); + std::string* release_objc_class_prefix(); + void set_allocated_objc_class_prefix(std::string* objc_class_prefix); + private: + const std::string& _internal_objc_class_prefix() const; + void _internal_set_objc_class_prefix(const std::string& value); + std::string* _internal_mutable_objc_class_prefix(); + public: + + // optional string csharp_namespace = 37; + bool has_csharp_namespace() const; + private: + bool _internal_has_csharp_namespace() const; + public: + void clear_csharp_namespace(); + const std::string& csharp_namespace() const; + void set_csharp_namespace(const std::string& value); + void set_csharp_namespace(std::string&& value); + void set_csharp_namespace(const char* value); + void set_csharp_namespace(const char* value, size_t size); + std::string* mutable_csharp_namespace(); + std::string* release_csharp_namespace(); + void set_allocated_csharp_namespace(std::string* csharp_namespace); + private: + const std::string& _internal_csharp_namespace() const; + void _internal_set_csharp_namespace(const std::string& value); + std::string* _internal_mutable_csharp_namespace(); + public: + + // optional string swift_prefix = 39; + bool has_swift_prefix() const; + private: + bool _internal_has_swift_prefix() const; + public: + void clear_swift_prefix(); + const std::string& swift_prefix() const; + void set_swift_prefix(const std::string& value); + void set_swift_prefix(std::string&& value); + void set_swift_prefix(const char* value); + void set_swift_prefix(const char* value, size_t size); + std::string* mutable_swift_prefix(); + std::string* release_swift_prefix(); + void set_allocated_swift_prefix(std::string* swift_prefix); + private: + const std::string& _internal_swift_prefix() const; + void _internal_set_swift_prefix(const std::string& value); + std::string* _internal_mutable_swift_prefix(); + public: + + // optional string php_class_prefix = 40; + bool has_php_class_prefix() const; + private: + bool _internal_has_php_class_prefix() const; + public: + void clear_php_class_prefix(); + const std::string& php_class_prefix() const; + void set_php_class_prefix(const std::string& value); + void set_php_class_prefix(std::string&& value); + void set_php_class_prefix(const char* value); + void set_php_class_prefix(const char* value, size_t size); + std::string* mutable_php_class_prefix(); + std::string* release_php_class_prefix(); + void set_allocated_php_class_prefix(std::string* php_class_prefix); + private: + const std::string& _internal_php_class_prefix() const; + void _internal_set_php_class_prefix(const std::string& value); + std::string* _internal_mutable_php_class_prefix(); + public: + + // optional string php_namespace = 41; + bool has_php_namespace() const; + private: + bool _internal_has_php_namespace() const; + public: + void clear_php_namespace(); + const std::string& php_namespace() const; + void set_php_namespace(const std::string& value); + void set_php_namespace(std::string&& value); + void set_php_namespace(const char* value); + void set_php_namespace(const char* value, size_t size); + std::string* mutable_php_namespace(); + std::string* release_php_namespace(); + void set_allocated_php_namespace(std::string* php_namespace); + private: + const std::string& _internal_php_namespace() const; + void _internal_set_php_namespace(const std::string& value); + std::string* _internal_mutable_php_namespace(); + public: + + // optional string php_metadata_namespace = 44; + bool has_php_metadata_namespace() const; + private: + bool _internal_has_php_metadata_namespace() const; + public: + void clear_php_metadata_namespace(); + const std::string& php_metadata_namespace() const; + void set_php_metadata_namespace(const std::string& value); + void set_php_metadata_namespace(std::string&& value); + void set_php_metadata_namespace(const char* value); + void set_php_metadata_namespace(const char* value, size_t size); + std::string* mutable_php_metadata_namespace(); + std::string* release_php_metadata_namespace(); + void set_allocated_php_metadata_namespace(std::string* php_metadata_namespace); + private: + const std::string& _internal_php_metadata_namespace() const; + void _internal_set_php_metadata_namespace(const std::string& value); + std::string* _internal_mutable_php_metadata_namespace(); + public: + + // optional string ruby_package = 45; + bool has_ruby_package() const; + private: + bool _internal_has_ruby_package() const; + public: + void clear_ruby_package(); + const std::string& ruby_package() const; + void set_ruby_package(const std::string& value); + void set_ruby_package(std::string&& value); + void set_ruby_package(const char* value); + void set_ruby_package(const char* value, size_t size); + std::string* mutable_ruby_package(); + std::string* release_ruby_package(); + void set_allocated_ruby_package(std::string* ruby_package); + private: + const std::string& _internal_ruby_package() const; + void _internal_set_ruby_package(const std::string& value); + std::string* _internal_mutable_ruby_package(); + public: + + // optional bool java_multiple_files = 10 [default = false]; + bool has_java_multiple_files() const; + private: + bool _internal_has_java_multiple_files() const; + public: + void clear_java_multiple_files(); + bool java_multiple_files() const; + void set_java_multiple_files(bool value); + private: + bool _internal_java_multiple_files() const; + void _internal_set_java_multiple_files(bool value); + public: + + // optional bool java_generate_equals_and_hash = 20 [deprecated = true]; + PROTOBUF_DEPRECATED bool has_java_generate_equals_and_hash() const; + private: + bool _internal_has_java_generate_equals_and_hash() const; + public: + PROTOBUF_DEPRECATED void clear_java_generate_equals_and_hash(); + PROTOBUF_DEPRECATED bool java_generate_equals_and_hash() const; + PROTOBUF_DEPRECATED void set_java_generate_equals_and_hash(bool value); + private: + bool _internal_java_generate_equals_and_hash() const; + void _internal_set_java_generate_equals_and_hash(bool value); + public: + + // optional bool java_string_check_utf8 = 27 [default = false]; + bool has_java_string_check_utf8() const; + private: + bool _internal_has_java_string_check_utf8() const; + public: + void clear_java_string_check_utf8(); + bool java_string_check_utf8() const; + void set_java_string_check_utf8(bool value); + private: + bool _internal_java_string_check_utf8() const; + void _internal_set_java_string_check_utf8(bool value); + public: + + // optional bool cc_generic_services = 16 [default = false]; + bool has_cc_generic_services() const; + private: + bool _internal_has_cc_generic_services() const; + public: + void clear_cc_generic_services(); + bool cc_generic_services() const; + void set_cc_generic_services(bool value); + private: + bool _internal_cc_generic_services() const; + void _internal_set_cc_generic_services(bool value); + public: + + // optional bool java_generic_services = 17 [default = false]; + bool has_java_generic_services() const; + private: + bool _internal_has_java_generic_services() const; + public: + void clear_java_generic_services(); + bool java_generic_services() const; + void set_java_generic_services(bool value); + private: + bool _internal_java_generic_services() const; + void _internal_set_java_generic_services(bool value); + public: + + // optional bool py_generic_services = 18 [default = false]; + bool has_py_generic_services() const; + private: + bool _internal_has_py_generic_services() const; + public: + void clear_py_generic_services(); + bool py_generic_services() const; + void set_py_generic_services(bool value); + private: + bool _internal_py_generic_services() const; + void _internal_set_py_generic_services(bool value); + public: + + // optional bool php_generic_services = 42 [default = false]; + bool has_php_generic_services() const; + private: + bool _internal_has_php_generic_services() const; + public: + void clear_php_generic_services(); + bool php_generic_services() const; + void set_php_generic_services(bool value); + private: + bool _internal_php_generic_services() const; + void _internal_set_php_generic_services(bool value); + public: + + // optional bool deprecated = 23 [default = false]; + bool has_deprecated() const; + private: + bool _internal_has_deprecated() const; + public: + void clear_deprecated(); + bool deprecated() const; + void set_deprecated(bool value); + private: + bool _internal_deprecated() const; + void _internal_set_deprecated(bool value); + public: + + // optional .google.protobuf.FileOptions.OptimizeMode optimize_for = 9 [default = SPEED]; + bool has_optimize_for() const; + private: + bool _internal_has_optimize_for() const; + public: + void clear_optimize_for(); + PROTOBUF_NAMESPACE_ID::FileOptions_OptimizeMode optimize_for() const; + void set_optimize_for(PROTOBUF_NAMESPACE_ID::FileOptions_OptimizeMode value); + private: + PROTOBUF_NAMESPACE_ID::FileOptions_OptimizeMode _internal_optimize_for() const; + void _internal_set_optimize_for(PROTOBUF_NAMESPACE_ID::FileOptions_OptimizeMode value); + public: + + // optional bool cc_enable_arenas = 31 [default = true]; + bool has_cc_enable_arenas() const; + private: + bool _internal_has_cc_enable_arenas() const; + public: + void clear_cc_enable_arenas(); + bool cc_enable_arenas() const; + void set_cc_enable_arenas(bool value); + private: + bool _internal_cc_enable_arenas() const; + void _internal_set_cc_enable_arenas(bool value); + public: + + GOOGLE_PROTOBUF_EXTENSION_ACCESSORS(FileOptions) + // @@protoc_insertion_point(class_scope:google.protobuf.FileOptions) + private: + class _Internal; + + ::PROTOBUF_NAMESPACE_ID::internal::ExtensionSet _extensions_; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::UninterpretedOption > uninterpreted_option_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr java_package_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr java_outer_classname_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr go_package_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr objc_class_prefix_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr csharp_namespace_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr swift_prefix_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr php_class_prefix_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr php_namespace_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr php_metadata_namespace_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr ruby_package_; + bool java_multiple_files_; + bool java_generate_equals_and_hash_; + bool java_string_check_utf8_; + bool cc_generic_services_; + bool java_generic_services_; + bool py_generic_services_; + bool php_generic_services_; + bool deprecated_; + int optimize_for_; + bool cc_enable_arenas_; + friend struct ::TableStruct_google_2fprotobuf_2fdescriptor_2eproto; +}; +// ------------------------------------------------------------------- + +class PROTOBUF_EXPORT MessageOptions PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:google.protobuf.MessageOptions) */ { + public: + inline MessageOptions() : MessageOptions(nullptr) {} + virtual ~MessageOptions(); + + MessageOptions(const MessageOptions& from); + MessageOptions(MessageOptions&& from) noexcept + : MessageOptions() { + *this = ::std::move(from); + } + + inline MessageOptions& operator=(const MessageOptions& from) { + CopyFrom(from); + return *this; + } + inline MessageOptions& operator=(MessageOptions&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const MessageOptions& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const MessageOptions* internal_default_instance() { + return reinterpret_cast( + &_MessageOptions_default_instance_); + } + static constexpr int kIndexInFileMessages = + 14; + + friend void swap(MessageOptions& a, MessageOptions& b) { + a.Swap(&b); + } + inline void Swap(MessageOptions* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(MessageOptions* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline MessageOptions* New() const final { + return CreateMaybeMessage(nullptr); + } + + MessageOptions* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const MessageOptions& from); + void MergeFrom(const MessageOptions& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(MessageOptions* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "google.protobuf.MessageOptions"; + } + protected: + explicit MessageOptions(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_google_2fprotobuf_2fdescriptor_2eproto); + return ::descriptor_table_google_2fprotobuf_2fdescriptor_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kUninterpretedOptionFieldNumber = 999, + kMessageSetWireFormatFieldNumber = 1, + kNoStandardDescriptorAccessorFieldNumber = 2, + kDeprecatedFieldNumber = 3, + kMapEntryFieldNumber = 7, + }; + // repeated .google.protobuf.UninterpretedOption uninterpreted_option = 999; + int uninterpreted_option_size() const; + private: + int _internal_uninterpreted_option_size() const; + public: + void clear_uninterpreted_option(); + PROTOBUF_NAMESPACE_ID::UninterpretedOption* mutable_uninterpreted_option(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::UninterpretedOption >* + mutable_uninterpreted_option(); + private: + const PROTOBUF_NAMESPACE_ID::UninterpretedOption& _internal_uninterpreted_option(int index) const; + PROTOBUF_NAMESPACE_ID::UninterpretedOption* _internal_add_uninterpreted_option(); + public: + const PROTOBUF_NAMESPACE_ID::UninterpretedOption& uninterpreted_option(int index) const; + PROTOBUF_NAMESPACE_ID::UninterpretedOption* add_uninterpreted_option(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::UninterpretedOption >& + uninterpreted_option() const; + + // optional bool message_set_wire_format = 1 [default = false]; + bool has_message_set_wire_format() const; + private: + bool _internal_has_message_set_wire_format() const; + public: + void clear_message_set_wire_format(); + bool message_set_wire_format() const; + void set_message_set_wire_format(bool value); + private: + bool _internal_message_set_wire_format() const; + void _internal_set_message_set_wire_format(bool value); + public: + + // optional bool no_standard_descriptor_accessor = 2 [default = false]; + bool has_no_standard_descriptor_accessor() const; + private: + bool _internal_has_no_standard_descriptor_accessor() const; + public: + void clear_no_standard_descriptor_accessor(); + bool no_standard_descriptor_accessor() const; + void set_no_standard_descriptor_accessor(bool value); + private: + bool _internal_no_standard_descriptor_accessor() const; + void _internal_set_no_standard_descriptor_accessor(bool value); + public: + + // optional bool deprecated = 3 [default = false]; + bool has_deprecated() const; + private: + bool _internal_has_deprecated() const; + public: + void clear_deprecated(); + bool deprecated() const; + void set_deprecated(bool value); + private: + bool _internal_deprecated() const; + void _internal_set_deprecated(bool value); + public: + + // optional bool map_entry = 7; + bool has_map_entry() const; + private: + bool _internal_has_map_entry() const; + public: + void clear_map_entry(); + bool map_entry() const; + void set_map_entry(bool value); + private: + bool _internal_map_entry() const; + void _internal_set_map_entry(bool value); + public: + + GOOGLE_PROTOBUF_EXTENSION_ACCESSORS(MessageOptions) + // @@protoc_insertion_point(class_scope:google.protobuf.MessageOptions) + private: + class _Internal; + + ::PROTOBUF_NAMESPACE_ID::internal::ExtensionSet _extensions_; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::UninterpretedOption > uninterpreted_option_; + bool message_set_wire_format_; + bool no_standard_descriptor_accessor_; + bool deprecated_; + bool map_entry_; + friend struct ::TableStruct_google_2fprotobuf_2fdescriptor_2eproto; +}; +// ------------------------------------------------------------------- + +class PROTOBUF_EXPORT FieldOptions PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:google.protobuf.FieldOptions) */ { + public: + inline FieldOptions() : FieldOptions(nullptr) {} + virtual ~FieldOptions(); + + FieldOptions(const FieldOptions& from); + FieldOptions(FieldOptions&& from) noexcept + : FieldOptions() { + *this = ::std::move(from); + } + + inline FieldOptions& operator=(const FieldOptions& from) { + CopyFrom(from); + return *this; + } + inline FieldOptions& operator=(FieldOptions&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const FieldOptions& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const FieldOptions* internal_default_instance() { + return reinterpret_cast( + &_FieldOptions_default_instance_); + } + static constexpr int kIndexInFileMessages = + 15; + + friend void swap(FieldOptions& a, FieldOptions& b) { + a.Swap(&b); + } + inline void Swap(FieldOptions* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(FieldOptions* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline FieldOptions* New() const final { + return CreateMaybeMessage(nullptr); + } + + FieldOptions* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const FieldOptions& from); + void MergeFrom(const FieldOptions& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(FieldOptions* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "google.protobuf.FieldOptions"; + } + protected: + explicit FieldOptions(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_google_2fprotobuf_2fdescriptor_2eproto); + return ::descriptor_table_google_2fprotobuf_2fdescriptor_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + typedef FieldOptions_CType CType; + static constexpr CType STRING = + FieldOptions_CType_STRING; + static constexpr CType CORD = + FieldOptions_CType_CORD; + static constexpr CType STRING_PIECE = + FieldOptions_CType_STRING_PIECE; + static inline bool CType_IsValid(int value) { + return FieldOptions_CType_IsValid(value); + } + static constexpr CType CType_MIN = + FieldOptions_CType_CType_MIN; + static constexpr CType CType_MAX = + FieldOptions_CType_CType_MAX; + static constexpr int CType_ARRAYSIZE = + FieldOptions_CType_CType_ARRAYSIZE; + static inline const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* + CType_descriptor() { + return FieldOptions_CType_descriptor(); + } + template + static inline const std::string& CType_Name(T enum_t_value) { + static_assert(::std::is_same::value || + ::std::is_integral::value, + "Incorrect type passed to function CType_Name."); + return FieldOptions_CType_Name(enum_t_value); + } + static inline bool CType_Parse(::PROTOBUF_NAMESPACE_ID::ConstStringParam name, + CType* value) { + return FieldOptions_CType_Parse(name, value); + } + + typedef FieldOptions_JSType JSType; + static constexpr JSType JS_NORMAL = + FieldOptions_JSType_JS_NORMAL; + static constexpr JSType JS_STRING = + FieldOptions_JSType_JS_STRING; + static constexpr JSType JS_NUMBER = + FieldOptions_JSType_JS_NUMBER; + static inline bool JSType_IsValid(int value) { + return FieldOptions_JSType_IsValid(value); + } + static constexpr JSType JSType_MIN = + FieldOptions_JSType_JSType_MIN; + static constexpr JSType JSType_MAX = + FieldOptions_JSType_JSType_MAX; + static constexpr int JSType_ARRAYSIZE = + FieldOptions_JSType_JSType_ARRAYSIZE; + static inline const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* + JSType_descriptor() { + return FieldOptions_JSType_descriptor(); + } + template + static inline const std::string& JSType_Name(T enum_t_value) { + static_assert(::std::is_same::value || + ::std::is_integral::value, + "Incorrect type passed to function JSType_Name."); + return FieldOptions_JSType_Name(enum_t_value); + } + static inline bool JSType_Parse(::PROTOBUF_NAMESPACE_ID::ConstStringParam name, + JSType* value) { + return FieldOptions_JSType_Parse(name, value); + } + + // accessors ------------------------------------------------------- + + enum : int { + kUninterpretedOptionFieldNumber = 999, + kCtypeFieldNumber = 1, + kPackedFieldNumber = 2, + kLazyFieldNumber = 5, + kDeprecatedFieldNumber = 3, + kWeakFieldNumber = 10, + kJstypeFieldNumber = 6, + }; + // repeated .google.protobuf.UninterpretedOption uninterpreted_option = 999; + int uninterpreted_option_size() const; + private: + int _internal_uninterpreted_option_size() const; + public: + void clear_uninterpreted_option(); + PROTOBUF_NAMESPACE_ID::UninterpretedOption* mutable_uninterpreted_option(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::UninterpretedOption >* + mutable_uninterpreted_option(); + private: + const PROTOBUF_NAMESPACE_ID::UninterpretedOption& _internal_uninterpreted_option(int index) const; + PROTOBUF_NAMESPACE_ID::UninterpretedOption* _internal_add_uninterpreted_option(); + public: + const PROTOBUF_NAMESPACE_ID::UninterpretedOption& uninterpreted_option(int index) const; + PROTOBUF_NAMESPACE_ID::UninterpretedOption* add_uninterpreted_option(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::UninterpretedOption >& + uninterpreted_option() const; + + // optional .google.protobuf.FieldOptions.CType ctype = 1 [default = STRING]; + bool has_ctype() const; + private: + bool _internal_has_ctype() const; + public: + void clear_ctype(); + PROTOBUF_NAMESPACE_ID::FieldOptions_CType ctype() const; + void set_ctype(PROTOBUF_NAMESPACE_ID::FieldOptions_CType value); + private: + PROTOBUF_NAMESPACE_ID::FieldOptions_CType _internal_ctype() const; + void _internal_set_ctype(PROTOBUF_NAMESPACE_ID::FieldOptions_CType value); + public: + + // optional bool packed = 2; + bool has_packed() const; + private: + bool _internal_has_packed() const; + public: + void clear_packed(); + bool packed() const; + void set_packed(bool value); + private: + bool _internal_packed() const; + void _internal_set_packed(bool value); + public: + + // optional bool lazy = 5 [default = false]; + bool has_lazy() const; + private: + bool _internal_has_lazy() const; + public: + void clear_lazy(); + bool lazy() const; + void set_lazy(bool value); + private: + bool _internal_lazy() const; + void _internal_set_lazy(bool value); + public: + + // optional bool deprecated = 3 [default = false]; + bool has_deprecated() const; + private: + bool _internal_has_deprecated() const; + public: + void clear_deprecated(); + bool deprecated() const; + void set_deprecated(bool value); + private: + bool _internal_deprecated() const; + void _internal_set_deprecated(bool value); + public: + + // optional bool weak = 10 [default = false]; + bool has_weak() const; + private: + bool _internal_has_weak() const; + public: + void clear_weak(); + bool weak() const; + void set_weak(bool value); + private: + bool _internal_weak() const; + void _internal_set_weak(bool value); + public: + + // optional .google.protobuf.FieldOptions.JSType jstype = 6 [default = JS_NORMAL]; + bool has_jstype() const; + private: + bool _internal_has_jstype() const; + public: + void clear_jstype(); + PROTOBUF_NAMESPACE_ID::FieldOptions_JSType jstype() const; + void set_jstype(PROTOBUF_NAMESPACE_ID::FieldOptions_JSType value); + private: + PROTOBUF_NAMESPACE_ID::FieldOptions_JSType _internal_jstype() const; + void _internal_set_jstype(PROTOBUF_NAMESPACE_ID::FieldOptions_JSType value); + public: + + GOOGLE_PROTOBUF_EXTENSION_ACCESSORS(FieldOptions) + // @@protoc_insertion_point(class_scope:google.protobuf.FieldOptions) + private: + class _Internal; + + ::PROTOBUF_NAMESPACE_ID::internal::ExtensionSet _extensions_; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::UninterpretedOption > uninterpreted_option_; + int ctype_; + bool packed_; + bool lazy_; + bool deprecated_; + bool weak_; + int jstype_; + friend struct ::TableStruct_google_2fprotobuf_2fdescriptor_2eproto; +}; +// ------------------------------------------------------------------- + +class PROTOBUF_EXPORT OneofOptions PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:google.protobuf.OneofOptions) */ { + public: + inline OneofOptions() : OneofOptions(nullptr) {} + virtual ~OneofOptions(); + + OneofOptions(const OneofOptions& from); + OneofOptions(OneofOptions&& from) noexcept + : OneofOptions() { + *this = ::std::move(from); + } + + inline OneofOptions& operator=(const OneofOptions& from) { + CopyFrom(from); + return *this; + } + inline OneofOptions& operator=(OneofOptions&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const OneofOptions& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const OneofOptions* internal_default_instance() { + return reinterpret_cast( + &_OneofOptions_default_instance_); + } + static constexpr int kIndexInFileMessages = + 16; + + friend void swap(OneofOptions& a, OneofOptions& b) { + a.Swap(&b); + } + inline void Swap(OneofOptions* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(OneofOptions* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline OneofOptions* New() const final { + return CreateMaybeMessage(nullptr); + } + + OneofOptions* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const OneofOptions& from); + void MergeFrom(const OneofOptions& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(OneofOptions* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "google.protobuf.OneofOptions"; + } + protected: + explicit OneofOptions(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_google_2fprotobuf_2fdescriptor_2eproto); + return ::descriptor_table_google_2fprotobuf_2fdescriptor_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kUninterpretedOptionFieldNumber = 999, + }; + // repeated .google.protobuf.UninterpretedOption uninterpreted_option = 999; + int uninterpreted_option_size() const; + private: + int _internal_uninterpreted_option_size() const; + public: + void clear_uninterpreted_option(); + PROTOBUF_NAMESPACE_ID::UninterpretedOption* mutable_uninterpreted_option(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::UninterpretedOption >* + mutable_uninterpreted_option(); + private: + const PROTOBUF_NAMESPACE_ID::UninterpretedOption& _internal_uninterpreted_option(int index) const; + PROTOBUF_NAMESPACE_ID::UninterpretedOption* _internal_add_uninterpreted_option(); + public: + const PROTOBUF_NAMESPACE_ID::UninterpretedOption& uninterpreted_option(int index) const; + PROTOBUF_NAMESPACE_ID::UninterpretedOption* add_uninterpreted_option(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::UninterpretedOption >& + uninterpreted_option() const; + + GOOGLE_PROTOBUF_EXTENSION_ACCESSORS(OneofOptions) + // @@protoc_insertion_point(class_scope:google.protobuf.OneofOptions) + private: + class _Internal; + + ::PROTOBUF_NAMESPACE_ID::internal::ExtensionSet _extensions_; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::UninterpretedOption > uninterpreted_option_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + friend struct ::TableStruct_google_2fprotobuf_2fdescriptor_2eproto; +}; +// ------------------------------------------------------------------- + +class PROTOBUF_EXPORT EnumOptions PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:google.protobuf.EnumOptions) */ { + public: + inline EnumOptions() : EnumOptions(nullptr) {} + virtual ~EnumOptions(); + + EnumOptions(const EnumOptions& from); + EnumOptions(EnumOptions&& from) noexcept + : EnumOptions() { + *this = ::std::move(from); + } + + inline EnumOptions& operator=(const EnumOptions& from) { + CopyFrom(from); + return *this; + } + inline EnumOptions& operator=(EnumOptions&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const EnumOptions& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const EnumOptions* internal_default_instance() { + return reinterpret_cast( + &_EnumOptions_default_instance_); + } + static constexpr int kIndexInFileMessages = + 17; + + friend void swap(EnumOptions& a, EnumOptions& b) { + a.Swap(&b); + } + inline void Swap(EnumOptions* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(EnumOptions* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline EnumOptions* New() const final { + return CreateMaybeMessage(nullptr); + } + + EnumOptions* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const EnumOptions& from); + void MergeFrom(const EnumOptions& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(EnumOptions* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "google.protobuf.EnumOptions"; + } + protected: + explicit EnumOptions(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_google_2fprotobuf_2fdescriptor_2eproto); + return ::descriptor_table_google_2fprotobuf_2fdescriptor_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kUninterpretedOptionFieldNumber = 999, + kAllowAliasFieldNumber = 2, + kDeprecatedFieldNumber = 3, + }; + // repeated .google.protobuf.UninterpretedOption uninterpreted_option = 999; + int uninterpreted_option_size() const; + private: + int _internal_uninterpreted_option_size() const; + public: + void clear_uninterpreted_option(); + PROTOBUF_NAMESPACE_ID::UninterpretedOption* mutable_uninterpreted_option(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::UninterpretedOption >* + mutable_uninterpreted_option(); + private: + const PROTOBUF_NAMESPACE_ID::UninterpretedOption& _internal_uninterpreted_option(int index) const; + PROTOBUF_NAMESPACE_ID::UninterpretedOption* _internal_add_uninterpreted_option(); + public: + const PROTOBUF_NAMESPACE_ID::UninterpretedOption& uninterpreted_option(int index) const; + PROTOBUF_NAMESPACE_ID::UninterpretedOption* add_uninterpreted_option(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::UninterpretedOption >& + uninterpreted_option() const; + + // optional bool allow_alias = 2; + bool has_allow_alias() const; + private: + bool _internal_has_allow_alias() const; + public: + void clear_allow_alias(); + bool allow_alias() const; + void set_allow_alias(bool value); + private: + bool _internal_allow_alias() const; + void _internal_set_allow_alias(bool value); + public: + + // optional bool deprecated = 3 [default = false]; + bool has_deprecated() const; + private: + bool _internal_has_deprecated() const; + public: + void clear_deprecated(); + bool deprecated() const; + void set_deprecated(bool value); + private: + bool _internal_deprecated() const; + void _internal_set_deprecated(bool value); + public: + + GOOGLE_PROTOBUF_EXTENSION_ACCESSORS(EnumOptions) + // @@protoc_insertion_point(class_scope:google.protobuf.EnumOptions) + private: + class _Internal; + + ::PROTOBUF_NAMESPACE_ID::internal::ExtensionSet _extensions_; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::UninterpretedOption > uninterpreted_option_; + bool allow_alias_; + bool deprecated_; + friend struct ::TableStruct_google_2fprotobuf_2fdescriptor_2eproto; +}; +// ------------------------------------------------------------------- + +class PROTOBUF_EXPORT EnumValueOptions PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:google.protobuf.EnumValueOptions) */ { + public: + inline EnumValueOptions() : EnumValueOptions(nullptr) {} + virtual ~EnumValueOptions(); + + EnumValueOptions(const EnumValueOptions& from); + EnumValueOptions(EnumValueOptions&& from) noexcept + : EnumValueOptions() { + *this = ::std::move(from); + } + + inline EnumValueOptions& operator=(const EnumValueOptions& from) { + CopyFrom(from); + return *this; + } + inline EnumValueOptions& operator=(EnumValueOptions&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const EnumValueOptions& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const EnumValueOptions* internal_default_instance() { + return reinterpret_cast( + &_EnumValueOptions_default_instance_); + } + static constexpr int kIndexInFileMessages = + 18; + + friend void swap(EnumValueOptions& a, EnumValueOptions& b) { + a.Swap(&b); + } + inline void Swap(EnumValueOptions* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(EnumValueOptions* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline EnumValueOptions* New() const final { + return CreateMaybeMessage(nullptr); + } + + EnumValueOptions* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const EnumValueOptions& from); + void MergeFrom(const EnumValueOptions& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(EnumValueOptions* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "google.protobuf.EnumValueOptions"; + } + protected: + explicit EnumValueOptions(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_google_2fprotobuf_2fdescriptor_2eproto); + return ::descriptor_table_google_2fprotobuf_2fdescriptor_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kUninterpretedOptionFieldNumber = 999, + kDeprecatedFieldNumber = 1, + }; + // repeated .google.protobuf.UninterpretedOption uninterpreted_option = 999; + int uninterpreted_option_size() const; + private: + int _internal_uninterpreted_option_size() const; + public: + void clear_uninterpreted_option(); + PROTOBUF_NAMESPACE_ID::UninterpretedOption* mutable_uninterpreted_option(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::UninterpretedOption >* + mutable_uninterpreted_option(); + private: + const PROTOBUF_NAMESPACE_ID::UninterpretedOption& _internal_uninterpreted_option(int index) const; + PROTOBUF_NAMESPACE_ID::UninterpretedOption* _internal_add_uninterpreted_option(); + public: + const PROTOBUF_NAMESPACE_ID::UninterpretedOption& uninterpreted_option(int index) const; + PROTOBUF_NAMESPACE_ID::UninterpretedOption* add_uninterpreted_option(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::UninterpretedOption >& + uninterpreted_option() const; + + // optional bool deprecated = 1 [default = false]; + bool has_deprecated() const; + private: + bool _internal_has_deprecated() const; + public: + void clear_deprecated(); + bool deprecated() const; + void set_deprecated(bool value); + private: + bool _internal_deprecated() const; + void _internal_set_deprecated(bool value); + public: + + GOOGLE_PROTOBUF_EXTENSION_ACCESSORS(EnumValueOptions) + // @@protoc_insertion_point(class_scope:google.protobuf.EnumValueOptions) + private: + class _Internal; + + ::PROTOBUF_NAMESPACE_ID::internal::ExtensionSet _extensions_; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::UninterpretedOption > uninterpreted_option_; + bool deprecated_; + friend struct ::TableStruct_google_2fprotobuf_2fdescriptor_2eproto; +}; +// ------------------------------------------------------------------- + +class PROTOBUF_EXPORT ServiceOptions PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:google.protobuf.ServiceOptions) */ { + public: + inline ServiceOptions() : ServiceOptions(nullptr) {} + virtual ~ServiceOptions(); + + ServiceOptions(const ServiceOptions& from); + ServiceOptions(ServiceOptions&& from) noexcept + : ServiceOptions() { + *this = ::std::move(from); + } + + inline ServiceOptions& operator=(const ServiceOptions& from) { + CopyFrom(from); + return *this; + } + inline ServiceOptions& operator=(ServiceOptions&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const ServiceOptions& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const ServiceOptions* internal_default_instance() { + return reinterpret_cast( + &_ServiceOptions_default_instance_); + } + static constexpr int kIndexInFileMessages = + 19; + + friend void swap(ServiceOptions& a, ServiceOptions& b) { + a.Swap(&b); + } + inline void Swap(ServiceOptions* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(ServiceOptions* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline ServiceOptions* New() const final { + return CreateMaybeMessage(nullptr); + } + + ServiceOptions* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const ServiceOptions& from); + void MergeFrom(const ServiceOptions& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(ServiceOptions* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "google.protobuf.ServiceOptions"; + } + protected: + explicit ServiceOptions(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_google_2fprotobuf_2fdescriptor_2eproto); + return ::descriptor_table_google_2fprotobuf_2fdescriptor_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kUninterpretedOptionFieldNumber = 999, + kDeprecatedFieldNumber = 33, + }; + // repeated .google.protobuf.UninterpretedOption uninterpreted_option = 999; + int uninterpreted_option_size() const; + private: + int _internal_uninterpreted_option_size() const; + public: + void clear_uninterpreted_option(); + PROTOBUF_NAMESPACE_ID::UninterpretedOption* mutable_uninterpreted_option(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::UninterpretedOption >* + mutable_uninterpreted_option(); + private: + const PROTOBUF_NAMESPACE_ID::UninterpretedOption& _internal_uninterpreted_option(int index) const; + PROTOBUF_NAMESPACE_ID::UninterpretedOption* _internal_add_uninterpreted_option(); + public: + const PROTOBUF_NAMESPACE_ID::UninterpretedOption& uninterpreted_option(int index) const; + PROTOBUF_NAMESPACE_ID::UninterpretedOption* add_uninterpreted_option(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::UninterpretedOption >& + uninterpreted_option() const; + + // optional bool deprecated = 33 [default = false]; + bool has_deprecated() const; + private: + bool _internal_has_deprecated() const; + public: + void clear_deprecated(); + bool deprecated() const; + void set_deprecated(bool value); + private: + bool _internal_deprecated() const; + void _internal_set_deprecated(bool value); + public: + + GOOGLE_PROTOBUF_EXTENSION_ACCESSORS(ServiceOptions) + // @@protoc_insertion_point(class_scope:google.protobuf.ServiceOptions) + private: + class _Internal; + + ::PROTOBUF_NAMESPACE_ID::internal::ExtensionSet _extensions_; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::UninterpretedOption > uninterpreted_option_; + bool deprecated_; + friend struct ::TableStruct_google_2fprotobuf_2fdescriptor_2eproto; +}; +// ------------------------------------------------------------------- + +class PROTOBUF_EXPORT MethodOptions PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:google.protobuf.MethodOptions) */ { + public: + inline MethodOptions() : MethodOptions(nullptr) {} + virtual ~MethodOptions(); + + MethodOptions(const MethodOptions& from); + MethodOptions(MethodOptions&& from) noexcept + : MethodOptions() { + *this = ::std::move(from); + } + + inline MethodOptions& operator=(const MethodOptions& from) { + CopyFrom(from); + return *this; + } + inline MethodOptions& operator=(MethodOptions&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const MethodOptions& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const MethodOptions* internal_default_instance() { + return reinterpret_cast( + &_MethodOptions_default_instance_); + } + static constexpr int kIndexInFileMessages = + 20; + + friend void swap(MethodOptions& a, MethodOptions& b) { + a.Swap(&b); + } + inline void Swap(MethodOptions* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(MethodOptions* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline MethodOptions* New() const final { + return CreateMaybeMessage(nullptr); + } + + MethodOptions* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const MethodOptions& from); + void MergeFrom(const MethodOptions& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(MethodOptions* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "google.protobuf.MethodOptions"; + } + protected: + explicit MethodOptions(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_google_2fprotobuf_2fdescriptor_2eproto); + return ::descriptor_table_google_2fprotobuf_2fdescriptor_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + typedef MethodOptions_IdempotencyLevel IdempotencyLevel; + static constexpr IdempotencyLevel IDEMPOTENCY_UNKNOWN = + MethodOptions_IdempotencyLevel_IDEMPOTENCY_UNKNOWN; + static constexpr IdempotencyLevel NO_SIDE_EFFECTS = + MethodOptions_IdempotencyLevel_NO_SIDE_EFFECTS; + static constexpr IdempotencyLevel IDEMPOTENT = + MethodOptions_IdempotencyLevel_IDEMPOTENT; + static inline bool IdempotencyLevel_IsValid(int value) { + return MethodOptions_IdempotencyLevel_IsValid(value); + } + static constexpr IdempotencyLevel IdempotencyLevel_MIN = + MethodOptions_IdempotencyLevel_IdempotencyLevel_MIN; + static constexpr IdempotencyLevel IdempotencyLevel_MAX = + MethodOptions_IdempotencyLevel_IdempotencyLevel_MAX; + static constexpr int IdempotencyLevel_ARRAYSIZE = + MethodOptions_IdempotencyLevel_IdempotencyLevel_ARRAYSIZE; + static inline const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* + IdempotencyLevel_descriptor() { + return MethodOptions_IdempotencyLevel_descriptor(); + } + template + static inline const std::string& IdempotencyLevel_Name(T enum_t_value) { + static_assert(::std::is_same::value || + ::std::is_integral::value, + "Incorrect type passed to function IdempotencyLevel_Name."); + return MethodOptions_IdempotencyLevel_Name(enum_t_value); + } + static inline bool IdempotencyLevel_Parse(::PROTOBUF_NAMESPACE_ID::ConstStringParam name, + IdempotencyLevel* value) { + return MethodOptions_IdempotencyLevel_Parse(name, value); + } + + // accessors ------------------------------------------------------- + + enum : int { + kUninterpretedOptionFieldNumber = 999, + kDeprecatedFieldNumber = 33, + kIdempotencyLevelFieldNumber = 34, + }; + // repeated .google.protobuf.UninterpretedOption uninterpreted_option = 999; + int uninterpreted_option_size() const; + private: + int _internal_uninterpreted_option_size() const; + public: + void clear_uninterpreted_option(); + PROTOBUF_NAMESPACE_ID::UninterpretedOption* mutable_uninterpreted_option(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::UninterpretedOption >* + mutable_uninterpreted_option(); + private: + const PROTOBUF_NAMESPACE_ID::UninterpretedOption& _internal_uninterpreted_option(int index) const; + PROTOBUF_NAMESPACE_ID::UninterpretedOption* _internal_add_uninterpreted_option(); + public: + const PROTOBUF_NAMESPACE_ID::UninterpretedOption& uninterpreted_option(int index) const; + PROTOBUF_NAMESPACE_ID::UninterpretedOption* add_uninterpreted_option(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::UninterpretedOption >& + uninterpreted_option() const; + + // optional bool deprecated = 33 [default = false]; + bool has_deprecated() const; + private: + bool _internal_has_deprecated() const; + public: + void clear_deprecated(); + bool deprecated() const; + void set_deprecated(bool value); + private: + bool _internal_deprecated() const; + void _internal_set_deprecated(bool value); + public: + + // optional .google.protobuf.MethodOptions.IdempotencyLevel idempotency_level = 34 [default = IDEMPOTENCY_UNKNOWN]; + bool has_idempotency_level() const; + private: + bool _internal_has_idempotency_level() const; + public: + void clear_idempotency_level(); + PROTOBUF_NAMESPACE_ID::MethodOptions_IdempotencyLevel idempotency_level() const; + void set_idempotency_level(PROTOBUF_NAMESPACE_ID::MethodOptions_IdempotencyLevel value); + private: + PROTOBUF_NAMESPACE_ID::MethodOptions_IdempotencyLevel _internal_idempotency_level() const; + void _internal_set_idempotency_level(PROTOBUF_NAMESPACE_ID::MethodOptions_IdempotencyLevel value); + public: + + GOOGLE_PROTOBUF_EXTENSION_ACCESSORS(MethodOptions) + // @@protoc_insertion_point(class_scope:google.protobuf.MethodOptions) + private: + class _Internal; + + ::PROTOBUF_NAMESPACE_ID::internal::ExtensionSet _extensions_; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::UninterpretedOption > uninterpreted_option_; + bool deprecated_; + int idempotency_level_; + friend struct ::TableStruct_google_2fprotobuf_2fdescriptor_2eproto; +}; +// ------------------------------------------------------------------- + +class PROTOBUF_EXPORT UninterpretedOption_NamePart PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:google.protobuf.UninterpretedOption.NamePart) */ { + public: + inline UninterpretedOption_NamePart() : UninterpretedOption_NamePart(nullptr) {} + virtual ~UninterpretedOption_NamePart(); + + UninterpretedOption_NamePart(const UninterpretedOption_NamePart& from); + UninterpretedOption_NamePart(UninterpretedOption_NamePart&& from) noexcept + : UninterpretedOption_NamePart() { + *this = ::std::move(from); + } + + inline UninterpretedOption_NamePart& operator=(const UninterpretedOption_NamePart& from) { + CopyFrom(from); + return *this; + } + inline UninterpretedOption_NamePart& operator=(UninterpretedOption_NamePart&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const UninterpretedOption_NamePart& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const UninterpretedOption_NamePart* internal_default_instance() { + return reinterpret_cast( + &_UninterpretedOption_NamePart_default_instance_); + } + static constexpr int kIndexInFileMessages = + 21; + + friend void swap(UninterpretedOption_NamePart& a, UninterpretedOption_NamePart& b) { + a.Swap(&b); + } + inline void Swap(UninterpretedOption_NamePart* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(UninterpretedOption_NamePart* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline UninterpretedOption_NamePart* New() const final { + return CreateMaybeMessage(nullptr); + } + + UninterpretedOption_NamePart* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const UninterpretedOption_NamePart& from); + void MergeFrom(const UninterpretedOption_NamePart& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(UninterpretedOption_NamePart* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "google.protobuf.UninterpretedOption.NamePart"; + } + protected: + explicit UninterpretedOption_NamePart(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_google_2fprotobuf_2fdescriptor_2eproto); + return ::descriptor_table_google_2fprotobuf_2fdescriptor_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kNamePartFieldNumber = 1, + kIsExtensionFieldNumber = 2, + }; + // required string name_part = 1; + bool has_name_part() const; + private: + bool _internal_has_name_part() const; + public: + void clear_name_part(); + const std::string& name_part() const; + void set_name_part(const std::string& value); + void set_name_part(std::string&& value); + void set_name_part(const char* value); + void set_name_part(const char* value, size_t size); + std::string* mutable_name_part(); + std::string* release_name_part(); + void set_allocated_name_part(std::string* name_part); + private: + const std::string& _internal_name_part() const; + void _internal_set_name_part(const std::string& value); + std::string* _internal_mutable_name_part(); + public: + + // required bool is_extension = 2; + bool has_is_extension() const; + private: + bool _internal_has_is_extension() const; + public: + void clear_is_extension(); + bool is_extension() const; + void set_is_extension(bool value); + private: + bool _internal_is_extension() const; + void _internal_set_is_extension(bool value); + public: + + // @@protoc_insertion_point(class_scope:google.protobuf.UninterpretedOption.NamePart) + private: + class _Internal; + + // helper for ByteSizeLong() + size_t RequiredFieldsByteSizeFallback() const; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr name_part_; + bool is_extension_; + friend struct ::TableStruct_google_2fprotobuf_2fdescriptor_2eproto; +}; +// ------------------------------------------------------------------- + +class PROTOBUF_EXPORT UninterpretedOption PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:google.protobuf.UninterpretedOption) */ { + public: + inline UninterpretedOption() : UninterpretedOption(nullptr) {} + virtual ~UninterpretedOption(); + + UninterpretedOption(const UninterpretedOption& from); + UninterpretedOption(UninterpretedOption&& from) noexcept + : UninterpretedOption() { + *this = ::std::move(from); + } + + inline UninterpretedOption& operator=(const UninterpretedOption& from) { + CopyFrom(from); + return *this; + } + inline UninterpretedOption& operator=(UninterpretedOption&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const UninterpretedOption& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const UninterpretedOption* internal_default_instance() { + return reinterpret_cast( + &_UninterpretedOption_default_instance_); + } + static constexpr int kIndexInFileMessages = + 22; + + friend void swap(UninterpretedOption& a, UninterpretedOption& b) { + a.Swap(&b); + } + inline void Swap(UninterpretedOption* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(UninterpretedOption* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline UninterpretedOption* New() const final { + return CreateMaybeMessage(nullptr); + } + + UninterpretedOption* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const UninterpretedOption& from); + void MergeFrom(const UninterpretedOption& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(UninterpretedOption* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "google.protobuf.UninterpretedOption"; + } + protected: + explicit UninterpretedOption(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_google_2fprotobuf_2fdescriptor_2eproto); + return ::descriptor_table_google_2fprotobuf_2fdescriptor_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + typedef UninterpretedOption_NamePart NamePart; + + // accessors ------------------------------------------------------- + + enum : int { + kNameFieldNumber = 2, + kIdentifierValueFieldNumber = 3, + kStringValueFieldNumber = 7, + kAggregateValueFieldNumber = 8, + kPositiveIntValueFieldNumber = 4, + kNegativeIntValueFieldNumber = 5, + kDoubleValueFieldNumber = 6, + }; + // repeated .google.protobuf.UninterpretedOption.NamePart name = 2; + int name_size() const; + private: + int _internal_name_size() const; + public: + void clear_name(); + PROTOBUF_NAMESPACE_ID::UninterpretedOption_NamePart* mutable_name(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::UninterpretedOption_NamePart >* + mutable_name(); + private: + const PROTOBUF_NAMESPACE_ID::UninterpretedOption_NamePart& _internal_name(int index) const; + PROTOBUF_NAMESPACE_ID::UninterpretedOption_NamePart* _internal_add_name(); + public: + const PROTOBUF_NAMESPACE_ID::UninterpretedOption_NamePart& name(int index) const; + PROTOBUF_NAMESPACE_ID::UninterpretedOption_NamePart* add_name(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::UninterpretedOption_NamePart >& + name() const; + + // optional string identifier_value = 3; + bool has_identifier_value() const; + private: + bool _internal_has_identifier_value() const; + public: + void clear_identifier_value(); + const std::string& identifier_value() const; + void set_identifier_value(const std::string& value); + void set_identifier_value(std::string&& value); + void set_identifier_value(const char* value); + void set_identifier_value(const char* value, size_t size); + std::string* mutable_identifier_value(); + std::string* release_identifier_value(); + void set_allocated_identifier_value(std::string* identifier_value); + private: + const std::string& _internal_identifier_value() const; + void _internal_set_identifier_value(const std::string& value); + std::string* _internal_mutable_identifier_value(); + public: + + // optional bytes string_value = 7; + bool has_string_value() const; + private: + bool _internal_has_string_value() const; + public: + void clear_string_value(); + const std::string& string_value() const; + void set_string_value(const std::string& value); + void set_string_value(std::string&& value); + void set_string_value(const char* value); + void set_string_value(const void* value, size_t size); + std::string* mutable_string_value(); + std::string* release_string_value(); + void set_allocated_string_value(std::string* string_value); + private: + const std::string& _internal_string_value() const; + void _internal_set_string_value(const std::string& value); + std::string* _internal_mutable_string_value(); + public: + + // optional string aggregate_value = 8; + bool has_aggregate_value() const; + private: + bool _internal_has_aggregate_value() const; + public: + void clear_aggregate_value(); + const std::string& aggregate_value() const; + void set_aggregate_value(const std::string& value); + void set_aggregate_value(std::string&& value); + void set_aggregate_value(const char* value); + void set_aggregate_value(const char* value, size_t size); + std::string* mutable_aggregate_value(); + std::string* release_aggregate_value(); + void set_allocated_aggregate_value(std::string* aggregate_value); + private: + const std::string& _internal_aggregate_value() const; + void _internal_set_aggregate_value(const std::string& value); + std::string* _internal_mutable_aggregate_value(); + public: + + // optional uint64 positive_int_value = 4; + bool has_positive_int_value() const; + private: + bool _internal_has_positive_int_value() const; + public: + void clear_positive_int_value(); + ::PROTOBUF_NAMESPACE_ID::uint64 positive_int_value() const; + void set_positive_int_value(::PROTOBUF_NAMESPACE_ID::uint64 value); + private: + ::PROTOBUF_NAMESPACE_ID::uint64 _internal_positive_int_value() const; + void _internal_set_positive_int_value(::PROTOBUF_NAMESPACE_ID::uint64 value); + public: + + // optional int64 negative_int_value = 5; + bool has_negative_int_value() const; + private: + bool _internal_has_negative_int_value() const; + public: + void clear_negative_int_value(); + ::PROTOBUF_NAMESPACE_ID::int64 negative_int_value() const; + void set_negative_int_value(::PROTOBUF_NAMESPACE_ID::int64 value); + private: + ::PROTOBUF_NAMESPACE_ID::int64 _internal_negative_int_value() const; + void _internal_set_negative_int_value(::PROTOBUF_NAMESPACE_ID::int64 value); + public: + + // optional double double_value = 6; + bool has_double_value() const; + private: + bool _internal_has_double_value() const; + public: + void clear_double_value(); + double double_value() const; + void set_double_value(double value); + private: + double _internal_double_value() const; + void _internal_set_double_value(double value); + public: + + // @@protoc_insertion_point(class_scope:google.protobuf.UninterpretedOption) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::UninterpretedOption_NamePart > name_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr identifier_value_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr string_value_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr aggregate_value_; + ::PROTOBUF_NAMESPACE_ID::uint64 positive_int_value_; + ::PROTOBUF_NAMESPACE_ID::int64 negative_int_value_; + double double_value_; + friend struct ::TableStruct_google_2fprotobuf_2fdescriptor_2eproto; +}; +// ------------------------------------------------------------------- + +class PROTOBUF_EXPORT SourceCodeInfo_Location PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:google.protobuf.SourceCodeInfo.Location) */ { + public: + inline SourceCodeInfo_Location() : SourceCodeInfo_Location(nullptr) {} + virtual ~SourceCodeInfo_Location(); + + SourceCodeInfo_Location(const SourceCodeInfo_Location& from); + SourceCodeInfo_Location(SourceCodeInfo_Location&& from) noexcept + : SourceCodeInfo_Location() { + *this = ::std::move(from); + } + + inline SourceCodeInfo_Location& operator=(const SourceCodeInfo_Location& from) { + CopyFrom(from); + return *this; + } + inline SourceCodeInfo_Location& operator=(SourceCodeInfo_Location&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const SourceCodeInfo_Location& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const SourceCodeInfo_Location* internal_default_instance() { + return reinterpret_cast( + &_SourceCodeInfo_Location_default_instance_); + } + static constexpr int kIndexInFileMessages = + 23; + + friend void swap(SourceCodeInfo_Location& a, SourceCodeInfo_Location& b) { + a.Swap(&b); + } + inline void Swap(SourceCodeInfo_Location* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(SourceCodeInfo_Location* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline SourceCodeInfo_Location* New() const final { + return CreateMaybeMessage(nullptr); + } + + SourceCodeInfo_Location* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const SourceCodeInfo_Location& from); + void MergeFrom(const SourceCodeInfo_Location& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(SourceCodeInfo_Location* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "google.protobuf.SourceCodeInfo.Location"; + } + protected: + explicit SourceCodeInfo_Location(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_google_2fprotobuf_2fdescriptor_2eproto); + return ::descriptor_table_google_2fprotobuf_2fdescriptor_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kPathFieldNumber = 1, + kSpanFieldNumber = 2, + kLeadingDetachedCommentsFieldNumber = 6, + kLeadingCommentsFieldNumber = 3, + kTrailingCommentsFieldNumber = 4, + }; + // repeated int32 path = 1 [packed = true]; + int path_size() const; + private: + int _internal_path_size() const; + public: + void clear_path(); + private: + ::PROTOBUF_NAMESPACE_ID::int32 _internal_path(int index) const; + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >& + _internal_path() const; + void _internal_add_path(::PROTOBUF_NAMESPACE_ID::int32 value); + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >* + _internal_mutable_path(); + public: + ::PROTOBUF_NAMESPACE_ID::int32 path(int index) const; + void set_path(int index, ::PROTOBUF_NAMESPACE_ID::int32 value); + void add_path(::PROTOBUF_NAMESPACE_ID::int32 value); + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >& + path() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >* + mutable_path(); + + // repeated int32 span = 2 [packed = true]; + int span_size() const; + private: + int _internal_span_size() const; + public: + void clear_span(); + private: + ::PROTOBUF_NAMESPACE_ID::int32 _internal_span(int index) const; + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >& + _internal_span() const; + void _internal_add_span(::PROTOBUF_NAMESPACE_ID::int32 value); + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >* + _internal_mutable_span(); + public: + ::PROTOBUF_NAMESPACE_ID::int32 span(int index) const; + void set_span(int index, ::PROTOBUF_NAMESPACE_ID::int32 value); + void add_span(::PROTOBUF_NAMESPACE_ID::int32 value); + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >& + span() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >* + mutable_span(); + + // repeated string leading_detached_comments = 6; + int leading_detached_comments_size() const; + private: + int _internal_leading_detached_comments_size() const; + public: + void clear_leading_detached_comments(); + const std::string& leading_detached_comments(int index) const; + std::string* mutable_leading_detached_comments(int index); + void set_leading_detached_comments(int index, const std::string& value); + void set_leading_detached_comments(int index, std::string&& value); + void set_leading_detached_comments(int index, const char* value); + void set_leading_detached_comments(int index, const char* value, size_t size); + std::string* add_leading_detached_comments(); + void add_leading_detached_comments(const std::string& value); + void add_leading_detached_comments(std::string&& value); + void add_leading_detached_comments(const char* value); + void add_leading_detached_comments(const char* value, size_t size); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField& leading_detached_comments() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField* mutable_leading_detached_comments(); + private: + const std::string& _internal_leading_detached_comments(int index) const; + std::string* _internal_add_leading_detached_comments(); + public: + + // optional string leading_comments = 3; + bool has_leading_comments() const; + private: + bool _internal_has_leading_comments() const; + public: + void clear_leading_comments(); + const std::string& leading_comments() const; + void set_leading_comments(const std::string& value); + void set_leading_comments(std::string&& value); + void set_leading_comments(const char* value); + void set_leading_comments(const char* value, size_t size); + std::string* mutable_leading_comments(); + std::string* release_leading_comments(); + void set_allocated_leading_comments(std::string* leading_comments); + private: + const std::string& _internal_leading_comments() const; + void _internal_set_leading_comments(const std::string& value); + std::string* _internal_mutable_leading_comments(); + public: + + // optional string trailing_comments = 4; + bool has_trailing_comments() const; + private: + bool _internal_has_trailing_comments() const; + public: + void clear_trailing_comments(); + const std::string& trailing_comments() const; + void set_trailing_comments(const std::string& value); + void set_trailing_comments(std::string&& value); + void set_trailing_comments(const char* value); + void set_trailing_comments(const char* value, size_t size); + std::string* mutable_trailing_comments(); + std::string* release_trailing_comments(); + void set_allocated_trailing_comments(std::string* trailing_comments); + private: + const std::string& _internal_trailing_comments() const; + void _internal_set_trailing_comments(const std::string& value); + std::string* _internal_mutable_trailing_comments(); + public: + + // @@protoc_insertion_point(class_scope:google.protobuf.SourceCodeInfo.Location) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 > path_; + mutable std::atomic _path_cached_byte_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 > span_; + mutable std::atomic _span_cached_byte_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField leading_detached_comments_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr leading_comments_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr trailing_comments_; + friend struct ::TableStruct_google_2fprotobuf_2fdescriptor_2eproto; +}; +// ------------------------------------------------------------------- + +class PROTOBUF_EXPORT SourceCodeInfo PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:google.protobuf.SourceCodeInfo) */ { + public: + inline SourceCodeInfo() : SourceCodeInfo(nullptr) {} + virtual ~SourceCodeInfo(); + + SourceCodeInfo(const SourceCodeInfo& from); + SourceCodeInfo(SourceCodeInfo&& from) noexcept + : SourceCodeInfo() { + *this = ::std::move(from); + } + + inline SourceCodeInfo& operator=(const SourceCodeInfo& from) { + CopyFrom(from); + return *this; + } + inline SourceCodeInfo& operator=(SourceCodeInfo&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const SourceCodeInfo& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const SourceCodeInfo* internal_default_instance() { + return reinterpret_cast( + &_SourceCodeInfo_default_instance_); + } + static constexpr int kIndexInFileMessages = + 24; + + friend void swap(SourceCodeInfo& a, SourceCodeInfo& b) { + a.Swap(&b); + } + inline void Swap(SourceCodeInfo* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(SourceCodeInfo* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline SourceCodeInfo* New() const final { + return CreateMaybeMessage(nullptr); + } + + SourceCodeInfo* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const SourceCodeInfo& from); + void MergeFrom(const SourceCodeInfo& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(SourceCodeInfo* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "google.protobuf.SourceCodeInfo"; + } + protected: + explicit SourceCodeInfo(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_google_2fprotobuf_2fdescriptor_2eproto); + return ::descriptor_table_google_2fprotobuf_2fdescriptor_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + typedef SourceCodeInfo_Location Location; + + // accessors ------------------------------------------------------- + + enum : int { + kLocationFieldNumber = 1, + }; + // repeated .google.protobuf.SourceCodeInfo.Location location = 1; + int location_size() const; + private: + int _internal_location_size() const; + public: + void clear_location(); + PROTOBUF_NAMESPACE_ID::SourceCodeInfo_Location* mutable_location(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::SourceCodeInfo_Location >* + mutable_location(); + private: + const PROTOBUF_NAMESPACE_ID::SourceCodeInfo_Location& _internal_location(int index) const; + PROTOBUF_NAMESPACE_ID::SourceCodeInfo_Location* _internal_add_location(); + public: + const PROTOBUF_NAMESPACE_ID::SourceCodeInfo_Location& location(int index) const; + PROTOBUF_NAMESPACE_ID::SourceCodeInfo_Location* add_location(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::SourceCodeInfo_Location >& + location() const; + + // @@protoc_insertion_point(class_scope:google.protobuf.SourceCodeInfo) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::SourceCodeInfo_Location > location_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + friend struct ::TableStruct_google_2fprotobuf_2fdescriptor_2eproto; +}; +// ------------------------------------------------------------------- + +class PROTOBUF_EXPORT GeneratedCodeInfo_Annotation PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:google.protobuf.GeneratedCodeInfo.Annotation) */ { + public: + inline GeneratedCodeInfo_Annotation() : GeneratedCodeInfo_Annotation(nullptr) {} + virtual ~GeneratedCodeInfo_Annotation(); + + GeneratedCodeInfo_Annotation(const GeneratedCodeInfo_Annotation& from); + GeneratedCodeInfo_Annotation(GeneratedCodeInfo_Annotation&& from) noexcept + : GeneratedCodeInfo_Annotation() { + *this = ::std::move(from); + } + + inline GeneratedCodeInfo_Annotation& operator=(const GeneratedCodeInfo_Annotation& from) { + CopyFrom(from); + return *this; + } + inline GeneratedCodeInfo_Annotation& operator=(GeneratedCodeInfo_Annotation&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const GeneratedCodeInfo_Annotation& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const GeneratedCodeInfo_Annotation* internal_default_instance() { + return reinterpret_cast( + &_GeneratedCodeInfo_Annotation_default_instance_); + } + static constexpr int kIndexInFileMessages = + 25; + + friend void swap(GeneratedCodeInfo_Annotation& a, GeneratedCodeInfo_Annotation& b) { + a.Swap(&b); + } + inline void Swap(GeneratedCodeInfo_Annotation* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(GeneratedCodeInfo_Annotation* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline GeneratedCodeInfo_Annotation* New() const final { + return CreateMaybeMessage(nullptr); + } + + GeneratedCodeInfo_Annotation* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const GeneratedCodeInfo_Annotation& from); + void MergeFrom(const GeneratedCodeInfo_Annotation& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(GeneratedCodeInfo_Annotation* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "google.protobuf.GeneratedCodeInfo.Annotation"; + } + protected: + explicit GeneratedCodeInfo_Annotation(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_google_2fprotobuf_2fdescriptor_2eproto); + return ::descriptor_table_google_2fprotobuf_2fdescriptor_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kPathFieldNumber = 1, + kSourceFileFieldNumber = 2, + kBeginFieldNumber = 3, + kEndFieldNumber = 4, + }; + // repeated int32 path = 1 [packed = true]; + int path_size() const; + private: + int _internal_path_size() const; + public: + void clear_path(); + private: + ::PROTOBUF_NAMESPACE_ID::int32 _internal_path(int index) const; + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >& + _internal_path() const; + void _internal_add_path(::PROTOBUF_NAMESPACE_ID::int32 value); + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >* + _internal_mutable_path(); + public: + ::PROTOBUF_NAMESPACE_ID::int32 path(int index) const; + void set_path(int index, ::PROTOBUF_NAMESPACE_ID::int32 value); + void add_path(::PROTOBUF_NAMESPACE_ID::int32 value); + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >& + path() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >* + mutable_path(); + + // optional string source_file = 2; + bool has_source_file() const; + private: + bool _internal_has_source_file() const; + public: + void clear_source_file(); + const std::string& source_file() const; + void set_source_file(const std::string& value); + void set_source_file(std::string&& value); + void set_source_file(const char* value); + void set_source_file(const char* value, size_t size); + std::string* mutable_source_file(); + std::string* release_source_file(); + void set_allocated_source_file(std::string* source_file); + private: + const std::string& _internal_source_file() const; + void _internal_set_source_file(const std::string& value); + std::string* _internal_mutable_source_file(); + public: + + // optional int32 begin = 3; + bool has_begin() const; + private: + bool _internal_has_begin() const; + public: + void clear_begin(); + ::PROTOBUF_NAMESPACE_ID::int32 begin() const; + void set_begin(::PROTOBUF_NAMESPACE_ID::int32 value); + private: + ::PROTOBUF_NAMESPACE_ID::int32 _internal_begin() const; + void _internal_set_begin(::PROTOBUF_NAMESPACE_ID::int32 value); + public: + + // optional int32 end = 4; + bool has_end() const; + private: + bool _internal_has_end() const; + public: + void clear_end(); + ::PROTOBUF_NAMESPACE_ID::int32 end() const; + void set_end(::PROTOBUF_NAMESPACE_ID::int32 value); + private: + ::PROTOBUF_NAMESPACE_ID::int32 _internal_end() const; + void _internal_set_end(::PROTOBUF_NAMESPACE_ID::int32 value); + public: + + // @@protoc_insertion_point(class_scope:google.protobuf.GeneratedCodeInfo.Annotation) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 > path_; + mutable std::atomic _path_cached_byte_size_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr source_file_; + ::PROTOBUF_NAMESPACE_ID::int32 begin_; + ::PROTOBUF_NAMESPACE_ID::int32 end_; + friend struct ::TableStruct_google_2fprotobuf_2fdescriptor_2eproto; +}; +// ------------------------------------------------------------------- + +class PROTOBUF_EXPORT GeneratedCodeInfo PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:google.protobuf.GeneratedCodeInfo) */ { + public: + inline GeneratedCodeInfo() : GeneratedCodeInfo(nullptr) {} + virtual ~GeneratedCodeInfo(); + + GeneratedCodeInfo(const GeneratedCodeInfo& from); + GeneratedCodeInfo(GeneratedCodeInfo&& from) noexcept + : GeneratedCodeInfo() { + *this = ::std::move(from); + } + + inline GeneratedCodeInfo& operator=(const GeneratedCodeInfo& from) { + CopyFrom(from); + return *this; + } + inline GeneratedCodeInfo& operator=(GeneratedCodeInfo&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const GeneratedCodeInfo& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const GeneratedCodeInfo* internal_default_instance() { + return reinterpret_cast( + &_GeneratedCodeInfo_default_instance_); + } + static constexpr int kIndexInFileMessages = + 26; + + friend void swap(GeneratedCodeInfo& a, GeneratedCodeInfo& b) { + a.Swap(&b); + } + inline void Swap(GeneratedCodeInfo* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(GeneratedCodeInfo* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline GeneratedCodeInfo* New() const final { + return CreateMaybeMessage(nullptr); + } + + GeneratedCodeInfo* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const GeneratedCodeInfo& from); + void MergeFrom(const GeneratedCodeInfo& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(GeneratedCodeInfo* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "google.protobuf.GeneratedCodeInfo"; + } + protected: + explicit GeneratedCodeInfo(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_google_2fprotobuf_2fdescriptor_2eproto); + return ::descriptor_table_google_2fprotobuf_2fdescriptor_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + typedef GeneratedCodeInfo_Annotation Annotation; + + // accessors ------------------------------------------------------- + + enum : int { + kAnnotationFieldNumber = 1, + }; + // repeated .google.protobuf.GeneratedCodeInfo.Annotation annotation = 1; + int annotation_size() const; + private: + int _internal_annotation_size() const; + public: + void clear_annotation(); + PROTOBUF_NAMESPACE_ID::GeneratedCodeInfo_Annotation* mutable_annotation(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::GeneratedCodeInfo_Annotation >* + mutable_annotation(); + private: + const PROTOBUF_NAMESPACE_ID::GeneratedCodeInfo_Annotation& _internal_annotation(int index) const; + PROTOBUF_NAMESPACE_ID::GeneratedCodeInfo_Annotation* _internal_add_annotation(); + public: + const PROTOBUF_NAMESPACE_ID::GeneratedCodeInfo_Annotation& annotation(int index) const; + PROTOBUF_NAMESPACE_ID::GeneratedCodeInfo_Annotation* add_annotation(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::GeneratedCodeInfo_Annotation >& + annotation() const; + + // @@protoc_insertion_point(class_scope:google.protobuf.GeneratedCodeInfo) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::GeneratedCodeInfo_Annotation > annotation_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + friend struct ::TableStruct_google_2fprotobuf_2fdescriptor_2eproto; +}; +// =================================================================== + + +// =================================================================== + +#ifdef __GNUC__ + #pragma GCC diagnostic push + #pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif // __GNUC__ +// FileDescriptorSet + +// repeated .google.protobuf.FileDescriptorProto file = 1; +inline int FileDescriptorSet::_internal_file_size() const { + return file_.size(); +} +inline int FileDescriptorSet::file_size() const { + return _internal_file_size(); +} +inline void FileDescriptorSet::clear_file() { + file_.Clear(); +} +inline PROTOBUF_NAMESPACE_ID::FileDescriptorProto* FileDescriptorSet::mutable_file(int index) { + // @@protoc_insertion_point(field_mutable:google.protobuf.FileDescriptorSet.file) + return file_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::FileDescriptorProto >* +FileDescriptorSet::mutable_file() { + // @@protoc_insertion_point(field_mutable_list:google.protobuf.FileDescriptorSet.file) + return &file_; +} +inline const PROTOBUF_NAMESPACE_ID::FileDescriptorProto& FileDescriptorSet::_internal_file(int index) const { + return file_.Get(index); +} +inline const PROTOBUF_NAMESPACE_ID::FileDescriptorProto& FileDescriptorSet::file(int index) const { + // @@protoc_insertion_point(field_get:google.protobuf.FileDescriptorSet.file) + return _internal_file(index); +} +inline PROTOBUF_NAMESPACE_ID::FileDescriptorProto* FileDescriptorSet::_internal_add_file() { + return file_.Add(); +} +inline PROTOBUF_NAMESPACE_ID::FileDescriptorProto* FileDescriptorSet::add_file() { + // @@protoc_insertion_point(field_add:google.protobuf.FileDescriptorSet.file) + return _internal_add_file(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::FileDescriptorProto >& +FileDescriptorSet::file() const { + // @@protoc_insertion_point(field_list:google.protobuf.FileDescriptorSet.file) + return file_; +} + +// ------------------------------------------------------------------- + +// FileDescriptorProto + +// optional string name = 1; +inline bool FileDescriptorProto::_internal_has_name() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool FileDescriptorProto::has_name() const { + return _internal_has_name(); +} +inline void FileDescriptorProto::clear_name() { + name_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000001u; +} +inline const std::string& FileDescriptorProto::name() const { + // @@protoc_insertion_point(field_get:google.protobuf.FileDescriptorProto.name) + return _internal_name(); +} +inline void FileDescriptorProto::set_name(const std::string& value) { + _internal_set_name(value); + // @@protoc_insertion_point(field_set:google.protobuf.FileDescriptorProto.name) +} +inline std::string* FileDescriptorProto::mutable_name() { + // @@protoc_insertion_point(field_mutable:google.protobuf.FileDescriptorProto.name) + return _internal_mutable_name(); +} +inline const std::string& FileDescriptorProto::_internal_name() const { + return name_.Get(); +} +inline void FileDescriptorProto::_internal_set_name(const std::string& value) { + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void FileDescriptorProto::set_name(std::string&& value) { + _has_bits_[0] |= 0x00000001u; + name_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:google.protobuf.FileDescriptorProto.name) +} +inline void FileDescriptorProto::set_name(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:google.protobuf.FileDescriptorProto.name) +} +inline void FileDescriptorProto::set_name(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:google.protobuf.FileDescriptorProto.name) +} +inline std::string* FileDescriptorProto::_internal_mutable_name() { + _has_bits_[0] |= 0x00000001u; + return name_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* FileDescriptorProto::release_name() { + // @@protoc_insertion_point(field_release:google.protobuf.FileDescriptorProto.name) + if (!_internal_has_name()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000001u; + return name_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void FileDescriptorProto::set_allocated_name(std::string* name) { + if (name != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + name_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), name, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:google.protobuf.FileDescriptorProto.name) +} + +// optional string package = 2; +inline bool FileDescriptorProto::_internal_has_package() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + return value; +} +inline bool FileDescriptorProto::has_package() const { + return _internal_has_package(); +} +inline void FileDescriptorProto::clear_package() { + package_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000002u; +} +inline const std::string& FileDescriptorProto::package() const { + // @@protoc_insertion_point(field_get:google.protobuf.FileDescriptorProto.package) + return _internal_package(); +} +inline void FileDescriptorProto::set_package(const std::string& value) { + _internal_set_package(value); + // @@protoc_insertion_point(field_set:google.protobuf.FileDescriptorProto.package) +} +inline std::string* FileDescriptorProto::mutable_package() { + // @@protoc_insertion_point(field_mutable:google.protobuf.FileDescriptorProto.package) + return _internal_mutable_package(); +} +inline const std::string& FileDescriptorProto::_internal_package() const { + return package_.Get(); +} +inline void FileDescriptorProto::_internal_set_package(const std::string& value) { + _has_bits_[0] |= 0x00000002u; + package_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void FileDescriptorProto::set_package(std::string&& value) { + _has_bits_[0] |= 0x00000002u; + package_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:google.protobuf.FileDescriptorProto.package) +} +inline void FileDescriptorProto::set_package(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000002u; + package_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:google.protobuf.FileDescriptorProto.package) +} +inline void FileDescriptorProto::set_package(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000002u; + package_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:google.protobuf.FileDescriptorProto.package) +} +inline std::string* FileDescriptorProto::_internal_mutable_package() { + _has_bits_[0] |= 0x00000002u; + return package_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* FileDescriptorProto::release_package() { + // @@protoc_insertion_point(field_release:google.protobuf.FileDescriptorProto.package) + if (!_internal_has_package()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000002u; + return package_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void FileDescriptorProto::set_allocated_package(std::string* package) { + if (package != nullptr) { + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + package_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), package, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:google.protobuf.FileDescriptorProto.package) +} + +// repeated string dependency = 3; +inline int FileDescriptorProto::_internal_dependency_size() const { + return dependency_.size(); +} +inline int FileDescriptorProto::dependency_size() const { + return _internal_dependency_size(); +} +inline void FileDescriptorProto::clear_dependency() { + dependency_.Clear(); +} +inline std::string* FileDescriptorProto::add_dependency() { + // @@protoc_insertion_point(field_add_mutable:google.protobuf.FileDescriptorProto.dependency) + return _internal_add_dependency(); +} +inline const std::string& FileDescriptorProto::_internal_dependency(int index) const { + return dependency_.Get(index); +} +inline const std::string& FileDescriptorProto::dependency(int index) const { + // @@protoc_insertion_point(field_get:google.protobuf.FileDescriptorProto.dependency) + return _internal_dependency(index); +} +inline std::string* FileDescriptorProto::mutable_dependency(int index) { + // @@protoc_insertion_point(field_mutable:google.protobuf.FileDescriptorProto.dependency) + return dependency_.Mutable(index); +} +inline void FileDescriptorProto::set_dependency(int index, const std::string& value) { + // @@protoc_insertion_point(field_set:google.protobuf.FileDescriptorProto.dependency) + dependency_.Mutable(index)->assign(value); +} +inline void FileDescriptorProto::set_dependency(int index, std::string&& value) { + // @@protoc_insertion_point(field_set:google.protobuf.FileDescriptorProto.dependency) + dependency_.Mutable(index)->assign(std::move(value)); +} +inline void FileDescriptorProto::set_dependency(int index, const char* value) { + GOOGLE_DCHECK(value != nullptr); + dependency_.Mutable(index)->assign(value); + // @@protoc_insertion_point(field_set_char:google.protobuf.FileDescriptorProto.dependency) +} +inline void FileDescriptorProto::set_dependency(int index, const char* value, size_t size) { + dependency_.Mutable(index)->assign( + reinterpret_cast(value), size); + // @@protoc_insertion_point(field_set_pointer:google.protobuf.FileDescriptorProto.dependency) +} +inline std::string* FileDescriptorProto::_internal_add_dependency() { + return dependency_.Add(); +} +inline void FileDescriptorProto::add_dependency(const std::string& value) { + dependency_.Add()->assign(value); + // @@protoc_insertion_point(field_add:google.protobuf.FileDescriptorProto.dependency) +} +inline void FileDescriptorProto::add_dependency(std::string&& value) { + dependency_.Add(std::move(value)); + // @@protoc_insertion_point(field_add:google.protobuf.FileDescriptorProto.dependency) +} +inline void FileDescriptorProto::add_dependency(const char* value) { + GOOGLE_DCHECK(value != nullptr); + dependency_.Add()->assign(value); + // @@protoc_insertion_point(field_add_char:google.protobuf.FileDescriptorProto.dependency) +} +inline void FileDescriptorProto::add_dependency(const char* value, size_t size) { + dependency_.Add()->assign(reinterpret_cast(value), size); + // @@protoc_insertion_point(field_add_pointer:google.protobuf.FileDescriptorProto.dependency) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField& +FileDescriptorProto::dependency() const { + // @@protoc_insertion_point(field_list:google.protobuf.FileDescriptorProto.dependency) + return dependency_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField* +FileDescriptorProto::mutable_dependency() { + // @@protoc_insertion_point(field_mutable_list:google.protobuf.FileDescriptorProto.dependency) + return &dependency_; +} + +// repeated int32 public_dependency = 10; +inline int FileDescriptorProto::_internal_public_dependency_size() const { + return public_dependency_.size(); +} +inline int FileDescriptorProto::public_dependency_size() const { + return _internal_public_dependency_size(); +} +inline void FileDescriptorProto::clear_public_dependency() { + public_dependency_.Clear(); +} +inline ::PROTOBUF_NAMESPACE_ID::int32 FileDescriptorProto::_internal_public_dependency(int index) const { + return public_dependency_.Get(index); +} +inline ::PROTOBUF_NAMESPACE_ID::int32 FileDescriptorProto::public_dependency(int index) const { + // @@protoc_insertion_point(field_get:google.protobuf.FileDescriptorProto.public_dependency) + return _internal_public_dependency(index); +} +inline void FileDescriptorProto::set_public_dependency(int index, ::PROTOBUF_NAMESPACE_ID::int32 value) { + public_dependency_.Set(index, value); + // @@protoc_insertion_point(field_set:google.protobuf.FileDescriptorProto.public_dependency) +} +inline void FileDescriptorProto::_internal_add_public_dependency(::PROTOBUF_NAMESPACE_ID::int32 value) { + public_dependency_.Add(value); +} +inline void FileDescriptorProto::add_public_dependency(::PROTOBUF_NAMESPACE_ID::int32 value) { + _internal_add_public_dependency(value); + // @@protoc_insertion_point(field_add:google.protobuf.FileDescriptorProto.public_dependency) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >& +FileDescriptorProto::_internal_public_dependency() const { + return public_dependency_; +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >& +FileDescriptorProto::public_dependency() const { + // @@protoc_insertion_point(field_list:google.protobuf.FileDescriptorProto.public_dependency) + return _internal_public_dependency(); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >* +FileDescriptorProto::_internal_mutable_public_dependency() { + return &public_dependency_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >* +FileDescriptorProto::mutable_public_dependency() { + // @@protoc_insertion_point(field_mutable_list:google.protobuf.FileDescriptorProto.public_dependency) + return _internal_mutable_public_dependency(); +} + +// repeated int32 weak_dependency = 11; +inline int FileDescriptorProto::_internal_weak_dependency_size() const { + return weak_dependency_.size(); +} +inline int FileDescriptorProto::weak_dependency_size() const { + return _internal_weak_dependency_size(); +} +inline void FileDescriptorProto::clear_weak_dependency() { + weak_dependency_.Clear(); +} +inline ::PROTOBUF_NAMESPACE_ID::int32 FileDescriptorProto::_internal_weak_dependency(int index) const { + return weak_dependency_.Get(index); +} +inline ::PROTOBUF_NAMESPACE_ID::int32 FileDescriptorProto::weak_dependency(int index) const { + // @@protoc_insertion_point(field_get:google.protobuf.FileDescriptorProto.weak_dependency) + return _internal_weak_dependency(index); +} +inline void FileDescriptorProto::set_weak_dependency(int index, ::PROTOBUF_NAMESPACE_ID::int32 value) { + weak_dependency_.Set(index, value); + // @@protoc_insertion_point(field_set:google.protobuf.FileDescriptorProto.weak_dependency) +} +inline void FileDescriptorProto::_internal_add_weak_dependency(::PROTOBUF_NAMESPACE_ID::int32 value) { + weak_dependency_.Add(value); +} +inline void FileDescriptorProto::add_weak_dependency(::PROTOBUF_NAMESPACE_ID::int32 value) { + _internal_add_weak_dependency(value); + // @@protoc_insertion_point(field_add:google.protobuf.FileDescriptorProto.weak_dependency) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >& +FileDescriptorProto::_internal_weak_dependency() const { + return weak_dependency_; +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >& +FileDescriptorProto::weak_dependency() const { + // @@protoc_insertion_point(field_list:google.protobuf.FileDescriptorProto.weak_dependency) + return _internal_weak_dependency(); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >* +FileDescriptorProto::_internal_mutable_weak_dependency() { + return &weak_dependency_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >* +FileDescriptorProto::mutable_weak_dependency() { + // @@protoc_insertion_point(field_mutable_list:google.protobuf.FileDescriptorProto.weak_dependency) + return _internal_mutable_weak_dependency(); +} + +// repeated .google.protobuf.DescriptorProto message_type = 4; +inline int FileDescriptorProto::_internal_message_type_size() const { + return message_type_.size(); +} +inline int FileDescriptorProto::message_type_size() const { + return _internal_message_type_size(); +} +inline void FileDescriptorProto::clear_message_type() { + message_type_.Clear(); +} +inline PROTOBUF_NAMESPACE_ID::DescriptorProto* FileDescriptorProto::mutable_message_type(int index) { + // @@protoc_insertion_point(field_mutable:google.protobuf.FileDescriptorProto.message_type) + return message_type_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::DescriptorProto >* +FileDescriptorProto::mutable_message_type() { + // @@protoc_insertion_point(field_mutable_list:google.protobuf.FileDescriptorProto.message_type) + return &message_type_; +} +inline const PROTOBUF_NAMESPACE_ID::DescriptorProto& FileDescriptorProto::_internal_message_type(int index) const { + return message_type_.Get(index); +} +inline const PROTOBUF_NAMESPACE_ID::DescriptorProto& FileDescriptorProto::message_type(int index) const { + // @@protoc_insertion_point(field_get:google.protobuf.FileDescriptorProto.message_type) + return _internal_message_type(index); +} +inline PROTOBUF_NAMESPACE_ID::DescriptorProto* FileDescriptorProto::_internal_add_message_type() { + return message_type_.Add(); +} +inline PROTOBUF_NAMESPACE_ID::DescriptorProto* FileDescriptorProto::add_message_type() { + // @@protoc_insertion_point(field_add:google.protobuf.FileDescriptorProto.message_type) + return _internal_add_message_type(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::DescriptorProto >& +FileDescriptorProto::message_type() const { + // @@protoc_insertion_point(field_list:google.protobuf.FileDescriptorProto.message_type) + return message_type_; +} + +// repeated .google.protobuf.EnumDescriptorProto enum_type = 5; +inline int FileDescriptorProto::_internal_enum_type_size() const { + return enum_type_.size(); +} +inline int FileDescriptorProto::enum_type_size() const { + return _internal_enum_type_size(); +} +inline void FileDescriptorProto::clear_enum_type() { + enum_type_.Clear(); +} +inline PROTOBUF_NAMESPACE_ID::EnumDescriptorProto* FileDescriptorProto::mutable_enum_type(int index) { + // @@protoc_insertion_point(field_mutable:google.protobuf.FileDescriptorProto.enum_type) + return enum_type_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::EnumDescriptorProto >* +FileDescriptorProto::mutable_enum_type() { + // @@protoc_insertion_point(field_mutable_list:google.protobuf.FileDescriptorProto.enum_type) + return &enum_type_; +} +inline const PROTOBUF_NAMESPACE_ID::EnumDescriptorProto& FileDescriptorProto::_internal_enum_type(int index) const { + return enum_type_.Get(index); +} +inline const PROTOBUF_NAMESPACE_ID::EnumDescriptorProto& FileDescriptorProto::enum_type(int index) const { + // @@protoc_insertion_point(field_get:google.protobuf.FileDescriptorProto.enum_type) + return _internal_enum_type(index); +} +inline PROTOBUF_NAMESPACE_ID::EnumDescriptorProto* FileDescriptorProto::_internal_add_enum_type() { + return enum_type_.Add(); +} +inline PROTOBUF_NAMESPACE_ID::EnumDescriptorProto* FileDescriptorProto::add_enum_type() { + // @@protoc_insertion_point(field_add:google.protobuf.FileDescriptorProto.enum_type) + return _internal_add_enum_type(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::EnumDescriptorProto >& +FileDescriptorProto::enum_type() const { + // @@protoc_insertion_point(field_list:google.protobuf.FileDescriptorProto.enum_type) + return enum_type_; +} + +// repeated .google.protobuf.ServiceDescriptorProto service = 6; +inline int FileDescriptorProto::_internal_service_size() const { + return service_.size(); +} +inline int FileDescriptorProto::service_size() const { + return _internal_service_size(); +} +inline void FileDescriptorProto::clear_service() { + service_.Clear(); +} +inline PROTOBUF_NAMESPACE_ID::ServiceDescriptorProto* FileDescriptorProto::mutable_service(int index) { + // @@protoc_insertion_point(field_mutable:google.protobuf.FileDescriptorProto.service) + return service_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::ServiceDescriptorProto >* +FileDescriptorProto::mutable_service() { + // @@protoc_insertion_point(field_mutable_list:google.protobuf.FileDescriptorProto.service) + return &service_; +} +inline const PROTOBUF_NAMESPACE_ID::ServiceDescriptorProto& FileDescriptorProto::_internal_service(int index) const { + return service_.Get(index); +} +inline const PROTOBUF_NAMESPACE_ID::ServiceDescriptorProto& FileDescriptorProto::service(int index) const { + // @@protoc_insertion_point(field_get:google.protobuf.FileDescriptorProto.service) + return _internal_service(index); +} +inline PROTOBUF_NAMESPACE_ID::ServiceDescriptorProto* FileDescriptorProto::_internal_add_service() { + return service_.Add(); +} +inline PROTOBUF_NAMESPACE_ID::ServiceDescriptorProto* FileDescriptorProto::add_service() { + // @@protoc_insertion_point(field_add:google.protobuf.FileDescriptorProto.service) + return _internal_add_service(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::ServiceDescriptorProto >& +FileDescriptorProto::service() const { + // @@protoc_insertion_point(field_list:google.protobuf.FileDescriptorProto.service) + return service_; +} + +// repeated .google.protobuf.FieldDescriptorProto extension = 7; +inline int FileDescriptorProto::_internal_extension_size() const { + return extension_.size(); +} +inline int FileDescriptorProto::extension_size() const { + return _internal_extension_size(); +} +inline void FileDescriptorProto::clear_extension() { + extension_.Clear(); +} +inline PROTOBUF_NAMESPACE_ID::FieldDescriptorProto* FileDescriptorProto::mutable_extension(int index) { + // @@protoc_insertion_point(field_mutable:google.protobuf.FileDescriptorProto.extension) + return extension_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::FieldDescriptorProto >* +FileDescriptorProto::mutable_extension() { + // @@protoc_insertion_point(field_mutable_list:google.protobuf.FileDescriptorProto.extension) + return &extension_; +} +inline const PROTOBUF_NAMESPACE_ID::FieldDescriptorProto& FileDescriptorProto::_internal_extension(int index) const { + return extension_.Get(index); +} +inline const PROTOBUF_NAMESPACE_ID::FieldDescriptorProto& FileDescriptorProto::extension(int index) const { + // @@protoc_insertion_point(field_get:google.protobuf.FileDescriptorProto.extension) + return _internal_extension(index); +} +inline PROTOBUF_NAMESPACE_ID::FieldDescriptorProto* FileDescriptorProto::_internal_add_extension() { + return extension_.Add(); +} +inline PROTOBUF_NAMESPACE_ID::FieldDescriptorProto* FileDescriptorProto::add_extension() { + // @@protoc_insertion_point(field_add:google.protobuf.FileDescriptorProto.extension) + return _internal_add_extension(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::FieldDescriptorProto >& +FileDescriptorProto::extension() const { + // @@protoc_insertion_point(field_list:google.protobuf.FileDescriptorProto.extension) + return extension_; +} + +// optional .google.protobuf.FileOptions options = 8; +inline bool FileDescriptorProto::_internal_has_options() const { + bool value = (_has_bits_[0] & 0x00000008u) != 0; + PROTOBUF_ASSUME(!value || options_ != nullptr); + return value; +} +inline bool FileDescriptorProto::has_options() const { + return _internal_has_options(); +} +inline void FileDescriptorProto::clear_options() { + if (options_ != nullptr) options_->Clear(); + _has_bits_[0] &= ~0x00000008u; +} +inline const PROTOBUF_NAMESPACE_ID::FileOptions& FileDescriptorProto::_internal_options() const { + const PROTOBUF_NAMESPACE_ID::FileOptions* p = options_; + return p != nullptr ? *p : *reinterpret_cast( + &PROTOBUF_NAMESPACE_ID::_FileOptions_default_instance_); +} +inline const PROTOBUF_NAMESPACE_ID::FileOptions& FileDescriptorProto::options() const { + // @@protoc_insertion_point(field_get:google.protobuf.FileDescriptorProto.options) + return _internal_options(); +} +inline void FileDescriptorProto::unsafe_arena_set_allocated_options( + PROTOBUF_NAMESPACE_ID::FileOptions* options) { + if (GetArena() == nullptr) { + delete reinterpret_cast<::PROTOBUF_NAMESPACE_ID::MessageLite*>(options_); + } + options_ = options; + if (options) { + _has_bits_[0] |= 0x00000008u; + } else { + _has_bits_[0] &= ~0x00000008u; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:google.protobuf.FileDescriptorProto.options) +} +inline PROTOBUF_NAMESPACE_ID::FileOptions* FileDescriptorProto::release_options() { + _has_bits_[0] &= ~0x00000008u; + PROTOBUF_NAMESPACE_ID::FileOptions* temp = options_; + options_ = nullptr; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + return temp; +} +inline PROTOBUF_NAMESPACE_ID::FileOptions* FileDescriptorProto::unsafe_arena_release_options() { + // @@protoc_insertion_point(field_release:google.protobuf.FileDescriptorProto.options) + _has_bits_[0] &= ~0x00000008u; + PROTOBUF_NAMESPACE_ID::FileOptions* temp = options_; + options_ = nullptr; + return temp; +} +inline PROTOBUF_NAMESPACE_ID::FileOptions* FileDescriptorProto::_internal_mutable_options() { + _has_bits_[0] |= 0x00000008u; + if (options_ == nullptr) { + auto* p = CreateMaybeMessage(GetArena()); + options_ = p; + } + return options_; +} +inline PROTOBUF_NAMESPACE_ID::FileOptions* FileDescriptorProto::mutable_options() { + // @@protoc_insertion_point(field_mutable:google.protobuf.FileDescriptorProto.options) + return _internal_mutable_options(); +} +inline void FileDescriptorProto::set_allocated_options(PROTOBUF_NAMESPACE_ID::FileOptions* options) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + if (message_arena == nullptr) { + delete options_; + } + if (options) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + ::PROTOBUF_NAMESPACE_ID::Arena::GetArena(options); + if (message_arena != submessage_arena) { + options = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, options, submessage_arena); + } + _has_bits_[0] |= 0x00000008u; + } else { + _has_bits_[0] &= ~0x00000008u; + } + options_ = options; + // @@protoc_insertion_point(field_set_allocated:google.protobuf.FileDescriptorProto.options) +} + +// optional .google.protobuf.SourceCodeInfo source_code_info = 9; +inline bool FileDescriptorProto::_internal_has_source_code_info() const { + bool value = (_has_bits_[0] & 0x00000010u) != 0; + PROTOBUF_ASSUME(!value || source_code_info_ != nullptr); + return value; +} +inline bool FileDescriptorProto::has_source_code_info() const { + return _internal_has_source_code_info(); +} +inline void FileDescriptorProto::clear_source_code_info() { + if (source_code_info_ != nullptr) source_code_info_->Clear(); + _has_bits_[0] &= ~0x00000010u; +} +inline const PROTOBUF_NAMESPACE_ID::SourceCodeInfo& FileDescriptorProto::_internal_source_code_info() const { + const PROTOBUF_NAMESPACE_ID::SourceCodeInfo* p = source_code_info_; + return p != nullptr ? *p : *reinterpret_cast( + &PROTOBUF_NAMESPACE_ID::_SourceCodeInfo_default_instance_); +} +inline const PROTOBUF_NAMESPACE_ID::SourceCodeInfo& FileDescriptorProto::source_code_info() const { + // @@protoc_insertion_point(field_get:google.protobuf.FileDescriptorProto.source_code_info) + return _internal_source_code_info(); +} +inline void FileDescriptorProto::unsafe_arena_set_allocated_source_code_info( + PROTOBUF_NAMESPACE_ID::SourceCodeInfo* source_code_info) { + if (GetArena() == nullptr) { + delete reinterpret_cast<::PROTOBUF_NAMESPACE_ID::MessageLite*>(source_code_info_); + } + source_code_info_ = source_code_info; + if (source_code_info) { + _has_bits_[0] |= 0x00000010u; + } else { + _has_bits_[0] &= ~0x00000010u; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:google.protobuf.FileDescriptorProto.source_code_info) +} +inline PROTOBUF_NAMESPACE_ID::SourceCodeInfo* FileDescriptorProto::release_source_code_info() { + _has_bits_[0] &= ~0x00000010u; + PROTOBUF_NAMESPACE_ID::SourceCodeInfo* temp = source_code_info_; + source_code_info_ = nullptr; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + return temp; +} +inline PROTOBUF_NAMESPACE_ID::SourceCodeInfo* FileDescriptorProto::unsafe_arena_release_source_code_info() { + // @@protoc_insertion_point(field_release:google.protobuf.FileDescriptorProto.source_code_info) + _has_bits_[0] &= ~0x00000010u; + PROTOBUF_NAMESPACE_ID::SourceCodeInfo* temp = source_code_info_; + source_code_info_ = nullptr; + return temp; +} +inline PROTOBUF_NAMESPACE_ID::SourceCodeInfo* FileDescriptorProto::_internal_mutable_source_code_info() { + _has_bits_[0] |= 0x00000010u; + if (source_code_info_ == nullptr) { + auto* p = CreateMaybeMessage(GetArena()); + source_code_info_ = p; + } + return source_code_info_; +} +inline PROTOBUF_NAMESPACE_ID::SourceCodeInfo* FileDescriptorProto::mutable_source_code_info() { + // @@protoc_insertion_point(field_mutable:google.protobuf.FileDescriptorProto.source_code_info) + return _internal_mutable_source_code_info(); +} +inline void FileDescriptorProto::set_allocated_source_code_info(PROTOBUF_NAMESPACE_ID::SourceCodeInfo* source_code_info) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + if (message_arena == nullptr) { + delete source_code_info_; + } + if (source_code_info) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + ::PROTOBUF_NAMESPACE_ID::Arena::GetArena(source_code_info); + if (message_arena != submessage_arena) { + source_code_info = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, source_code_info, submessage_arena); + } + _has_bits_[0] |= 0x00000010u; + } else { + _has_bits_[0] &= ~0x00000010u; + } + source_code_info_ = source_code_info; + // @@protoc_insertion_point(field_set_allocated:google.protobuf.FileDescriptorProto.source_code_info) +} + +// optional string syntax = 12; +inline bool FileDescriptorProto::_internal_has_syntax() const { + bool value = (_has_bits_[0] & 0x00000004u) != 0; + return value; +} +inline bool FileDescriptorProto::has_syntax() const { + return _internal_has_syntax(); +} +inline void FileDescriptorProto::clear_syntax() { + syntax_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000004u; +} +inline const std::string& FileDescriptorProto::syntax() const { + // @@protoc_insertion_point(field_get:google.protobuf.FileDescriptorProto.syntax) + return _internal_syntax(); +} +inline void FileDescriptorProto::set_syntax(const std::string& value) { + _internal_set_syntax(value); + // @@protoc_insertion_point(field_set:google.protobuf.FileDescriptorProto.syntax) +} +inline std::string* FileDescriptorProto::mutable_syntax() { + // @@protoc_insertion_point(field_mutable:google.protobuf.FileDescriptorProto.syntax) + return _internal_mutable_syntax(); +} +inline const std::string& FileDescriptorProto::_internal_syntax() const { + return syntax_.Get(); +} +inline void FileDescriptorProto::_internal_set_syntax(const std::string& value) { + _has_bits_[0] |= 0x00000004u; + syntax_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void FileDescriptorProto::set_syntax(std::string&& value) { + _has_bits_[0] |= 0x00000004u; + syntax_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:google.protobuf.FileDescriptorProto.syntax) +} +inline void FileDescriptorProto::set_syntax(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000004u; + syntax_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:google.protobuf.FileDescriptorProto.syntax) +} +inline void FileDescriptorProto::set_syntax(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000004u; + syntax_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:google.protobuf.FileDescriptorProto.syntax) +} +inline std::string* FileDescriptorProto::_internal_mutable_syntax() { + _has_bits_[0] |= 0x00000004u; + return syntax_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* FileDescriptorProto::release_syntax() { + // @@protoc_insertion_point(field_release:google.protobuf.FileDescriptorProto.syntax) + if (!_internal_has_syntax()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000004u; + return syntax_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void FileDescriptorProto::set_allocated_syntax(std::string* syntax) { + if (syntax != nullptr) { + _has_bits_[0] |= 0x00000004u; + } else { + _has_bits_[0] &= ~0x00000004u; + } + syntax_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), syntax, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:google.protobuf.FileDescriptorProto.syntax) +} + +// ------------------------------------------------------------------- + +// DescriptorProto_ExtensionRange + +// optional int32 start = 1; +inline bool DescriptorProto_ExtensionRange::_internal_has_start() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + return value; +} +inline bool DescriptorProto_ExtensionRange::has_start() const { + return _internal_has_start(); +} +inline void DescriptorProto_ExtensionRange::clear_start() { + start_ = 0; + _has_bits_[0] &= ~0x00000002u; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 DescriptorProto_ExtensionRange::_internal_start() const { + return start_; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 DescriptorProto_ExtensionRange::start() const { + // @@protoc_insertion_point(field_get:google.protobuf.DescriptorProto.ExtensionRange.start) + return _internal_start(); +} +inline void DescriptorProto_ExtensionRange::_internal_set_start(::PROTOBUF_NAMESPACE_ID::int32 value) { + _has_bits_[0] |= 0x00000002u; + start_ = value; +} +inline void DescriptorProto_ExtensionRange::set_start(::PROTOBUF_NAMESPACE_ID::int32 value) { + _internal_set_start(value); + // @@protoc_insertion_point(field_set:google.protobuf.DescriptorProto.ExtensionRange.start) +} + +// optional int32 end = 2; +inline bool DescriptorProto_ExtensionRange::_internal_has_end() const { + bool value = (_has_bits_[0] & 0x00000004u) != 0; + return value; +} +inline bool DescriptorProto_ExtensionRange::has_end() const { + return _internal_has_end(); +} +inline void DescriptorProto_ExtensionRange::clear_end() { + end_ = 0; + _has_bits_[0] &= ~0x00000004u; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 DescriptorProto_ExtensionRange::_internal_end() const { + return end_; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 DescriptorProto_ExtensionRange::end() const { + // @@protoc_insertion_point(field_get:google.protobuf.DescriptorProto.ExtensionRange.end) + return _internal_end(); +} +inline void DescriptorProto_ExtensionRange::_internal_set_end(::PROTOBUF_NAMESPACE_ID::int32 value) { + _has_bits_[0] |= 0x00000004u; + end_ = value; +} +inline void DescriptorProto_ExtensionRange::set_end(::PROTOBUF_NAMESPACE_ID::int32 value) { + _internal_set_end(value); + // @@protoc_insertion_point(field_set:google.protobuf.DescriptorProto.ExtensionRange.end) +} + +// optional .google.protobuf.ExtensionRangeOptions options = 3; +inline bool DescriptorProto_ExtensionRange::_internal_has_options() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + PROTOBUF_ASSUME(!value || options_ != nullptr); + return value; +} +inline bool DescriptorProto_ExtensionRange::has_options() const { + return _internal_has_options(); +} +inline void DescriptorProto_ExtensionRange::clear_options() { + if (options_ != nullptr) options_->Clear(); + _has_bits_[0] &= ~0x00000001u; +} +inline const PROTOBUF_NAMESPACE_ID::ExtensionRangeOptions& DescriptorProto_ExtensionRange::_internal_options() const { + const PROTOBUF_NAMESPACE_ID::ExtensionRangeOptions* p = options_; + return p != nullptr ? *p : *reinterpret_cast( + &PROTOBUF_NAMESPACE_ID::_ExtensionRangeOptions_default_instance_); +} +inline const PROTOBUF_NAMESPACE_ID::ExtensionRangeOptions& DescriptorProto_ExtensionRange::options() const { + // @@protoc_insertion_point(field_get:google.protobuf.DescriptorProto.ExtensionRange.options) + return _internal_options(); +} +inline void DescriptorProto_ExtensionRange::unsafe_arena_set_allocated_options( + PROTOBUF_NAMESPACE_ID::ExtensionRangeOptions* options) { + if (GetArena() == nullptr) { + delete reinterpret_cast<::PROTOBUF_NAMESPACE_ID::MessageLite*>(options_); + } + options_ = options; + if (options) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:google.protobuf.DescriptorProto.ExtensionRange.options) +} +inline PROTOBUF_NAMESPACE_ID::ExtensionRangeOptions* DescriptorProto_ExtensionRange::release_options() { + _has_bits_[0] &= ~0x00000001u; + PROTOBUF_NAMESPACE_ID::ExtensionRangeOptions* temp = options_; + options_ = nullptr; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + return temp; +} +inline PROTOBUF_NAMESPACE_ID::ExtensionRangeOptions* DescriptorProto_ExtensionRange::unsafe_arena_release_options() { + // @@protoc_insertion_point(field_release:google.protobuf.DescriptorProto.ExtensionRange.options) + _has_bits_[0] &= ~0x00000001u; + PROTOBUF_NAMESPACE_ID::ExtensionRangeOptions* temp = options_; + options_ = nullptr; + return temp; +} +inline PROTOBUF_NAMESPACE_ID::ExtensionRangeOptions* DescriptorProto_ExtensionRange::_internal_mutable_options() { + _has_bits_[0] |= 0x00000001u; + if (options_ == nullptr) { + auto* p = CreateMaybeMessage(GetArena()); + options_ = p; + } + return options_; +} +inline PROTOBUF_NAMESPACE_ID::ExtensionRangeOptions* DescriptorProto_ExtensionRange::mutable_options() { + // @@protoc_insertion_point(field_mutable:google.protobuf.DescriptorProto.ExtensionRange.options) + return _internal_mutable_options(); +} +inline void DescriptorProto_ExtensionRange::set_allocated_options(PROTOBUF_NAMESPACE_ID::ExtensionRangeOptions* options) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + if (message_arena == nullptr) { + delete options_; + } + if (options) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + ::PROTOBUF_NAMESPACE_ID::Arena::GetArena(options); + if (message_arena != submessage_arena) { + options = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, options, submessage_arena); + } + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + options_ = options; + // @@protoc_insertion_point(field_set_allocated:google.protobuf.DescriptorProto.ExtensionRange.options) +} + +// ------------------------------------------------------------------- + +// DescriptorProto_ReservedRange + +// optional int32 start = 1; +inline bool DescriptorProto_ReservedRange::_internal_has_start() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool DescriptorProto_ReservedRange::has_start() const { + return _internal_has_start(); +} +inline void DescriptorProto_ReservedRange::clear_start() { + start_ = 0; + _has_bits_[0] &= ~0x00000001u; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 DescriptorProto_ReservedRange::_internal_start() const { + return start_; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 DescriptorProto_ReservedRange::start() const { + // @@protoc_insertion_point(field_get:google.protobuf.DescriptorProto.ReservedRange.start) + return _internal_start(); +} +inline void DescriptorProto_ReservedRange::_internal_set_start(::PROTOBUF_NAMESPACE_ID::int32 value) { + _has_bits_[0] |= 0x00000001u; + start_ = value; +} +inline void DescriptorProto_ReservedRange::set_start(::PROTOBUF_NAMESPACE_ID::int32 value) { + _internal_set_start(value); + // @@protoc_insertion_point(field_set:google.protobuf.DescriptorProto.ReservedRange.start) +} + +// optional int32 end = 2; +inline bool DescriptorProto_ReservedRange::_internal_has_end() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + return value; +} +inline bool DescriptorProto_ReservedRange::has_end() const { + return _internal_has_end(); +} +inline void DescriptorProto_ReservedRange::clear_end() { + end_ = 0; + _has_bits_[0] &= ~0x00000002u; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 DescriptorProto_ReservedRange::_internal_end() const { + return end_; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 DescriptorProto_ReservedRange::end() const { + // @@protoc_insertion_point(field_get:google.protobuf.DescriptorProto.ReservedRange.end) + return _internal_end(); +} +inline void DescriptorProto_ReservedRange::_internal_set_end(::PROTOBUF_NAMESPACE_ID::int32 value) { + _has_bits_[0] |= 0x00000002u; + end_ = value; +} +inline void DescriptorProto_ReservedRange::set_end(::PROTOBUF_NAMESPACE_ID::int32 value) { + _internal_set_end(value); + // @@protoc_insertion_point(field_set:google.protobuf.DescriptorProto.ReservedRange.end) +} + +// ------------------------------------------------------------------- + +// DescriptorProto + +// optional string name = 1; +inline bool DescriptorProto::_internal_has_name() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool DescriptorProto::has_name() const { + return _internal_has_name(); +} +inline void DescriptorProto::clear_name() { + name_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000001u; +} +inline const std::string& DescriptorProto::name() const { + // @@protoc_insertion_point(field_get:google.protobuf.DescriptorProto.name) + return _internal_name(); +} +inline void DescriptorProto::set_name(const std::string& value) { + _internal_set_name(value); + // @@protoc_insertion_point(field_set:google.protobuf.DescriptorProto.name) +} +inline std::string* DescriptorProto::mutable_name() { + // @@protoc_insertion_point(field_mutable:google.protobuf.DescriptorProto.name) + return _internal_mutable_name(); +} +inline const std::string& DescriptorProto::_internal_name() const { + return name_.Get(); +} +inline void DescriptorProto::_internal_set_name(const std::string& value) { + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void DescriptorProto::set_name(std::string&& value) { + _has_bits_[0] |= 0x00000001u; + name_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:google.protobuf.DescriptorProto.name) +} +inline void DescriptorProto::set_name(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:google.protobuf.DescriptorProto.name) +} +inline void DescriptorProto::set_name(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:google.protobuf.DescriptorProto.name) +} +inline std::string* DescriptorProto::_internal_mutable_name() { + _has_bits_[0] |= 0x00000001u; + return name_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* DescriptorProto::release_name() { + // @@protoc_insertion_point(field_release:google.protobuf.DescriptorProto.name) + if (!_internal_has_name()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000001u; + return name_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void DescriptorProto::set_allocated_name(std::string* name) { + if (name != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + name_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), name, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:google.protobuf.DescriptorProto.name) +} + +// repeated .google.protobuf.FieldDescriptorProto field = 2; +inline int DescriptorProto::_internal_field_size() const { + return field_.size(); +} +inline int DescriptorProto::field_size() const { + return _internal_field_size(); +} +inline void DescriptorProto::clear_field() { + field_.Clear(); +} +inline PROTOBUF_NAMESPACE_ID::FieldDescriptorProto* DescriptorProto::mutable_field(int index) { + // @@protoc_insertion_point(field_mutable:google.protobuf.DescriptorProto.field) + return field_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::FieldDescriptorProto >* +DescriptorProto::mutable_field() { + // @@protoc_insertion_point(field_mutable_list:google.protobuf.DescriptorProto.field) + return &field_; +} +inline const PROTOBUF_NAMESPACE_ID::FieldDescriptorProto& DescriptorProto::_internal_field(int index) const { + return field_.Get(index); +} +inline const PROTOBUF_NAMESPACE_ID::FieldDescriptorProto& DescriptorProto::field(int index) const { + // @@protoc_insertion_point(field_get:google.protobuf.DescriptorProto.field) + return _internal_field(index); +} +inline PROTOBUF_NAMESPACE_ID::FieldDescriptorProto* DescriptorProto::_internal_add_field() { + return field_.Add(); +} +inline PROTOBUF_NAMESPACE_ID::FieldDescriptorProto* DescriptorProto::add_field() { + // @@protoc_insertion_point(field_add:google.protobuf.DescriptorProto.field) + return _internal_add_field(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::FieldDescriptorProto >& +DescriptorProto::field() const { + // @@protoc_insertion_point(field_list:google.protobuf.DescriptorProto.field) + return field_; +} + +// repeated .google.protobuf.FieldDescriptorProto extension = 6; +inline int DescriptorProto::_internal_extension_size() const { + return extension_.size(); +} +inline int DescriptorProto::extension_size() const { + return _internal_extension_size(); +} +inline void DescriptorProto::clear_extension() { + extension_.Clear(); +} +inline PROTOBUF_NAMESPACE_ID::FieldDescriptorProto* DescriptorProto::mutable_extension(int index) { + // @@protoc_insertion_point(field_mutable:google.protobuf.DescriptorProto.extension) + return extension_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::FieldDescriptorProto >* +DescriptorProto::mutable_extension() { + // @@protoc_insertion_point(field_mutable_list:google.protobuf.DescriptorProto.extension) + return &extension_; +} +inline const PROTOBUF_NAMESPACE_ID::FieldDescriptorProto& DescriptorProto::_internal_extension(int index) const { + return extension_.Get(index); +} +inline const PROTOBUF_NAMESPACE_ID::FieldDescriptorProto& DescriptorProto::extension(int index) const { + // @@protoc_insertion_point(field_get:google.protobuf.DescriptorProto.extension) + return _internal_extension(index); +} +inline PROTOBUF_NAMESPACE_ID::FieldDescriptorProto* DescriptorProto::_internal_add_extension() { + return extension_.Add(); +} +inline PROTOBUF_NAMESPACE_ID::FieldDescriptorProto* DescriptorProto::add_extension() { + // @@protoc_insertion_point(field_add:google.protobuf.DescriptorProto.extension) + return _internal_add_extension(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::FieldDescriptorProto >& +DescriptorProto::extension() const { + // @@protoc_insertion_point(field_list:google.protobuf.DescriptorProto.extension) + return extension_; +} + +// repeated .google.protobuf.DescriptorProto nested_type = 3; +inline int DescriptorProto::_internal_nested_type_size() const { + return nested_type_.size(); +} +inline int DescriptorProto::nested_type_size() const { + return _internal_nested_type_size(); +} +inline void DescriptorProto::clear_nested_type() { + nested_type_.Clear(); +} +inline PROTOBUF_NAMESPACE_ID::DescriptorProto* DescriptorProto::mutable_nested_type(int index) { + // @@protoc_insertion_point(field_mutable:google.protobuf.DescriptorProto.nested_type) + return nested_type_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::DescriptorProto >* +DescriptorProto::mutable_nested_type() { + // @@protoc_insertion_point(field_mutable_list:google.protobuf.DescriptorProto.nested_type) + return &nested_type_; +} +inline const PROTOBUF_NAMESPACE_ID::DescriptorProto& DescriptorProto::_internal_nested_type(int index) const { + return nested_type_.Get(index); +} +inline const PROTOBUF_NAMESPACE_ID::DescriptorProto& DescriptorProto::nested_type(int index) const { + // @@protoc_insertion_point(field_get:google.protobuf.DescriptorProto.nested_type) + return _internal_nested_type(index); +} +inline PROTOBUF_NAMESPACE_ID::DescriptorProto* DescriptorProto::_internal_add_nested_type() { + return nested_type_.Add(); +} +inline PROTOBUF_NAMESPACE_ID::DescriptorProto* DescriptorProto::add_nested_type() { + // @@protoc_insertion_point(field_add:google.protobuf.DescriptorProto.nested_type) + return _internal_add_nested_type(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::DescriptorProto >& +DescriptorProto::nested_type() const { + // @@protoc_insertion_point(field_list:google.protobuf.DescriptorProto.nested_type) + return nested_type_; +} + +// repeated .google.protobuf.EnumDescriptorProto enum_type = 4; +inline int DescriptorProto::_internal_enum_type_size() const { + return enum_type_.size(); +} +inline int DescriptorProto::enum_type_size() const { + return _internal_enum_type_size(); +} +inline void DescriptorProto::clear_enum_type() { + enum_type_.Clear(); +} +inline PROTOBUF_NAMESPACE_ID::EnumDescriptorProto* DescriptorProto::mutable_enum_type(int index) { + // @@protoc_insertion_point(field_mutable:google.protobuf.DescriptorProto.enum_type) + return enum_type_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::EnumDescriptorProto >* +DescriptorProto::mutable_enum_type() { + // @@protoc_insertion_point(field_mutable_list:google.protobuf.DescriptorProto.enum_type) + return &enum_type_; +} +inline const PROTOBUF_NAMESPACE_ID::EnumDescriptorProto& DescriptorProto::_internal_enum_type(int index) const { + return enum_type_.Get(index); +} +inline const PROTOBUF_NAMESPACE_ID::EnumDescriptorProto& DescriptorProto::enum_type(int index) const { + // @@protoc_insertion_point(field_get:google.protobuf.DescriptorProto.enum_type) + return _internal_enum_type(index); +} +inline PROTOBUF_NAMESPACE_ID::EnumDescriptorProto* DescriptorProto::_internal_add_enum_type() { + return enum_type_.Add(); +} +inline PROTOBUF_NAMESPACE_ID::EnumDescriptorProto* DescriptorProto::add_enum_type() { + // @@protoc_insertion_point(field_add:google.protobuf.DescriptorProto.enum_type) + return _internal_add_enum_type(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::EnumDescriptorProto >& +DescriptorProto::enum_type() const { + // @@protoc_insertion_point(field_list:google.protobuf.DescriptorProto.enum_type) + return enum_type_; +} + +// repeated .google.protobuf.DescriptorProto.ExtensionRange extension_range = 5; +inline int DescriptorProto::_internal_extension_range_size() const { + return extension_range_.size(); +} +inline int DescriptorProto::extension_range_size() const { + return _internal_extension_range_size(); +} +inline void DescriptorProto::clear_extension_range() { + extension_range_.Clear(); +} +inline PROTOBUF_NAMESPACE_ID::DescriptorProto_ExtensionRange* DescriptorProto::mutable_extension_range(int index) { + // @@protoc_insertion_point(field_mutable:google.protobuf.DescriptorProto.extension_range) + return extension_range_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::DescriptorProto_ExtensionRange >* +DescriptorProto::mutable_extension_range() { + // @@protoc_insertion_point(field_mutable_list:google.protobuf.DescriptorProto.extension_range) + return &extension_range_; +} +inline const PROTOBUF_NAMESPACE_ID::DescriptorProto_ExtensionRange& DescriptorProto::_internal_extension_range(int index) const { + return extension_range_.Get(index); +} +inline const PROTOBUF_NAMESPACE_ID::DescriptorProto_ExtensionRange& DescriptorProto::extension_range(int index) const { + // @@protoc_insertion_point(field_get:google.protobuf.DescriptorProto.extension_range) + return _internal_extension_range(index); +} +inline PROTOBUF_NAMESPACE_ID::DescriptorProto_ExtensionRange* DescriptorProto::_internal_add_extension_range() { + return extension_range_.Add(); +} +inline PROTOBUF_NAMESPACE_ID::DescriptorProto_ExtensionRange* DescriptorProto::add_extension_range() { + // @@protoc_insertion_point(field_add:google.protobuf.DescriptorProto.extension_range) + return _internal_add_extension_range(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::DescriptorProto_ExtensionRange >& +DescriptorProto::extension_range() const { + // @@protoc_insertion_point(field_list:google.protobuf.DescriptorProto.extension_range) + return extension_range_; +} + +// repeated .google.protobuf.OneofDescriptorProto oneof_decl = 8; +inline int DescriptorProto::_internal_oneof_decl_size() const { + return oneof_decl_.size(); +} +inline int DescriptorProto::oneof_decl_size() const { + return _internal_oneof_decl_size(); +} +inline void DescriptorProto::clear_oneof_decl() { + oneof_decl_.Clear(); +} +inline PROTOBUF_NAMESPACE_ID::OneofDescriptorProto* DescriptorProto::mutable_oneof_decl(int index) { + // @@protoc_insertion_point(field_mutable:google.protobuf.DescriptorProto.oneof_decl) + return oneof_decl_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::OneofDescriptorProto >* +DescriptorProto::mutable_oneof_decl() { + // @@protoc_insertion_point(field_mutable_list:google.protobuf.DescriptorProto.oneof_decl) + return &oneof_decl_; +} +inline const PROTOBUF_NAMESPACE_ID::OneofDescriptorProto& DescriptorProto::_internal_oneof_decl(int index) const { + return oneof_decl_.Get(index); +} +inline const PROTOBUF_NAMESPACE_ID::OneofDescriptorProto& DescriptorProto::oneof_decl(int index) const { + // @@protoc_insertion_point(field_get:google.protobuf.DescriptorProto.oneof_decl) + return _internal_oneof_decl(index); +} +inline PROTOBUF_NAMESPACE_ID::OneofDescriptorProto* DescriptorProto::_internal_add_oneof_decl() { + return oneof_decl_.Add(); +} +inline PROTOBUF_NAMESPACE_ID::OneofDescriptorProto* DescriptorProto::add_oneof_decl() { + // @@protoc_insertion_point(field_add:google.protobuf.DescriptorProto.oneof_decl) + return _internal_add_oneof_decl(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::OneofDescriptorProto >& +DescriptorProto::oneof_decl() const { + // @@protoc_insertion_point(field_list:google.protobuf.DescriptorProto.oneof_decl) + return oneof_decl_; +} + +// optional .google.protobuf.MessageOptions options = 7; +inline bool DescriptorProto::_internal_has_options() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + PROTOBUF_ASSUME(!value || options_ != nullptr); + return value; +} +inline bool DescriptorProto::has_options() const { + return _internal_has_options(); +} +inline void DescriptorProto::clear_options() { + if (options_ != nullptr) options_->Clear(); + _has_bits_[0] &= ~0x00000002u; +} +inline const PROTOBUF_NAMESPACE_ID::MessageOptions& DescriptorProto::_internal_options() const { + const PROTOBUF_NAMESPACE_ID::MessageOptions* p = options_; + return p != nullptr ? *p : *reinterpret_cast( + &PROTOBUF_NAMESPACE_ID::_MessageOptions_default_instance_); +} +inline const PROTOBUF_NAMESPACE_ID::MessageOptions& DescriptorProto::options() const { + // @@protoc_insertion_point(field_get:google.protobuf.DescriptorProto.options) + return _internal_options(); +} +inline void DescriptorProto::unsafe_arena_set_allocated_options( + PROTOBUF_NAMESPACE_ID::MessageOptions* options) { + if (GetArena() == nullptr) { + delete reinterpret_cast<::PROTOBUF_NAMESPACE_ID::MessageLite*>(options_); + } + options_ = options; + if (options) { + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:google.protobuf.DescriptorProto.options) +} +inline PROTOBUF_NAMESPACE_ID::MessageOptions* DescriptorProto::release_options() { + _has_bits_[0] &= ~0x00000002u; + PROTOBUF_NAMESPACE_ID::MessageOptions* temp = options_; + options_ = nullptr; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + return temp; +} +inline PROTOBUF_NAMESPACE_ID::MessageOptions* DescriptorProto::unsafe_arena_release_options() { + // @@protoc_insertion_point(field_release:google.protobuf.DescriptorProto.options) + _has_bits_[0] &= ~0x00000002u; + PROTOBUF_NAMESPACE_ID::MessageOptions* temp = options_; + options_ = nullptr; + return temp; +} +inline PROTOBUF_NAMESPACE_ID::MessageOptions* DescriptorProto::_internal_mutable_options() { + _has_bits_[0] |= 0x00000002u; + if (options_ == nullptr) { + auto* p = CreateMaybeMessage(GetArena()); + options_ = p; + } + return options_; +} +inline PROTOBUF_NAMESPACE_ID::MessageOptions* DescriptorProto::mutable_options() { + // @@protoc_insertion_point(field_mutable:google.protobuf.DescriptorProto.options) + return _internal_mutable_options(); +} +inline void DescriptorProto::set_allocated_options(PROTOBUF_NAMESPACE_ID::MessageOptions* options) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + if (message_arena == nullptr) { + delete options_; + } + if (options) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + ::PROTOBUF_NAMESPACE_ID::Arena::GetArena(options); + if (message_arena != submessage_arena) { + options = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, options, submessage_arena); + } + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + options_ = options; + // @@protoc_insertion_point(field_set_allocated:google.protobuf.DescriptorProto.options) +} + +// repeated .google.protobuf.DescriptorProto.ReservedRange reserved_range = 9; +inline int DescriptorProto::_internal_reserved_range_size() const { + return reserved_range_.size(); +} +inline int DescriptorProto::reserved_range_size() const { + return _internal_reserved_range_size(); +} +inline void DescriptorProto::clear_reserved_range() { + reserved_range_.Clear(); +} +inline PROTOBUF_NAMESPACE_ID::DescriptorProto_ReservedRange* DescriptorProto::mutable_reserved_range(int index) { + // @@protoc_insertion_point(field_mutable:google.protobuf.DescriptorProto.reserved_range) + return reserved_range_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::DescriptorProto_ReservedRange >* +DescriptorProto::mutable_reserved_range() { + // @@protoc_insertion_point(field_mutable_list:google.protobuf.DescriptorProto.reserved_range) + return &reserved_range_; +} +inline const PROTOBUF_NAMESPACE_ID::DescriptorProto_ReservedRange& DescriptorProto::_internal_reserved_range(int index) const { + return reserved_range_.Get(index); +} +inline const PROTOBUF_NAMESPACE_ID::DescriptorProto_ReservedRange& DescriptorProto::reserved_range(int index) const { + // @@protoc_insertion_point(field_get:google.protobuf.DescriptorProto.reserved_range) + return _internal_reserved_range(index); +} +inline PROTOBUF_NAMESPACE_ID::DescriptorProto_ReservedRange* DescriptorProto::_internal_add_reserved_range() { + return reserved_range_.Add(); +} +inline PROTOBUF_NAMESPACE_ID::DescriptorProto_ReservedRange* DescriptorProto::add_reserved_range() { + // @@protoc_insertion_point(field_add:google.protobuf.DescriptorProto.reserved_range) + return _internal_add_reserved_range(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::DescriptorProto_ReservedRange >& +DescriptorProto::reserved_range() const { + // @@protoc_insertion_point(field_list:google.protobuf.DescriptorProto.reserved_range) + return reserved_range_; +} + +// repeated string reserved_name = 10; +inline int DescriptorProto::_internal_reserved_name_size() const { + return reserved_name_.size(); +} +inline int DescriptorProto::reserved_name_size() const { + return _internal_reserved_name_size(); +} +inline void DescriptorProto::clear_reserved_name() { + reserved_name_.Clear(); +} +inline std::string* DescriptorProto::add_reserved_name() { + // @@protoc_insertion_point(field_add_mutable:google.protobuf.DescriptorProto.reserved_name) + return _internal_add_reserved_name(); +} +inline const std::string& DescriptorProto::_internal_reserved_name(int index) const { + return reserved_name_.Get(index); +} +inline const std::string& DescriptorProto::reserved_name(int index) const { + // @@protoc_insertion_point(field_get:google.protobuf.DescriptorProto.reserved_name) + return _internal_reserved_name(index); +} +inline std::string* DescriptorProto::mutable_reserved_name(int index) { + // @@protoc_insertion_point(field_mutable:google.protobuf.DescriptorProto.reserved_name) + return reserved_name_.Mutable(index); +} +inline void DescriptorProto::set_reserved_name(int index, const std::string& value) { + // @@protoc_insertion_point(field_set:google.protobuf.DescriptorProto.reserved_name) + reserved_name_.Mutable(index)->assign(value); +} +inline void DescriptorProto::set_reserved_name(int index, std::string&& value) { + // @@protoc_insertion_point(field_set:google.protobuf.DescriptorProto.reserved_name) + reserved_name_.Mutable(index)->assign(std::move(value)); +} +inline void DescriptorProto::set_reserved_name(int index, const char* value) { + GOOGLE_DCHECK(value != nullptr); + reserved_name_.Mutable(index)->assign(value); + // @@protoc_insertion_point(field_set_char:google.protobuf.DescriptorProto.reserved_name) +} +inline void DescriptorProto::set_reserved_name(int index, const char* value, size_t size) { + reserved_name_.Mutable(index)->assign( + reinterpret_cast(value), size); + // @@protoc_insertion_point(field_set_pointer:google.protobuf.DescriptorProto.reserved_name) +} +inline std::string* DescriptorProto::_internal_add_reserved_name() { + return reserved_name_.Add(); +} +inline void DescriptorProto::add_reserved_name(const std::string& value) { + reserved_name_.Add()->assign(value); + // @@protoc_insertion_point(field_add:google.protobuf.DescriptorProto.reserved_name) +} +inline void DescriptorProto::add_reserved_name(std::string&& value) { + reserved_name_.Add(std::move(value)); + // @@protoc_insertion_point(field_add:google.protobuf.DescriptorProto.reserved_name) +} +inline void DescriptorProto::add_reserved_name(const char* value) { + GOOGLE_DCHECK(value != nullptr); + reserved_name_.Add()->assign(value); + // @@protoc_insertion_point(field_add_char:google.protobuf.DescriptorProto.reserved_name) +} +inline void DescriptorProto::add_reserved_name(const char* value, size_t size) { + reserved_name_.Add()->assign(reinterpret_cast(value), size); + // @@protoc_insertion_point(field_add_pointer:google.protobuf.DescriptorProto.reserved_name) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField& +DescriptorProto::reserved_name() const { + // @@protoc_insertion_point(field_list:google.protobuf.DescriptorProto.reserved_name) + return reserved_name_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField* +DescriptorProto::mutable_reserved_name() { + // @@protoc_insertion_point(field_mutable_list:google.protobuf.DescriptorProto.reserved_name) + return &reserved_name_; +} + +// ------------------------------------------------------------------- + +// ExtensionRangeOptions + +// repeated .google.protobuf.UninterpretedOption uninterpreted_option = 999; +inline int ExtensionRangeOptions::_internal_uninterpreted_option_size() const { + return uninterpreted_option_.size(); +} +inline int ExtensionRangeOptions::uninterpreted_option_size() const { + return _internal_uninterpreted_option_size(); +} +inline void ExtensionRangeOptions::clear_uninterpreted_option() { + uninterpreted_option_.Clear(); +} +inline PROTOBUF_NAMESPACE_ID::UninterpretedOption* ExtensionRangeOptions::mutable_uninterpreted_option(int index) { + // @@protoc_insertion_point(field_mutable:google.protobuf.ExtensionRangeOptions.uninterpreted_option) + return uninterpreted_option_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::UninterpretedOption >* +ExtensionRangeOptions::mutable_uninterpreted_option() { + // @@protoc_insertion_point(field_mutable_list:google.protobuf.ExtensionRangeOptions.uninterpreted_option) + return &uninterpreted_option_; +} +inline const PROTOBUF_NAMESPACE_ID::UninterpretedOption& ExtensionRangeOptions::_internal_uninterpreted_option(int index) const { + return uninterpreted_option_.Get(index); +} +inline const PROTOBUF_NAMESPACE_ID::UninterpretedOption& ExtensionRangeOptions::uninterpreted_option(int index) const { + // @@protoc_insertion_point(field_get:google.protobuf.ExtensionRangeOptions.uninterpreted_option) + return _internal_uninterpreted_option(index); +} +inline PROTOBUF_NAMESPACE_ID::UninterpretedOption* ExtensionRangeOptions::_internal_add_uninterpreted_option() { + return uninterpreted_option_.Add(); +} +inline PROTOBUF_NAMESPACE_ID::UninterpretedOption* ExtensionRangeOptions::add_uninterpreted_option() { + // @@protoc_insertion_point(field_add:google.protobuf.ExtensionRangeOptions.uninterpreted_option) + return _internal_add_uninterpreted_option(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::UninterpretedOption >& +ExtensionRangeOptions::uninterpreted_option() const { + // @@protoc_insertion_point(field_list:google.protobuf.ExtensionRangeOptions.uninterpreted_option) + return uninterpreted_option_; +} + +// ------------------------------------------------------------------- + +// FieldDescriptorProto + +// optional string name = 1; +inline bool FieldDescriptorProto::_internal_has_name() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool FieldDescriptorProto::has_name() const { + return _internal_has_name(); +} +inline void FieldDescriptorProto::clear_name() { + name_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000001u; +} +inline const std::string& FieldDescriptorProto::name() const { + // @@protoc_insertion_point(field_get:google.protobuf.FieldDescriptorProto.name) + return _internal_name(); +} +inline void FieldDescriptorProto::set_name(const std::string& value) { + _internal_set_name(value); + // @@protoc_insertion_point(field_set:google.protobuf.FieldDescriptorProto.name) +} +inline std::string* FieldDescriptorProto::mutable_name() { + // @@protoc_insertion_point(field_mutable:google.protobuf.FieldDescriptorProto.name) + return _internal_mutable_name(); +} +inline const std::string& FieldDescriptorProto::_internal_name() const { + return name_.Get(); +} +inline void FieldDescriptorProto::_internal_set_name(const std::string& value) { + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void FieldDescriptorProto::set_name(std::string&& value) { + _has_bits_[0] |= 0x00000001u; + name_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:google.protobuf.FieldDescriptorProto.name) +} +inline void FieldDescriptorProto::set_name(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:google.protobuf.FieldDescriptorProto.name) +} +inline void FieldDescriptorProto::set_name(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:google.protobuf.FieldDescriptorProto.name) +} +inline std::string* FieldDescriptorProto::_internal_mutable_name() { + _has_bits_[0] |= 0x00000001u; + return name_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* FieldDescriptorProto::release_name() { + // @@protoc_insertion_point(field_release:google.protobuf.FieldDescriptorProto.name) + if (!_internal_has_name()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000001u; + return name_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void FieldDescriptorProto::set_allocated_name(std::string* name) { + if (name != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + name_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), name, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:google.protobuf.FieldDescriptorProto.name) +} + +// optional int32 number = 3; +inline bool FieldDescriptorProto::_internal_has_number() const { + bool value = (_has_bits_[0] & 0x00000040u) != 0; + return value; +} +inline bool FieldDescriptorProto::has_number() const { + return _internal_has_number(); +} +inline void FieldDescriptorProto::clear_number() { + number_ = 0; + _has_bits_[0] &= ~0x00000040u; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 FieldDescriptorProto::_internal_number() const { + return number_; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 FieldDescriptorProto::number() const { + // @@protoc_insertion_point(field_get:google.protobuf.FieldDescriptorProto.number) + return _internal_number(); +} +inline void FieldDescriptorProto::_internal_set_number(::PROTOBUF_NAMESPACE_ID::int32 value) { + _has_bits_[0] |= 0x00000040u; + number_ = value; +} +inline void FieldDescriptorProto::set_number(::PROTOBUF_NAMESPACE_ID::int32 value) { + _internal_set_number(value); + // @@protoc_insertion_point(field_set:google.protobuf.FieldDescriptorProto.number) +} + +// optional .google.protobuf.FieldDescriptorProto.Label label = 4; +inline bool FieldDescriptorProto::_internal_has_label() const { + bool value = (_has_bits_[0] & 0x00000200u) != 0; + return value; +} +inline bool FieldDescriptorProto::has_label() const { + return _internal_has_label(); +} +inline void FieldDescriptorProto::clear_label() { + label_ = 1; + _has_bits_[0] &= ~0x00000200u; +} +inline PROTOBUF_NAMESPACE_ID::FieldDescriptorProto_Label FieldDescriptorProto::_internal_label() const { + return static_cast< PROTOBUF_NAMESPACE_ID::FieldDescriptorProto_Label >(label_); +} +inline PROTOBUF_NAMESPACE_ID::FieldDescriptorProto_Label FieldDescriptorProto::label() const { + // @@protoc_insertion_point(field_get:google.protobuf.FieldDescriptorProto.label) + return _internal_label(); +} +inline void FieldDescriptorProto::_internal_set_label(PROTOBUF_NAMESPACE_ID::FieldDescriptorProto_Label value) { + assert(PROTOBUF_NAMESPACE_ID::FieldDescriptorProto_Label_IsValid(value)); + _has_bits_[0] |= 0x00000200u; + label_ = value; +} +inline void FieldDescriptorProto::set_label(PROTOBUF_NAMESPACE_ID::FieldDescriptorProto_Label value) { + _internal_set_label(value); + // @@protoc_insertion_point(field_set:google.protobuf.FieldDescriptorProto.label) +} + +// optional .google.protobuf.FieldDescriptorProto.Type type = 5; +inline bool FieldDescriptorProto::_internal_has_type() const { + bool value = (_has_bits_[0] & 0x00000400u) != 0; + return value; +} +inline bool FieldDescriptorProto::has_type() const { + return _internal_has_type(); +} +inline void FieldDescriptorProto::clear_type() { + type_ = 1; + _has_bits_[0] &= ~0x00000400u; +} +inline PROTOBUF_NAMESPACE_ID::FieldDescriptorProto_Type FieldDescriptorProto::_internal_type() const { + return static_cast< PROTOBUF_NAMESPACE_ID::FieldDescriptorProto_Type >(type_); +} +inline PROTOBUF_NAMESPACE_ID::FieldDescriptorProto_Type FieldDescriptorProto::type() const { + // @@protoc_insertion_point(field_get:google.protobuf.FieldDescriptorProto.type) + return _internal_type(); +} +inline void FieldDescriptorProto::_internal_set_type(PROTOBUF_NAMESPACE_ID::FieldDescriptorProto_Type value) { + assert(PROTOBUF_NAMESPACE_ID::FieldDescriptorProto_Type_IsValid(value)); + _has_bits_[0] |= 0x00000400u; + type_ = value; +} +inline void FieldDescriptorProto::set_type(PROTOBUF_NAMESPACE_ID::FieldDescriptorProto_Type value) { + _internal_set_type(value); + // @@protoc_insertion_point(field_set:google.protobuf.FieldDescriptorProto.type) +} + +// optional string type_name = 6; +inline bool FieldDescriptorProto::_internal_has_type_name() const { + bool value = (_has_bits_[0] & 0x00000004u) != 0; + return value; +} +inline bool FieldDescriptorProto::has_type_name() const { + return _internal_has_type_name(); +} +inline void FieldDescriptorProto::clear_type_name() { + type_name_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000004u; +} +inline const std::string& FieldDescriptorProto::type_name() const { + // @@protoc_insertion_point(field_get:google.protobuf.FieldDescriptorProto.type_name) + return _internal_type_name(); +} +inline void FieldDescriptorProto::set_type_name(const std::string& value) { + _internal_set_type_name(value); + // @@protoc_insertion_point(field_set:google.protobuf.FieldDescriptorProto.type_name) +} +inline std::string* FieldDescriptorProto::mutable_type_name() { + // @@protoc_insertion_point(field_mutable:google.protobuf.FieldDescriptorProto.type_name) + return _internal_mutable_type_name(); +} +inline const std::string& FieldDescriptorProto::_internal_type_name() const { + return type_name_.Get(); +} +inline void FieldDescriptorProto::_internal_set_type_name(const std::string& value) { + _has_bits_[0] |= 0x00000004u; + type_name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void FieldDescriptorProto::set_type_name(std::string&& value) { + _has_bits_[0] |= 0x00000004u; + type_name_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:google.protobuf.FieldDescriptorProto.type_name) +} +inline void FieldDescriptorProto::set_type_name(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000004u; + type_name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:google.protobuf.FieldDescriptorProto.type_name) +} +inline void FieldDescriptorProto::set_type_name(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000004u; + type_name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:google.protobuf.FieldDescriptorProto.type_name) +} +inline std::string* FieldDescriptorProto::_internal_mutable_type_name() { + _has_bits_[0] |= 0x00000004u; + return type_name_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* FieldDescriptorProto::release_type_name() { + // @@protoc_insertion_point(field_release:google.protobuf.FieldDescriptorProto.type_name) + if (!_internal_has_type_name()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000004u; + return type_name_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void FieldDescriptorProto::set_allocated_type_name(std::string* type_name) { + if (type_name != nullptr) { + _has_bits_[0] |= 0x00000004u; + } else { + _has_bits_[0] &= ~0x00000004u; + } + type_name_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), type_name, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:google.protobuf.FieldDescriptorProto.type_name) +} + +// optional string extendee = 2; +inline bool FieldDescriptorProto::_internal_has_extendee() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + return value; +} +inline bool FieldDescriptorProto::has_extendee() const { + return _internal_has_extendee(); +} +inline void FieldDescriptorProto::clear_extendee() { + extendee_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000002u; +} +inline const std::string& FieldDescriptorProto::extendee() const { + // @@protoc_insertion_point(field_get:google.protobuf.FieldDescriptorProto.extendee) + return _internal_extendee(); +} +inline void FieldDescriptorProto::set_extendee(const std::string& value) { + _internal_set_extendee(value); + // @@protoc_insertion_point(field_set:google.protobuf.FieldDescriptorProto.extendee) +} +inline std::string* FieldDescriptorProto::mutable_extendee() { + // @@protoc_insertion_point(field_mutable:google.protobuf.FieldDescriptorProto.extendee) + return _internal_mutable_extendee(); +} +inline const std::string& FieldDescriptorProto::_internal_extendee() const { + return extendee_.Get(); +} +inline void FieldDescriptorProto::_internal_set_extendee(const std::string& value) { + _has_bits_[0] |= 0x00000002u; + extendee_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void FieldDescriptorProto::set_extendee(std::string&& value) { + _has_bits_[0] |= 0x00000002u; + extendee_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:google.protobuf.FieldDescriptorProto.extendee) +} +inline void FieldDescriptorProto::set_extendee(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000002u; + extendee_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:google.protobuf.FieldDescriptorProto.extendee) +} +inline void FieldDescriptorProto::set_extendee(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000002u; + extendee_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:google.protobuf.FieldDescriptorProto.extendee) +} +inline std::string* FieldDescriptorProto::_internal_mutable_extendee() { + _has_bits_[0] |= 0x00000002u; + return extendee_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* FieldDescriptorProto::release_extendee() { + // @@protoc_insertion_point(field_release:google.protobuf.FieldDescriptorProto.extendee) + if (!_internal_has_extendee()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000002u; + return extendee_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void FieldDescriptorProto::set_allocated_extendee(std::string* extendee) { + if (extendee != nullptr) { + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + extendee_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), extendee, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:google.protobuf.FieldDescriptorProto.extendee) +} + +// optional string default_value = 7; +inline bool FieldDescriptorProto::_internal_has_default_value() const { + bool value = (_has_bits_[0] & 0x00000008u) != 0; + return value; +} +inline bool FieldDescriptorProto::has_default_value() const { + return _internal_has_default_value(); +} +inline void FieldDescriptorProto::clear_default_value() { + default_value_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000008u; +} +inline const std::string& FieldDescriptorProto::default_value() const { + // @@protoc_insertion_point(field_get:google.protobuf.FieldDescriptorProto.default_value) + return _internal_default_value(); +} +inline void FieldDescriptorProto::set_default_value(const std::string& value) { + _internal_set_default_value(value); + // @@protoc_insertion_point(field_set:google.protobuf.FieldDescriptorProto.default_value) +} +inline std::string* FieldDescriptorProto::mutable_default_value() { + // @@protoc_insertion_point(field_mutable:google.protobuf.FieldDescriptorProto.default_value) + return _internal_mutable_default_value(); +} +inline const std::string& FieldDescriptorProto::_internal_default_value() const { + return default_value_.Get(); +} +inline void FieldDescriptorProto::_internal_set_default_value(const std::string& value) { + _has_bits_[0] |= 0x00000008u; + default_value_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void FieldDescriptorProto::set_default_value(std::string&& value) { + _has_bits_[0] |= 0x00000008u; + default_value_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:google.protobuf.FieldDescriptorProto.default_value) +} +inline void FieldDescriptorProto::set_default_value(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000008u; + default_value_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:google.protobuf.FieldDescriptorProto.default_value) +} +inline void FieldDescriptorProto::set_default_value(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000008u; + default_value_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:google.protobuf.FieldDescriptorProto.default_value) +} +inline std::string* FieldDescriptorProto::_internal_mutable_default_value() { + _has_bits_[0] |= 0x00000008u; + return default_value_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* FieldDescriptorProto::release_default_value() { + // @@protoc_insertion_point(field_release:google.protobuf.FieldDescriptorProto.default_value) + if (!_internal_has_default_value()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000008u; + return default_value_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void FieldDescriptorProto::set_allocated_default_value(std::string* default_value) { + if (default_value != nullptr) { + _has_bits_[0] |= 0x00000008u; + } else { + _has_bits_[0] &= ~0x00000008u; + } + default_value_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), default_value, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:google.protobuf.FieldDescriptorProto.default_value) +} + +// optional int32 oneof_index = 9; +inline bool FieldDescriptorProto::_internal_has_oneof_index() const { + bool value = (_has_bits_[0] & 0x00000080u) != 0; + return value; +} +inline bool FieldDescriptorProto::has_oneof_index() const { + return _internal_has_oneof_index(); +} +inline void FieldDescriptorProto::clear_oneof_index() { + oneof_index_ = 0; + _has_bits_[0] &= ~0x00000080u; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 FieldDescriptorProto::_internal_oneof_index() const { + return oneof_index_; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 FieldDescriptorProto::oneof_index() const { + // @@protoc_insertion_point(field_get:google.protobuf.FieldDescriptorProto.oneof_index) + return _internal_oneof_index(); +} +inline void FieldDescriptorProto::_internal_set_oneof_index(::PROTOBUF_NAMESPACE_ID::int32 value) { + _has_bits_[0] |= 0x00000080u; + oneof_index_ = value; +} +inline void FieldDescriptorProto::set_oneof_index(::PROTOBUF_NAMESPACE_ID::int32 value) { + _internal_set_oneof_index(value); + // @@protoc_insertion_point(field_set:google.protobuf.FieldDescriptorProto.oneof_index) +} + +// optional string json_name = 10; +inline bool FieldDescriptorProto::_internal_has_json_name() const { + bool value = (_has_bits_[0] & 0x00000010u) != 0; + return value; +} +inline bool FieldDescriptorProto::has_json_name() const { + return _internal_has_json_name(); +} +inline void FieldDescriptorProto::clear_json_name() { + json_name_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000010u; +} +inline const std::string& FieldDescriptorProto::json_name() const { + // @@protoc_insertion_point(field_get:google.protobuf.FieldDescriptorProto.json_name) + return _internal_json_name(); +} +inline void FieldDescriptorProto::set_json_name(const std::string& value) { + _internal_set_json_name(value); + // @@protoc_insertion_point(field_set:google.protobuf.FieldDescriptorProto.json_name) +} +inline std::string* FieldDescriptorProto::mutable_json_name() { + // @@protoc_insertion_point(field_mutable:google.protobuf.FieldDescriptorProto.json_name) + return _internal_mutable_json_name(); +} +inline const std::string& FieldDescriptorProto::_internal_json_name() const { + return json_name_.Get(); +} +inline void FieldDescriptorProto::_internal_set_json_name(const std::string& value) { + _has_bits_[0] |= 0x00000010u; + json_name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void FieldDescriptorProto::set_json_name(std::string&& value) { + _has_bits_[0] |= 0x00000010u; + json_name_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:google.protobuf.FieldDescriptorProto.json_name) +} +inline void FieldDescriptorProto::set_json_name(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000010u; + json_name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:google.protobuf.FieldDescriptorProto.json_name) +} +inline void FieldDescriptorProto::set_json_name(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000010u; + json_name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:google.protobuf.FieldDescriptorProto.json_name) +} +inline std::string* FieldDescriptorProto::_internal_mutable_json_name() { + _has_bits_[0] |= 0x00000010u; + return json_name_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* FieldDescriptorProto::release_json_name() { + // @@protoc_insertion_point(field_release:google.protobuf.FieldDescriptorProto.json_name) + if (!_internal_has_json_name()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000010u; + return json_name_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void FieldDescriptorProto::set_allocated_json_name(std::string* json_name) { + if (json_name != nullptr) { + _has_bits_[0] |= 0x00000010u; + } else { + _has_bits_[0] &= ~0x00000010u; + } + json_name_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), json_name, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:google.protobuf.FieldDescriptorProto.json_name) +} + +// optional .google.protobuf.FieldOptions options = 8; +inline bool FieldDescriptorProto::_internal_has_options() const { + bool value = (_has_bits_[0] & 0x00000020u) != 0; + PROTOBUF_ASSUME(!value || options_ != nullptr); + return value; +} +inline bool FieldDescriptorProto::has_options() const { + return _internal_has_options(); +} +inline void FieldDescriptorProto::clear_options() { + if (options_ != nullptr) options_->Clear(); + _has_bits_[0] &= ~0x00000020u; +} +inline const PROTOBUF_NAMESPACE_ID::FieldOptions& FieldDescriptorProto::_internal_options() const { + const PROTOBUF_NAMESPACE_ID::FieldOptions* p = options_; + return p != nullptr ? *p : *reinterpret_cast( + &PROTOBUF_NAMESPACE_ID::_FieldOptions_default_instance_); +} +inline const PROTOBUF_NAMESPACE_ID::FieldOptions& FieldDescriptorProto::options() const { + // @@protoc_insertion_point(field_get:google.protobuf.FieldDescriptorProto.options) + return _internal_options(); +} +inline void FieldDescriptorProto::unsafe_arena_set_allocated_options( + PROTOBUF_NAMESPACE_ID::FieldOptions* options) { + if (GetArena() == nullptr) { + delete reinterpret_cast<::PROTOBUF_NAMESPACE_ID::MessageLite*>(options_); + } + options_ = options; + if (options) { + _has_bits_[0] |= 0x00000020u; + } else { + _has_bits_[0] &= ~0x00000020u; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:google.protobuf.FieldDescriptorProto.options) +} +inline PROTOBUF_NAMESPACE_ID::FieldOptions* FieldDescriptorProto::release_options() { + _has_bits_[0] &= ~0x00000020u; + PROTOBUF_NAMESPACE_ID::FieldOptions* temp = options_; + options_ = nullptr; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + return temp; +} +inline PROTOBUF_NAMESPACE_ID::FieldOptions* FieldDescriptorProto::unsafe_arena_release_options() { + // @@protoc_insertion_point(field_release:google.protobuf.FieldDescriptorProto.options) + _has_bits_[0] &= ~0x00000020u; + PROTOBUF_NAMESPACE_ID::FieldOptions* temp = options_; + options_ = nullptr; + return temp; +} +inline PROTOBUF_NAMESPACE_ID::FieldOptions* FieldDescriptorProto::_internal_mutable_options() { + _has_bits_[0] |= 0x00000020u; + if (options_ == nullptr) { + auto* p = CreateMaybeMessage(GetArena()); + options_ = p; + } + return options_; +} +inline PROTOBUF_NAMESPACE_ID::FieldOptions* FieldDescriptorProto::mutable_options() { + // @@protoc_insertion_point(field_mutable:google.protobuf.FieldDescriptorProto.options) + return _internal_mutable_options(); +} +inline void FieldDescriptorProto::set_allocated_options(PROTOBUF_NAMESPACE_ID::FieldOptions* options) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + if (message_arena == nullptr) { + delete options_; + } + if (options) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + ::PROTOBUF_NAMESPACE_ID::Arena::GetArena(options); + if (message_arena != submessage_arena) { + options = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, options, submessage_arena); + } + _has_bits_[0] |= 0x00000020u; + } else { + _has_bits_[0] &= ~0x00000020u; + } + options_ = options; + // @@protoc_insertion_point(field_set_allocated:google.protobuf.FieldDescriptorProto.options) +} + +// optional bool proto3_optional = 17; +inline bool FieldDescriptorProto::_internal_has_proto3_optional() const { + bool value = (_has_bits_[0] & 0x00000100u) != 0; + return value; +} +inline bool FieldDescriptorProto::has_proto3_optional() const { + return _internal_has_proto3_optional(); +} +inline void FieldDescriptorProto::clear_proto3_optional() { + proto3_optional_ = false; + _has_bits_[0] &= ~0x00000100u; +} +inline bool FieldDescriptorProto::_internal_proto3_optional() const { + return proto3_optional_; +} +inline bool FieldDescriptorProto::proto3_optional() const { + // @@protoc_insertion_point(field_get:google.protobuf.FieldDescriptorProto.proto3_optional) + return _internal_proto3_optional(); +} +inline void FieldDescriptorProto::_internal_set_proto3_optional(bool value) { + _has_bits_[0] |= 0x00000100u; + proto3_optional_ = value; +} +inline void FieldDescriptorProto::set_proto3_optional(bool value) { + _internal_set_proto3_optional(value); + // @@protoc_insertion_point(field_set:google.protobuf.FieldDescriptorProto.proto3_optional) +} + +// ------------------------------------------------------------------- + +// OneofDescriptorProto + +// optional string name = 1; +inline bool OneofDescriptorProto::_internal_has_name() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool OneofDescriptorProto::has_name() const { + return _internal_has_name(); +} +inline void OneofDescriptorProto::clear_name() { + name_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000001u; +} +inline const std::string& OneofDescriptorProto::name() const { + // @@protoc_insertion_point(field_get:google.protobuf.OneofDescriptorProto.name) + return _internal_name(); +} +inline void OneofDescriptorProto::set_name(const std::string& value) { + _internal_set_name(value); + // @@protoc_insertion_point(field_set:google.protobuf.OneofDescriptorProto.name) +} +inline std::string* OneofDescriptorProto::mutable_name() { + // @@protoc_insertion_point(field_mutable:google.protobuf.OneofDescriptorProto.name) + return _internal_mutable_name(); +} +inline const std::string& OneofDescriptorProto::_internal_name() const { + return name_.Get(); +} +inline void OneofDescriptorProto::_internal_set_name(const std::string& value) { + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void OneofDescriptorProto::set_name(std::string&& value) { + _has_bits_[0] |= 0x00000001u; + name_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:google.protobuf.OneofDescriptorProto.name) +} +inline void OneofDescriptorProto::set_name(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:google.protobuf.OneofDescriptorProto.name) +} +inline void OneofDescriptorProto::set_name(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:google.protobuf.OneofDescriptorProto.name) +} +inline std::string* OneofDescriptorProto::_internal_mutable_name() { + _has_bits_[0] |= 0x00000001u; + return name_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* OneofDescriptorProto::release_name() { + // @@protoc_insertion_point(field_release:google.protobuf.OneofDescriptorProto.name) + if (!_internal_has_name()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000001u; + return name_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void OneofDescriptorProto::set_allocated_name(std::string* name) { + if (name != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + name_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), name, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:google.protobuf.OneofDescriptorProto.name) +} + +// optional .google.protobuf.OneofOptions options = 2; +inline bool OneofDescriptorProto::_internal_has_options() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + PROTOBUF_ASSUME(!value || options_ != nullptr); + return value; +} +inline bool OneofDescriptorProto::has_options() const { + return _internal_has_options(); +} +inline void OneofDescriptorProto::clear_options() { + if (options_ != nullptr) options_->Clear(); + _has_bits_[0] &= ~0x00000002u; +} +inline const PROTOBUF_NAMESPACE_ID::OneofOptions& OneofDescriptorProto::_internal_options() const { + const PROTOBUF_NAMESPACE_ID::OneofOptions* p = options_; + return p != nullptr ? *p : *reinterpret_cast( + &PROTOBUF_NAMESPACE_ID::_OneofOptions_default_instance_); +} +inline const PROTOBUF_NAMESPACE_ID::OneofOptions& OneofDescriptorProto::options() const { + // @@protoc_insertion_point(field_get:google.protobuf.OneofDescriptorProto.options) + return _internal_options(); +} +inline void OneofDescriptorProto::unsafe_arena_set_allocated_options( + PROTOBUF_NAMESPACE_ID::OneofOptions* options) { + if (GetArena() == nullptr) { + delete reinterpret_cast<::PROTOBUF_NAMESPACE_ID::MessageLite*>(options_); + } + options_ = options; + if (options) { + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:google.protobuf.OneofDescriptorProto.options) +} +inline PROTOBUF_NAMESPACE_ID::OneofOptions* OneofDescriptorProto::release_options() { + _has_bits_[0] &= ~0x00000002u; + PROTOBUF_NAMESPACE_ID::OneofOptions* temp = options_; + options_ = nullptr; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + return temp; +} +inline PROTOBUF_NAMESPACE_ID::OneofOptions* OneofDescriptorProto::unsafe_arena_release_options() { + // @@protoc_insertion_point(field_release:google.protobuf.OneofDescriptorProto.options) + _has_bits_[0] &= ~0x00000002u; + PROTOBUF_NAMESPACE_ID::OneofOptions* temp = options_; + options_ = nullptr; + return temp; +} +inline PROTOBUF_NAMESPACE_ID::OneofOptions* OneofDescriptorProto::_internal_mutable_options() { + _has_bits_[0] |= 0x00000002u; + if (options_ == nullptr) { + auto* p = CreateMaybeMessage(GetArena()); + options_ = p; + } + return options_; +} +inline PROTOBUF_NAMESPACE_ID::OneofOptions* OneofDescriptorProto::mutable_options() { + // @@protoc_insertion_point(field_mutable:google.protobuf.OneofDescriptorProto.options) + return _internal_mutable_options(); +} +inline void OneofDescriptorProto::set_allocated_options(PROTOBUF_NAMESPACE_ID::OneofOptions* options) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + if (message_arena == nullptr) { + delete options_; + } + if (options) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + ::PROTOBUF_NAMESPACE_ID::Arena::GetArena(options); + if (message_arena != submessage_arena) { + options = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, options, submessage_arena); + } + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + options_ = options; + // @@protoc_insertion_point(field_set_allocated:google.protobuf.OneofDescriptorProto.options) +} + +// ------------------------------------------------------------------- + +// EnumDescriptorProto_EnumReservedRange + +// optional int32 start = 1; +inline bool EnumDescriptorProto_EnumReservedRange::_internal_has_start() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool EnumDescriptorProto_EnumReservedRange::has_start() const { + return _internal_has_start(); +} +inline void EnumDescriptorProto_EnumReservedRange::clear_start() { + start_ = 0; + _has_bits_[0] &= ~0x00000001u; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 EnumDescriptorProto_EnumReservedRange::_internal_start() const { + return start_; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 EnumDescriptorProto_EnumReservedRange::start() const { + // @@protoc_insertion_point(field_get:google.protobuf.EnumDescriptorProto.EnumReservedRange.start) + return _internal_start(); +} +inline void EnumDescriptorProto_EnumReservedRange::_internal_set_start(::PROTOBUF_NAMESPACE_ID::int32 value) { + _has_bits_[0] |= 0x00000001u; + start_ = value; +} +inline void EnumDescriptorProto_EnumReservedRange::set_start(::PROTOBUF_NAMESPACE_ID::int32 value) { + _internal_set_start(value); + // @@protoc_insertion_point(field_set:google.protobuf.EnumDescriptorProto.EnumReservedRange.start) +} + +// optional int32 end = 2; +inline bool EnumDescriptorProto_EnumReservedRange::_internal_has_end() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + return value; +} +inline bool EnumDescriptorProto_EnumReservedRange::has_end() const { + return _internal_has_end(); +} +inline void EnumDescriptorProto_EnumReservedRange::clear_end() { + end_ = 0; + _has_bits_[0] &= ~0x00000002u; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 EnumDescriptorProto_EnumReservedRange::_internal_end() const { + return end_; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 EnumDescriptorProto_EnumReservedRange::end() const { + // @@protoc_insertion_point(field_get:google.protobuf.EnumDescriptorProto.EnumReservedRange.end) + return _internal_end(); +} +inline void EnumDescriptorProto_EnumReservedRange::_internal_set_end(::PROTOBUF_NAMESPACE_ID::int32 value) { + _has_bits_[0] |= 0x00000002u; + end_ = value; +} +inline void EnumDescriptorProto_EnumReservedRange::set_end(::PROTOBUF_NAMESPACE_ID::int32 value) { + _internal_set_end(value); + // @@protoc_insertion_point(field_set:google.protobuf.EnumDescriptorProto.EnumReservedRange.end) +} + +// ------------------------------------------------------------------- + +// EnumDescriptorProto + +// optional string name = 1; +inline bool EnumDescriptorProto::_internal_has_name() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool EnumDescriptorProto::has_name() const { + return _internal_has_name(); +} +inline void EnumDescriptorProto::clear_name() { + name_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000001u; +} +inline const std::string& EnumDescriptorProto::name() const { + // @@protoc_insertion_point(field_get:google.protobuf.EnumDescriptorProto.name) + return _internal_name(); +} +inline void EnumDescriptorProto::set_name(const std::string& value) { + _internal_set_name(value); + // @@protoc_insertion_point(field_set:google.protobuf.EnumDescriptorProto.name) +} +inline std::string* EnumDescriptorProto::mutable_name() { + // @@protoc_insertion_point(field_mutable:google.protobuf.EnumDescriptorProto.name) + return _internal_mutable_name(); +} +inline const std::string& EnumDescriptorProto::_internal_name() const { + return name_.Get(); +} +inline void EnumDescriptorProto::_internal_set_name(const std::string& value) { + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void EnumDescriptorProto::set_name(std::string&& value) { + _has_bits_[0] |= 0x00000001u; + name_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:google.protobuf.EnumDescriptorProto.name) +} +inline void EnumDescriptorProto::set_name(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:google.protobuf.EnumDescriptorProto.name) +} +inline void EnumDescriptorProto::set_name(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:google.protobuf.EnumDescriptorProto.name) +} +inline std::string* EnumDescriptorProto::_internal_mutable_name() { + _has_bits_[0] |= 0x00000001u; + return name_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* EnumDescriptorProto::release_name() { + // @@protoc_insertion_point(field_release:google.protobuf.EnumDescriptorProto.name) + if (!_internal_has_name()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000001u; + return name_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void EnumDescriptorProto::set_allocated_name(std::string* name) { + if (name != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + name_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), name, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:google.protobuf.EnumDescriptorProto.name) +} + +// repeated .google.protobuf.EnumValueDescriptorProto value = 2; +inline int EnumDescriptorProto::_internal_value_size() const { + return value_.size(); +} +inline int EnumDescriptorProto::value_size() const { + return _internal_value_size(); +} +inline void EnumDescriptorProto::clear_value() { + value_.Clear(); +} +inline PROTOBUF_NAMESPACE_ID::EnumValueDescriptorProto* EnumDescriptorProto::mutable_value(int index) { + // @@protoc_insertion_point(field_mutable:google.protobuf.EnumDescriptorProto.value) + return value_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::EnumValueDescriptorProto >* +EnumDescriptorProto::mutable_value() { + // @@protoc_insertion_point(field_mutable_list:google.protobuf.EnumDescriptorProto.value) + return &value_; +} +inline const PROTOBUF_NAMESPACE_ID::EnumValueDescriptorProto& EnumDescriptorProto::_internal_value(int index) const { + return value_.Get(index); +} +inline const PROTOBUF_NAMESPACE_ID::EnumValueDescriptorProto& EnumDescriptorProto::value(int index) const { + // @@protoc_insertion_point(field_get:google.protobuf.EnumDescriptorProto.value) + return _internal_value(index); +} +inline PROTOBUF_NAMESPACE_ID::EnumValueDescriptorProto* EnumDescriptorProto::_internal_add_value() { + return value_.Add(); +} +inline PROTOBUF_NAMESPACE_ID::EnumValueDescriptorProto* EnumDescriptorProto::add_value() { + // @@protoc_insertion_point(field_add:google.protobuf.EnumDescriptorProto.value) + return _internal_add_value(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::EnumValueDescriptorProto >& +EnumDescriptorProto::value() const { + // @@protoc_insertion_point(field_list:google.protobuf.EnumDescriptorProto.value) + return value_; +} + +// optional .google.protobuf.EnumOptions options = 3; +inline bool EnumDescriptorProto::_internal_has_options() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + PROTOBUF_ASSUME(!value || options_ != nullptr); + return value; +} +inline bool EnumDescriptorProto::has_options() const { + return _internal_has_options(); +} +inline void EnumDescriptorProto::clear_options() { + if (options_ != nullptr) options_->Clear(); + _has_bits_[0] &= ~0x00000002u; +} +inline const PROTOBUF_NAMESPACE_ID::EnumOptions& EnumDescriptorProto::_internal_options() const { + const PROTOBUF_NAMESPACE_ID::EnumOptions* p = options_; + return p != nullptr ? *p : *reinterpret_cast( + &PROTOBUF_NAMESPACE_ID::_EnumOptions_default_instance_); +} +inline const PROTOBUF_NAMESPACE_ID::EnumOptions& EnumDescriptorProto::options() const { + // @@protoc_insertion_point(field_get:google.protobuf.EnumDescriptorProto.options) + return _internal_options(); +} +inline void EnumDescriptorProto::unsafe_arena_set_allocated_options( + PROTOBUF_NAMESPACE_ID::EnumOptions* options) { + if (GetArena() == nullptr) { + delete reinterpret_cast<::PROTOBUF_NAMESPACE_ID::MessageLite*>(options_); + } + options_ = options; + if (options) { + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:google.protobuf.EnumDescriptorProto.options) +} +inline PROTOBUF_NAMESPACE_ID::EnumOptions* EnumDescriptorProto::release_options() { + _has_bits_[0] &= ~0x00000002u; + PROTOBUF_NAMESPACE_ID::EnumOptions* temp = options_; + options_ = nullptr; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + return temp; +} +inline PROTOBUF_NAMESPACE_ID::EnumOptions* EnumDescriptorProto::unsafe_arena_release_options() { + // @@protoc_insertion_point(field_release:google.protobuf.EnumDescriptorProto.options) + _has_bits_[0] &= ~0x00000002u; + PROTOBUF_NAMESPACE_ID::EnumOptions* temp = options_; + options_ = nullptr; + return temp; +} +inline PROTOBUF_NAMESPACE_ID::EnumOptions* EnumDescriptorProto::_internal_mutable_options() { + _has_bits_[0] |= 0x00000002u; + if (options_ == nullptr) { + auto* p = CreateMaybeMessage(GetArena()); + options_ = p; + } + return options_; +} +inline PROTOBUF_NAMESPACE_ID::EnumOptions* EnumDescriptorProto::mutable_options() { + // @@protoc_insertion_point(field_mutable:google.protobuf.EnumDescriptorProto.options) + return _internal_mutable_options(); +} +inline void EnumDescriptorProto::set_allocated_options(PROTOBUF_NAMESPACE_ID::EnumOptions* options) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + if (message_arena == nullptr) { + delete options_; + } + if (options) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + ::PROTOBUF_NAMESPACE_ID::Arena::GetArena(options); + if (message_arena != submessage_arena) { + options = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, options, submessage_arena); + } + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + options_ = options; + // @@protoc_insertion_point(field_set_allocated:google.protobuf.EnumDescriptorProto.options) +} + +// repeated .google.protobuf.EnumDescriptorProto.EnumReservedRange reserved_range = 4; +inline int EnumDescriptorProto::_internal_reserved_range_size() const { + return reserved_range_.size(); +} +inline int EnumDescriptorProto::reserved_range_size() const { + return _internal_reserved_range_size(); +} +inline void EnumDescriptorProto::clear_reserved_range() { + reserved_range_.Clear(); +} +inline PROTOBUF_NAMESPACE_ID::EnumDescriptorProto_EnumReservedRange* EnumDescriptorProto::mutable_reserved_range(int index) { + // @@protoc_insertion_point(field_mutable:google.protobuf.EnumDescriptorProto.reserved_range) + return reserved_range_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::EnumDescriptorProto_EnumReservedRange >* +EnumDescriptorProto::mutable_reserved_range() { + // @@protoc_insertion_point(field_mutable_list:google.protobuf.EnumDescriptorProto.reserved_range) + return &reserved_range_; +} +inline const PROTOBUF_NAMESPACE_ID::EnumDescriptorProto_EnumReservedRange& EnumDescriptorProto::_internal_reserved_range(int index) const { + return reserved_range_.Get(index); +} +inline const PROTOBUF_NAMESPACE_ID::EnumDescriptorProto_EnumReservedRange& EnumDescriptorProto::reserved_range(int index) const { + // @@protoc_insertion_point(field_get:google.protobuf.EnumDescriptorProto.reserved_range) + return _internal_reserved_range(index); +} +inline PROTOBUF_NAMESPACE_ID::EnumDescriptorProto_EnumReservedRange* EnumDescriptorProto::_internal_add_reserved_range() { + return reserved_range_.Add(); +} +inline PROTOBUF_NAMESPACE_ID::EnumDescriptorProto_EnumReservedRange* EnumDescriptorProto::add_reserved_range() { + // @@protoc_insertion_point(field_add:google.protobuf.EnumDescriptorProto.reserved_range) + return _internal_add_reserved_range(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::EnumDescriptorProto_EnumReservedRange >& +EnumDescriptorProto::reserved_range() const { + // @@protoc_insertion_point(field_list:google.protobuf.EnumDescriptorProto.reserved_range) + return reserved_range_; +} + +// repeated string reserved_name = 5; +inline int EnumDescriptorProto::_internal_reserved_name_size() const { + return reserved_name_.size(); +} +inline int EnumDescriptorProto::reserved_name_size() const { + return _internal_reserved_name_size(); +} +inline void EnumDescriptorProto::clear_reserved_name() { + reserved_name_.Clear(); +} +inline std::string* EnumDescriptorProto::add_reserved_name() { + // @@protoc_insertion_point(field_add_mutable:google.protobuf.EnumDescriptorProto.reserved_name) + return _internal_add_reserved_name(); +} +inline const std::string& EnumDescriptorProto::_internal_reserved_name(int index) const { + return reserved_name_.Get(index); +} +inline const std::string& EnumDescriptorProto::reserved_name(int index) const { + // @@protoc_insertion_point(field_get:google.protobuf.EnumDescriptorProto.reserved_name) + return _internal_reserved_name(index); +} +inline std::string* EnumDescriptorProto::mutable_reserved_name(int index) { + // @@protoc_insertion_point(field_mutable:google.protobuf.EnumDescriptorProto.reserved_name) + return reserved_name_.Mutable(index); +} +inline void EnumDescriptorProto::set_reserved_name(int index, const std::string& value) { + // @@protoc_insertion_point(field_set:google.protobuf.EnumDescriptorProto.reserved_name) + reserved_name_.Mutable(index)->assign(value); +} +inline void EnumDescriptorProto::set_reserved_name(int index, std::string&& value) { + // @@protoc_insertion_point(field_set:google.protobuf.EnumDescriptorProto.reserved_name) + reserved_name_.Mutable(index)->assign(std::move(value)); +} +inline void EnumDescriptorProto::set_reserved_name(int index, const char* value) { + GOOGLE_DCHECK(value != nullptr); + reserved_name_.Mutable(index)->assign(value); + // @@protoc_insertion_point(field_set_char:google.protobuf.EnumDescriptorProto.reserved_name) +} +inline void EnumDescriptorProto::set_reserved_name(int index, const char* value, size_t size) { + reserved_name_.Mutable(index)->assign( + reinterpret_cast(value), size); + // @@protoc_insertion_point(field_set_pointer:google.protobuf.EnumDescriptorProto.reserved_name) +} +inline std::string* EnumDescriptorProto::_internal_add_reserved_name() { + return reserved_name_.Add(); +} +inline void EnumDescriptorProto::add_reserved_name(const std::string& value) { + reserved_name_.Add()->assign(value); + // @@protoc_insertion_point(field_add:google.protobuf.EnumDescriptorProto.reserved_name) +} +inline void EnumDescriptorProto::add_reserved_name(std::string&& value) { + reserved_name_.Add(std::move(value)); + // @@protoc_insertion_point(field_add:google.protobuf.EnumDescriptorProto.reserved_name) +} +inline void EnumDescriptorProto::add_reserved_name(const char* value) { + GOOGLE_DCHECK(value != nullptr); + reserved_name_.Add()->assign(value); + // @@protoc_insertion_point(field_add_char:google.protobuf.EnumDescriptorProto.reserved_name) +} +inline void EnumDescriptorProto::add_reserved_name(const char* value, size_t size) { + reserved_name_.Add()->assign(reinterpret_cast(value), size); + // @@protoc_insertion_point(field_add_pointer:google.protobuf.EnumDescriptorProto.reserved_name) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField& +EnumDescriptorProto::reserved_name() const { + // @@protoc_insertion_point(field_list:google.protobuf.EnumDescriptorProto.reserved_name) + return reserved_name_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField* +EnumDescriptorProto::mutable_reserved_name() { + // @@protoc_insertion_point(field_mutable_list:google.protobuf.EnumDescriptorProto.reserved_name) + return &reserved_name_; +} + +// ------------------------------------------------------------------- + +// EnumValueDescriptorProto + +// optional string name = 1; +inline bool EnumValueDescriptorProto::_internal_has_name() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool EnumValueDescriptorProto::has_name() const { + return _internal_has_name(); +} +inline void EnumValueDescriptorProto::clear_name() { + name_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000001u; +} +inline const std::string& EnumValueDescriptorProto::name() const { + // @@protoc_insertion_point(field_get:google.protobuf.EnumValueDescriptorProto.name) + return _internal_name(); +} +inline void EnumValueDescriptorProto::set_name(const std::string& value) { + _internal_set_name(value); + // @@protoc_insertion_point(field_set:google.protobuf.EnumValueDescriptorProto.name) +} +inline std::string* EnumValueDescriptorProto::mutable_name() { + // @@protoc_insertion_point(field_mutable:google.protobuf.EnumValueDescriptorProto.name) + return _internal_mutable_name(); +} +inline const std::string& EnumValueDescriptorProto::_internal_name() const { + return name_.Get(); +} +inline void EnumValueDescriptorProto::_internal_set_name(const std::string& value) { + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void EnumValueDescriptorProto::set_name(std::string&& value) { + _has_bits_[0] |= 0x00000001u; + name_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:google.protobuf.EnumValueDescriptorProto.name) +} +inline void EnumValueDescriptorProto::set_name(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:google.protobuf.EnumValueDescriptorProto.name) +} +inline void EnumValueDescriptorProto::set_name(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:google.protobuf.EnumValueDescriptorProto.name) +} +inline std::string* EnumValueDescriptorProto::_internal_mutable_name() { + _has_bits_[0] |= 0x00000001u; + return name_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* EnumValueDescriptorProto::release_name() { + // @@protoc_insertion_point(field_release:google.protobuf.EnumValueDescriptorProto.name) + if (!_internal_has_name()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000001u; + return name_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void EnumValueDescriptorProto::set_allocated_name(std::string* name) { + if (name != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + name_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), name, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:google.protobuf.EnumValueDescriptorProto.name) +} + +// optional int32 number = 2; +inline bool EnumValueDescriptorProto::_internal_has_number() const { + bool value = (_has_bits_[0] & 0x00000004u) != 0; + return value; +} +inline bool EnumValueDescriptorProto::has_number() const { + return _internal_has_number(); +} +inline void EnumValueDescriptorProto::clear_number() { + number_ = 0; + _has_bits_[0] &= ~0x00000004u; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 EnumValueDescriptorProto::_internal_number() const { + return number_; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 EnumValueDescriptorProto::number() const { + // @@protoc_insertion_point(field_get:google.protobuf.EnumValueDescriptorProto.number) + return _internal_number(); +} +inline void EnumValueDescriptorProto::_internal_set_number(::PROTOBUF_NAMESPACE_ID::int32 value) { + _has_bits_[0] |= 0x00000004u; + number_ = value; +} +inline void EnumValueDescriptorProto::set_number(::PROTOBUF_NAMESPACE_ID::int32 value) { + _internal_set_number(value); + // @@protoc_insertion_point(field_set:google.protobuf.EnumValueDescriptorProto.number) +} + +// optional .google.protobuf.EnumValueOptions options = 3; +inline bool EnumValueDescriptorProto::_internal_has_options() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + PROTOBUF_ASSUME(!value || options_ != nullptr); + return value; +} +inline bool EnumValueDescriptorProto::has_options() const { + return _internal_has_options(); +} +inline void EnumValueDescriptorProto::clear_options() { + if (options_ != nullptr) options_->Clear(); + _has_bits_[0] &= ~0x00000002u; +} +inline const PROTOBUF_NAMESPACE_ID::EnumValueOptions& EnumValueDescriptorProto::_internal_options() const { + const PROTOBUF_NAMESPACE_ID::EnumValueOptions* p = options_; + return p != nullptr ? *p : *reinterpret_cast( + &PROTOBUF_NAMESPACE_ID::_EnumValueOptions_default_instance_); +} +inline const PROTOBUF_NAMESPACE_ID::EnumValueOptions& EnumValueDescriptorProto::options() const { + // @@protoc_insertion_point(field_get:google.protobuf.EnumValueDescriptorProto.options) + return _internal_options(); +} +inline void EnumValueDescriptorProto::unsafe_arena_set_allocated_options( + PROTOBUF_NAMESPACE_ID::EnumValueOptions* options) { + if (GetArena() == nullptr) { + delete reinterpret_cast<::PROTOBUF_NAMESPACE_ID::MessageLite*>(options_); + } + options_ = options; + if (options) { + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:google.protobuf.EnumValueDescriptorProto.options) +} +inline PROTOBUF_NAMESPACE_ID::EnumValueOptions* EnumValueDescriptorProto::release_options() { + _has_bits_[0] &= ~0x00000002u; + PROTOBUF_NAMESPACE_ID::EnumValueOptions* temp = options_; + options_ = nullptr; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + return temp; +} +inline PROTOBUF_NAMESPACE_ID::EnumValueOptions* EnumValueDescriptorProto::unsafe_arena_release_options() { + // @@protoc_insertion_point(field_release:google.protobuf.EnumValueDescriptorProto.options) + _has_bits_[0] &= ~0x00000002u; + PROTOBUF_NAMESPACE_ID::EnumValueOptions* temp = options_; + options_ = nullptr; + return temp; +} +inline PROTOBUF_NAMESPACE_ID::EnumValueOptions* EnumValueDescriptorProto::_internal_mutable_options() { + _has_bits_[0] |= 0x00000002u; + if (options_ == nullptr) { + auto* p = CreateMaybeMessage(GetArena()); + options_ = p; + } + return options_; +} +inline PROTOBUF_NAMESPACE_ID::EnumValueOptions* EnumValueDescriptorProto::mutable_options() { + // @@protoc_insertion_point(field_mutable:google.protobuf.EnumValueDescriptorProto.options) + return _internal_mutable_options(); +} +inline void EnumValueDescriptorProto::set_allocated_options(PROTOBUF_NAMESPACE_ID::EnumValueOptions* options) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + if (message_arena == nullptr) { + delete options_; + } + if (options) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + ::PROTOBUF_NAMESPACE_ID::Arena::GetArena(options); + if (message_arena != submessage_arena) { + options = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, options, submessage_arena); + } + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + options_ = options; + // @@protoc_insertion_point(field_set_allocated:google.protobuf.EnumValueDescriptorProto.options) +} + +// ------------------------------------------------------------------- + +// ServiceDescriptorProto + +// optional string name = 1; +inline bool ServiceDescriptorProto::_internal_has_name() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool ServiceDescriptorProto::has_name() const { + return _internal_has_name(); +} +inline void ServiceDescriptorProto::clear_name() { + name_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000001u; +} +inline const std::string& ServiceDescriptorProto::name() const { + // @@protoc_insertion_point(field_get:google.protobuf.ServiceDescriptorProto.name) + return _internal_name(); +} +inline void ServiceDescriptorProto::set_name(const std::string& value) { + _internal_set_name(value); + // @@protoc_insertion_point(field_set:google.protobuf.ServiceDescriptorProto.name) +} +inline std::string* ServiceDescriptorProto::mutable_name() { + // @@protoc_insertion_point(field_mutable:google.protobuf.ServiceDescriptorProto.name) + return _internal_mutable_name(); +} +inline const std::string& ServiceDescriptorProto::_internal_name() const { + return name_.Get(); +} +inline void ServiceDescriptorProto::_internal_set_name(const std::string& value) { + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void ServiceDescriptorProto::set_name(std::string&& value) { + _has_bits_[0] |= 0x00000001u; + name_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:google.protobuf.ServiceDescriptorProto.name) +} +inline void ServiceDescriptorProto::set_name(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:google.protobuf.ServiceDescriptorProto.name) +} +inline void ServiceDescriptorProto::set_name(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:google.protobuf.ServiceDescriptorProto.name) +} +inline std::string* ServiceDescriptorProto::_internal_mutable_name() { + _has_bits_[0] |= 0x00000001u; + return name_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* ServiceDescriptorProto::release_name() { + // @@protoc_insertion_point(field_release:google.protobuf.ServiceDescriptorProto.name) + if (!_internal_has_name()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000001u; + return name_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void ServiceDescriptorProto::set_allocated_name(std::string* name) { + if (name != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + name_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), name, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:google.protobuf.ServiceDescriptorProto.name) +} + +// repeated .google.protobuf.MethodDescriptorProto method = 2; +inline int ServiceDescriptorProto::_internal_method_size() const { + return method_.size(); +} +inline int ServiceDescriptorProto::method_size() const { + return _internal_method_size(); +} +inline void ServiceDescriptorProto::clear_method() { + method_.Clear(); +} +inline PROTOBUF_NAMESPACE_ID::MethodDescriptorProto* ServiceDescriptorProto::mutable_method(int index) { + // @@protoc_insertion_point(field_mutable:google.protobuf.ServiceDescriptorProto.method) + return method_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::MethodDescriptorProto >* +ServiceDescriptorProto::mutable_method() { + // @@protoc_insertion_point(field_mutable_list:google.protobuf.ServiceDescriptorProto.method) + return &method_; +} +inline const PROTOBUF_NAMESPACE_ID::MethodDescriptorProto& ServiceDescriptorProto::_internal_method(int index) const { + return method_.Get(index); +} +inline const PROTOBUF_NAMESPACE_ID::MethodDescriptorProto& ServiceDescriptorProto::method(int index) const { + // @@protoc_insertion_point(field_get:google.protobuf.ServiceDescriptorProto.method) + return _internal_method(index); +} +inline PROTOBUF_NAMESPACE_ID::MethodDescriptorProto* ServiceDescriptorProto::_internal_add_method() { + return method_.Add(); +} +inline PROTOBUF_NAMESPACE_ID::MethodDescriptorProto* ServiceDescriptorProto::add_method() { + // @@protoc_insertion_point(field_add:google.protobuf.ServiceDescriptorProto.method) + return _internal_add_method(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::MethodDescriptorProto >& +ServiceDescriptorProto::method() const { + // @@protoc_insertion_point(field_list:google.protobuf.ServiceDescriptorProto.method) + return method_; +} + +// optional .google.protobuf.ServiceOptions options = 3; +inline bool ServiceDescriptorProto::_internal_has_options() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + PROTOBUF_ASSUME(!value || options_ != nullptr); + return value; +} +inline bool ServiceDescriptorProto::has_options() const { + return _internal_has_options(); +} +inline void ServiceDescriptorProto::clear_options() { + if (options_ != nullptr) options_->Clear(); + _has_bits_[0] &= ~0x00000002u; +} +inline const PROTOBUF_NAMESPACE_ID::ServiceOptions& ServiceDescriptorProto::_internal_options() const { + const PROTOBUF_NAMESPACE_ID::ServiceOptions* p = options_; + return p != nullptr ? *p : *reinterpret_cast( + &PROTOBUF_NAMESPACE_ID::_ServiceOptions_default_instance_); +} +inline const PROTOBUF_NAMESPACE_ID::ServiceOptions& ServiceDescriptorProto::options() const { + // @@protoc_insertion_point(field_get:google.protobuf.ServiceDescriptorProto.options) + return _internal_options(); +} +inline void ServiceDescriptorProto::unsafe_arena_set_allocated_options( + PROTOBUF_NAMESPACE_ID::ServiceOptions* options) { + if (GetArena() == nullptr) { + delete reinterpret_cast<::PROTOBUF_NAMESPACE_ID::MessageLite*>(options_); + } + options_ = options; + if (options) { + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:google.protobuf.ServiceDescriptorProto.options) +} +inline PROTOBUF_NAMESPACE_ID::ServiceOptions* ServiceDescriptorProto::release_options() { + _has_bits_[0] &= ~0x00000002u; + PROTOBUF_NAMESPACE_ID::ServiceOptions* temp = options_; + options_ = nullptr; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + return temp; +} +inline PROTOBUF_NAMESPACE_ID::ServiceOptions* ServiceDescriptorProto::unsafe_arena_release_options() { + // @@protoc_insertion_point(field_release:google.protobuf.ServiceDescriptorProto.options) + _has_bits_[0] &= ~0x00000002u; + PROTOBUF_NAMESPACE_ID::ServiceOptions* temp = options_; + options_ = nullptr; + return temp; +} +inline PROTOBUF_NAMESPACE_ID::ServiceOptions* ServiceDescriptorProto::_internal_mutable_options() { + _has_bits_[0] |= 0x00000002u; + if (options_ == nullptr) { + auto* p = CreateMaybeMessage(GetArena()); + options_ = p; + } + return options_; +} +inline PROTOBUF_NAMESPACE_ID::ServiceOptions* ServiceDescriptorProto::mutable_options() { + // @@protoc_insertion_point(field_mutable:google.protobuf.ServiceDescriptorProto.options) + return _internal_mutable_options(); +} +inline void ServiceDescriptorProto::set_allocated_options(PROTOBUF_NAMESPACE_ID::ServiceOptions* options) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + if (message_arena == nullptr) { + delete options_; + } + if (options) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + ::PROTOBUF_NAMESPACE_ID::Arena::GetArena(options); + if (message_arena != submessage_arena) { + options = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, options, submessage_arena); + } + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + options_ = options; + // @@protoc_insertion_point(field_set_allocated:google.protobuf.ServiceDescriptorProto.options) +} + +// ------------------------------------------------------------------- + +// MethodDescriptorProto + +// optional string name = 1; +inline bool MethodDescriptorProto::_internal_has_name() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool MethodDescriptorProto::has_name() const { + return _internal_has_name(); +} +inline void MethodDescriptorProto::clear_name() { + name_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000001u; +} +inline const std::string& MethodDescriptorProto::name() const { + // @@protoc_insertion_point(field_get:google.protobuf.MethodDescriptorProto.name) + return _internal_name(); +} +inline void MethodDescriptorProto::set_name(const std::string& value) { + _internal_set_name(value); + // @@protoc_insertion_point(field_set:google.protobuf.MethodDescriptorProto.name) +} +inline std::string* MethodDescriptorProto::mutable_name() { + // @@protoc_insertion_point(field_mutable:google.protobuf.MethodDescriptorProto.name) + return _internal_mutable_name(); +} +inline const std::string& MethodDescriptorProto::_internal_name() const { + return name_.Get(); +} +inline void MethodDescriptorProto::_internal_set_name(const std::string& value) { + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void MethodDescriptorProto::set_name(std::string&& value) { + _has_bits_[0] |= 0x00000001u; + name_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:google.protobuf.MethodDescriptorProto.name) +} +inline void MethodDescriptorProto::set_name(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:google.protobuf.MethodDescriptorProto.name) +} +inline void MethodDescriptorProto::set_name(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:google.protobuf.MethodDescriptorProto.name) +} +inline std::string* MethodDescriptorProto::_internal_mutable_name() { + _has_bits_[0] |= 0x00000001u; + return name_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* MethodDescriptorProto::release_name() { + // @@protoc_insertion_point(field_release:google.protobuf.MethodDescriptorProto.name) + if (!_internal_has_name()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000001u; + return name_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void MethodDescriptorProto::set_allocated_name(std::string* name) { + if (name != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + name_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), name, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:google.protobuf.MethodDescriptorProto.name) +} + +// optional string input_type = 2; +inline bool MethodDescriptorProto::_internal_has_input_type() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + return value; +} +inline bool MethodDescriptorProto::has_input_type() const { + return _internal_has_input_type(); +} +inline void MethodDescriptorProto::clear_input_type() { + input_type_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000002u; +} +inline const std::string& MethodDescriptorProto::input_type() const { + // @@protoc_insertion_point(field_get:google.protobuf.MethodDescriptorProto.input_type) + return _internal_input_type(); +} +inline void MethodDescriptorProto::set_input_type(const std::string& value) { + _internal_set_input_type(value); + // @@protoc_insertion_point(field_set:google.protobuf.MethodDescriptorProto.input_type) +} +inline std::string* MethodDescriptorProto::mutable_input_type() { + // @@protoc_insertion_point(field_mutable:google.protobuf.MethodDescriptorProto.input_type) + return _internal_mutable_input_type(); +} +inline const std::string& MethodDescriptorProto::_internal_input_type() const { + return input_type_.Get(); +} +inline void MethodDescriptorProto::_internal_set_input_type(const std::string& value) { + _has_bits_[0] |= 0x00000002u; + input_type_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void MethodDescriptorProto::set_input_type(std::string&& value) { + _has_bits_[0] |= 0x00000002u; + input_type_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:google.protobuf.MethodDescriptorProto.input_type) +} +inline void MethodDescriptorProto::set_input_type(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000002u; + input_type_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:google.protobuf.MethodDescriptorProto.input_type) +} +inline void MethodDescriptorProto::set_input_type(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000002u; + input_type_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:google.protobuf.MethodDescriptorProto.input_type) +} +inline std::string* MethodDescriptorProto::_internal_mutable_input_type() { + _has_bits_[0] |= 0x00000002u; + return input_type_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* MethodDescriptorProto::release_input_type() { + // @@protoc_insertion_point(field_release:google.protobuf.MethodDescriptorProto.input_type) + if (!_internal_has_input_type()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000002u; + return input_type_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void MethodDescriptorProto::set_allocated_input_type(std::string* input_type) { + if (input_type != nullptr) { + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + input_type_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), input_type, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:google.protobuf.MethodDescriptorProto.input_type) +} + +// optional string output_type = 3; +inline bool MethodDescriptorProto::_internal_has_output_type() const { + bool value = (_has_bits_[0] & 0x00000004u) != 0; + return value; +} +inline bool MethodDescriptorProto::has_output_type() const { + return _internal_has_output_type(); +} +inline void MethodDescriptorProto::clear_output_type() { + output_type_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000004u; +} +inline const std::string& MethodDescriptorProto::output_type() const { + // @@protoc_insertion_point(field_get:google.protobuf.MethodDescriptorProto.output_type) + return _internal_output_type(); +} +inline void MethodDescriptorProto::set_output_type(const std::string& value) { + _internal_set_output_type(value); + // @@protoc_insertion_point(field_set:google.protobuf.MethodDescriptorProto.output_type) +} +inline std::string* MethodDescriptorProto::mutable_output_type() { + // @@protoc_insertion_point(field_mutable:google.protobuf.MethodDescriptorProto.output_type) + return _internal_mutable_output_type(); +} +inline const std::string& MethodDescriptorProto::_internal_output_type() const { + return output_type_.Get(); +} +inline void MethodDescriptorProto::_internal_set_output_type(const std::string& value) { + _has_bits_[0] |= 0x00000004u; + output_type_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void MethodDescriptorProto::set_output_type(std::string&& value) { + _has_bits_[0] |= 0x00000004u; + output_type_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:google.protobuf.MethodDescriptorProto.output_type) +} +inline void MethodDescriptorProto::set_output_type(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000004u; + output_type_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:google.protobuf.MethodDescriptorProto.output_type) +} +inline void MethodDescriptorProto::set_output_type(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000004u; + output_type_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:google.protobuf.MethodDescriptorProto.output_type) +} +inline std::string* MethodDescriptorProto::_internal_mutable_output_type() { + _has_bits_[0] |= 0x00000004u; + return output_type_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* MethodDescriptorProto::release_output_type() { + // @@protoc_insertion_point(field_release:google.protobuf.MethodDescriptorProto.output_type) + if (!_internal_has_output_type()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000004u; + return output_type_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void MethodDescriptorProto::set_allocated_output_type(std::string* output_type) { + if (output_type != nullptr) { + _has_bits_[0] |= 0x00000004u; + } else { + _has_bits_[0] &= ~0x00000004u; + } + output_type_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), output_type, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:google.protobuf.MethodDescriptorProto.output_type) +} + +// optional .google.protobuf.MethodOptions options = 4; +inline bool MethodDescriptorProto::_internal_has_options() const { + bool value = (_has_bits_[0] & 0x00000008u) != 0; + PROTOBUF_ASSUME(!value || options_ != nullptr); + return value; +} +inline bool MethodDescriptorProto::has_options() const { + return _internal_has_options(); +} +inline void MethodDescriptorProto::clear_options() { + if (options_ != nullptr) options_->Clear(); + _has_bits_[0] &= ~0x00000008u; +} +inline const PROTOBUF_NAMESPACE_ID::MethodOptions& MethodDescriptorProto::_internal_options() const { + const PROTOBUF_NAMESPACE_ID::MethodOptions* p = options_; + return p != nullptr ? *p : *reinterpret_cast( + &PROTOBUF_NAMESPACE_ID::_MethodOptions_default_instance_); +} +inline const PROTOBUF_NAMESPACE_ID::MethodOptions& MethodDescriptorProto::options() const { + // @@protoc_insertion_point(field_get:google.protobuf.MethodDescriptorProto.options) + return _internal_options(); +} +inline void MethodDescriptorProto::unsafe_arena_set_allocated_options( + PROTOBUF_NAMESPACE_ID::MethodOptions* options) { + if (GetArena() == nullptr) { + delete reinterpret_cast<::PROTOBUF_NAMESPACE_ID::MessageLite*>(options_); + } + options_ = options; + if (options) { + _has_bits_[0] |= 0x00000008u; + } else { + _has_bits_[0] &= ~0x00000008u; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:google.protobuf.MethodDescriptorProto.options) +} +inline PROTOBUF_NAMESPACE_ID::MethodOptions* MethodDescriptorProto::release_options() { + _has_bits_[0] &= ~0x00000008u; + PROTOBUF_NAMESPACE_ID::MethodOptions* temp = options_; + options_ = nullptr; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + return temp; +} +inline PROTOBUF_NAMESPACE_ID::MethodOptions* MethodDescriptorProto::unsafe_arena_release_options() { + // @@protoc_insertion_point(field_release:google.protobuf.MethodDescriptorProto.options) + _has_bits_[0] &= ~0x00000008u; + PROTOBUF_NAMESPACE_ID::MethodOptions* temp = options_; + options_ = nullptr; + return temp; +} +inline PROTOBUF_NAMESPACE_ID::MethodOptions* MethodDescriptorProto::_internal_mutable_options() { + _has_bits_[0] |= 0x00000008u; + if (options_ == nullptr) { + auto* p = CreateMaybeMessage(GetArena()); + options_ = p; + } + return options_; +} +inline PROTOBUF_NAMESPACE_ID::MethodOptions* MethodDescriptorProto::mutable_options() { + // @@protoc_insertion_point(field_mutable:google.protobuf.MethodDescriptorProto.options) + return _internal_mutable_options(); +} +inline void MethodDescriptorProto::set_allocated_options(PROTOBUF_NAMESPACE_ID::MethodOptions* options) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + if (message_arena == nullptr) { + delete options_; + } + if (options) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + ::PROTOBUF_NAMESPACE_ID::Arena::GetArena(options); + if (message_arena != submessage_arena) { + options = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, options, submessage_arena); + } + _has_bits_[0] |= 0x00000008u; + } else { + _has_bits_[0] &= ~0x00000008u; + } + options_ = options; + // @@protoc_insertion_point(field_set_allocated:google.protobuf.MethodDescriptorProto.options) +} + +// optional bool client_streaming = 5 [default = false]; +inline bool MethodDescriptorProto::_internal_has_client_streaming() const { + bool value = (_has_bits_[0] & 0x00000010u) != 0; + return value; +} +inline bool MethodDescriptorProto::has_client_streaming() const { + return _internal_has_client_streaming(); +} +inline void MethodDescriptorProto::clear_client_streaming() { + client_streaming_ = false; + _has_bits_[0] &= ~0x00000010u; +} +inline bool MethodDescriptorProto::_internal_client_streaming() const { + return client_streaming_; +} +inline bool MethodDescriptorProto::client_streaming() const { + // @@protoc_insertion_point(field_get:google.protobuf.MethodDescriptorProto.client_streaming) + return _internal_client_streaming(); +} +inline void MethodDescriptorProto::_internal_set_client_streaming(bool value) { + _has_bits_[0] |= 0x00000010u; + client_streaming_ = value; +} +inline void MethodDescriptorProto::set_client_streaming(bool value) { + _internal_set_client_streaming(value); + // @@protoc_insertion_point(field_set:google.protobuf.MethodDescriptorProto.client_streaming) +} + +// optional bool server_streaming = 6 [default = false]; +inline bool MethodDescriptorProto::_internal_has_server_streaming() const { + bool value = (_has_bits_[0] & 0x00000020u) != 0; + return value; +} +inline bool MethodDescriptorProto::has_server_streaming() const { + return _internal_has_server_streaming(); +} +inline void MethodDescriptorProto::clear_server_streaming() { + server_streaming_ = false; + _has_bits_[0] &= ~0x00000020u; +} +inline bool MethodDescriptorProto::_internal_server_streaming() const { + return server_streaming_; +} +inline bool MethodDescriptorProto::server_streaming() const { + // @@protoc_insertion_point(field_get:google.protobuf.MethodDescriptorProto.server_streaming) + return _internal_server_streaming(); +} +inline void MethodDescriptorProto::_internal_set_server_streaming(bool value) { + _has_bits_[0] |= 0x00000020u; + server_streaming_ = value; +} +inline void MethodDescriptorProto::set_server_streaming(bool value) { + _internal_set_server_streaming(value); + // @@protoc_insertion_point(field_set:google.protobuf.MethodDescriptorProto.server_streaming) +} + +// ------------------------------------------------------------------- + +// FileOptions + +// optional string java_package = 1; +inline bool FileOptions::_internal_has_java_package() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool FileOptions::has_java_package() const { + return _internal_has_java_package(); +} +inline void FileOptions::clear_java_package() { + java_package_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000001u; +} +inline const std::string& FileOptions::java_package() const { + // @@protoc_insertion_point(field_get:google.protobuf.FileOptions.java_package) + return _internal_java_package(); +} +inline void FileOptions::set_java_package(const std::string& value) { + _internal_set_java_package(value); + // @@protoc_insertion_point(field_set:google.protobuf.FileOptions.java_package) +} +inline std::string* FileOptions::mutable_java_package() { + // @@protoc_insertion_point(field_mutable:google.protobuf.FileOptions.java_package) + return _internal_mutable_java_package(); +} +inline const std::string& FileOptions::_internal_java_package() const { + return java_package_.Get(); +} +inline void FileOptions::_internal_set_java_package(const std::string& value) { + _has_bits_[0] |= 0x00000001u; + java_package_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void FileOptions::set_java_package(std::string&& value) { + _has_bits_[0] |= 0x00000001u; + java_package_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:google.protobuf.FileOptions.java_package) +} +inline void FileOptions::set_java_package(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000001u; + java_package_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:google.protobuf.FileOptions.java_package) +} +inline void FileOptions::set_java_package(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000001u; + java_package_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:google.protobuf.FileOptions.java_package) +} +inline std::string* FileOptions::_internal_mutable_java_package() { + _has_bits_[0] |= 0x00000001u; + return java_package_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* FileOptions::release_java_package() { + // @@protoc_insertion_point(field_release:google.protobuf.FileOptions.java_package) + if (!_internal_has_java_package()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000001u; + return java_package_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void FileOptions::set_allocated_java_package(std::string* java_package) { + if (java_package != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + java_package_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), java_package, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:google.protobuf.FileOptions.java_package) +} + +// optional string java_outer_classname = 8; +inline bool FileOptions::_internal_has_java_outer_classname() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + return value; +} +inline bool FileOptions::has_java_outer_classname() const { + return _internal_has_java_outer_classname(); +} +inline void FileOptions::clear_java_outer_classname() { + java_outer_classname_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000002u; +} +inline const std::string& FileOptions::java_outer_classname() const { + // @@protoc_insertion_point(field_get:google.protobuf.FileOptions.java_outer_classname) + return _internal_java_outer_classname(); +} +inline void FileOptions::set_java_outer_classname(const std::string& value) { + _internal_set_java_outer_classname(value); + // @@protoc_insertion_point(field_set:google.protobuf.FileOptions.java_outer_classname) +} +inline std::string* FileOptions::mutable_java_outer_classname() { + // @@protoc_insertion_point(field_mutable:google.protobuf.FileOptions.java_outer_classname) + return _internal_mutable_java_outer_classname(); +} +inline const std::string& FileOptions::_internal_java_outer_classname() const { + return java_outer_classname_.Get(); +} +inline void FileOptions::_internal_set_java_outer_classname(const std::string& value) { + _has_bits_[0] |= 0x00000002u; + java_outer_classname_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void FileOptions::set_java_outer_classname(std::string&& value) { + _has_bits_[0] |= 0x00000002u; + java_outer_classname_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:google.protobuf.FileOptions.java_outer_classname) +} +inline void FileOptions::set_java_outer_classname(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000002u; + java_outer_classname_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:google.protobuf.FileOptions.java_outer_classname) +} +inline void FileOptions::set_java_outer_classname(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000002u; + java_outer_classname_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:google.protobuf.FileOptions.java_outer_classname) +} +inline std::string* FileOptions::_internal_mutable_java_outer_classname() { + _has_bits_[0] |= 0x00000002u; + return java_outer_classname_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* FileOptions::release_java_outer_classname() { + // @@protoc_insertion_point(field_release:google.protobuf.FileOptions.java_outer_classname) + if (!_internal_has_java_outer_classname()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000002u; + return java_outer_classname_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void FileOptions::set_allocated_java_outer_classname(std::string* java_outer_classname) { + if (java_outer_classname != nullptr) { + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + java_outer_classname_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), java_outer_classname, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:google.protobuf.FileOptions.java_outer_classname) +} + +// optional bool java_multiple_files = 10 [default = false]; +inline bool FileOptions::_internal_has_java_multiple_files() const { + bool value = (_has_bits_[0] & 0x00000400u) != 0; + return value; +} +inline bool FileOptions::has_java_multiple_files() const { + return _internal_has_java_multiple_files(); +} +inline void FileOptions::clear_java_multiple_files() { + java_multiple_files_ = false; + _has_bits_[0] &= ~0x00000400u; +} +inline bool FileOptions::_internal_java_multiple_files() const { + return java_multiple_files_; +} +inline bool FileOptions::java_multiple_files() const { + // @@protoc_insertion_point(field_get:google.protobuf.FileOptions.java_multiple_files) + return _internal_java_multiple_files(); +} +inline void FileOptions::_internal_set_java_multiple_files(bool value) { + _has_bits_[0] |= 0x00000400u; + java_multiple_files_ = value; +} +inline void FileOptions::set_java_multiple_files(bool value) { + _internal_set_java_multiple_files(value); + // @@protoc_insertion_point(field_set:google.protobuf.FileOptions.java_multiple_files) +} + +// optional bool java_generate_equals_and_hash = 20 [deprecated = true]; +inline bool FileOptions::_internal_has_java_generate_equals_and_hash() const { + bool value = (_has_bits_[0] & 0x00000800u) != 0; + return value; +} +inline bool FileOptions::has_java_generate_equals_and_hash() const { + return _internal_has_java_generate_equals_and_hash(); +} +inline void FileOptions::clear_java_generate_equals_and_hash() { + java_generate_equals_and_hash_ = false; + _has_bits_[0] &= ~0x00000800u; +} +inline bool FileOptions::_internal_java_generate_equals_and_hash() const { + return java_generate_equals_and_hash_; +} +inline bool FileOptions::java_generate_equals_and_hash() const { + // @@protoc_insertion_point(field_get:google.protobuf.FileOptions.java_generate_equals_and_hash) + return _internal_java_generate_equals_and_hash(); +} +inline void FileOptions::_internal_set_java_generate_equals_and_hash(bool value) { + _has_bits_[0] |= 0x00000800u; + java_generate_equals_and_hash_ = value; +} +inline void FileOptions::set_java_generate_equals_and_hash(bool value) { + _internal_set_java_generate_equals_and_hash(value); + // @@protoc_insertion_point(field_set:google.protobuf.FileOptions.java_generate_equals_and_hash) +} + +// optional bool java_string_check_utf8 = 27 [default = false]; +inline bool FileOptions::_internal_has_java_string_check_utf8() const { + bool value = (_has_bits_[0] & 0x00001000u) != 0; + return value; +} +inline bool FileOptions::has_java_string_check_utf8() const { + return _internal_has_java_string_check_utf8(); +} +inline void FileOptions::clear_java_string_check_utf8() { + java_string_check_utf8_ = false; + _has_bits_[0] &= ~0x00001000u; +} +inline bool FileOptions::_internal_java_string_check_utf8() const { + return java_string_check_utf8_; +} +inline bool FileOptions::java_string_check_utf8() const { + // @@protoc_insertion_point(field_get:google.protobuf.FileOptions.java_string_check_utf8) + return _internal_java_string_check_utf8(); +} +inline void FileOptions::_internal_set_java_string_check_utf8(bool value) { + _has_bits_[0] |= 0x00001000u; + java_string_check_utf8_ = value; +} +inline void FileOptions::set_java_string_check_utf8(bool value) { + _internal_set_java_string_check_utf8(value); + // @@protoc_insertion_point(field_set:google.protobuf.FileOptions.java_string_check_utf8) +} + +// optional .google.protobuf.FileOptions.OptimizeMode optimize_for = 9 [default = SPEED]; +inline bool FileOptions::_internal_has_optimize_for() const { + bool value = (_has_bits_[0] & 0x00040000u) != 0; + return value; +} +inline bool FileOptions::has_optimize_for() const { + return _internal_has_optimize_for(); +} +inline void FileOptions::clear_optimize_for() { + optimize_for_ = 1; + _has_bits_[0] &= ~0x00040000u; +} +inline PROTOBUF_NAMESPACE_ID::FileOptions_OptimizeMode FileOptions::_internal_optimize_for() const { + return static_cast< PROTOBUF_NAMESPACE_ID::FileOptions_OptimizeMode >(optimize_for_); +} +inline PROTOBUF_NAMESPACE_ID::FileOptions_OptimizeMode FileOptions::optimize_for() const { + // @@protoc_insertion_point(field_get:google.protobuf.FileOptions.optimize_for) + return _internal_optimize_for(); +} +inline void FileOptions::_internal_set_optimize_for(PROTOBUF_NAMESPACE_ID::FileOptions_OptimizeMode value) { + assert(PROTOBUF_NAMESPACE_ID::FileOptions_OptimizeMode_IsValid(value)); + _has_bits_[0] |= 0x00040000u; + optimize_for_ = value; +} +inline void FileOptions::set_optimize_for(PROTOBUF_NAMESPACE_ID::FileOptions_OptimizeMode value) { + _internal_set_optimize_for(value); + // @@protoc_insertion_point(field_set:google.protobuf.FileOptions.optimize_for) +} + +// optional string go_package = 11; +inline bool FileOptions::_internal_has_go_package() const { + bool value = (_has_bits_[0] & 0x00000004u) != 0; + return value; +} +inline bool FileOptions::has_go_package() const { + return _internal_has_go_package(); +} +inline void FileOptions::clear_go_package() { + go_package_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000004u; +} +inline const std::string& FileOptions::go_package() const { + // @@protoc_insertion_point(field_get:google.protobuf.FileOptions.go_package) + return _internal_go_package(); +} +inline void FileOptions::set_go_package(const std::string& value) { + _internal_set_go_package(value); + // @@protoc_insertion_point(field_set:google.protobuf.FileOptions.go_package) +} +inline std::string* FileOptions::mutable_go_package() { + // @@protoc_insertion_point(field_mutable:google.protobuf.FileOptions.go_package) + return _internal_mutable_go_package(); +} +inline const std::string& FileOptions::_internal_go_package() const { + return go_package_.Get(); +} +inline void FileOptions::_internal_set_go_package(const std::string& value) { + _has_bits_[0] |= 0x00000004u; + go_package_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void FileOptions::set_go_package(std::string&& value) { + _has_bits_[0] |= 0x00000004u; + go_package_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:google.protobuf.FileOptions.go_package) +} +inline void FileOptions::set_go_package(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000004u; + go_package_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:google.protobuf.FileOptions.go_package) +} +inline void FileOptions::set_go_package(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000004u; + go_package_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:google.protobuf.FileOptions.go_package) +} +inline std::string* FileOptions::_internal_mutable_go_package() { + _has_bits_[0] |= 0x00000004u; + return go_package_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* FileOptions::release_go_package() { + // @@protoc_insertion_point(field_release:google.protobuf.FileOptions.go_package) + if (!_internal_has_go_package()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000004u; + return go_package_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void FileOptions::set_allocated_go_package(std::string* go_package) { + if (go_package != nullptr) { + _has_bits_[0] |= 0x00000004u; + } else { + _has_bits_[0] &= ~0x00000004u; + } + go_package_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), go_package, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:google.protobuf.FileOptions.go_package) +} + +// optional bool cc_generic_services = 16 [default = false]; +inline bool FileOptions::_internal_has_cc_generic_services() const { + bool value = (_has_bits_[0] & 0x00002000u) != 0; + return value; +} +inline bool FileOptions::has_cc_generic_services() const { + return _internal_has_cc_generic_services(); +} +inline void FileOptions::clear_cc_generic_services() { + cc_generic_services_ = false; + _has_bits_[0] &= ~0x00002000u; +} +inline bool FileOptions::_internal_cc_generic_services() const { + return cc_generic_services_; +} +inline bool FileOptions::cc_generic_services() const { + // @@protoc_insertion_point(field_get:google.protobuf.FileOptions.cc_generic_services) + return _internal_cc_generic_services(); +} +inline void FileOptions::_internal_set_cc_generic_services(bool value) { + _has_bits_[0] |= 0x00002000u; + cc_generic_services_ = value; +} +inline void FileOptions::set_cc_generic_services(bool value) { + _internal_set_cc_generic_services(value); + // @@protoc_insertion_point(field_set:google.protobuf.FileOptions.cc_generic_services) +} + +// optional bool java_generic_services = 17 [default = false]; +inline bool FileOptions::_internal_has_java_generic_services() const { + bool value = (_has_bits_[0] & 0x00004000u) != 0; + return value; +} +inline bool FileOptions::has_java_generic_services() const { + return _internal_has_java_generic_services(); +} +inline void FileOptions::clear_java_generic_services() { + java_generic_services_ = false; + _has_bits_[0] &= ~0x00004000u; +} +inline bool FileOptions::_internal_java_generic_services() const { + return java_generic_services_; +} +inline bool FileOptions::java_generic_services() const { + // @@protoc_insertion_point(field_get:google.protobuf.FileOptions.java_generic_services) + return _internal_java_generic_services(); +} +inline void FileOptions::_internal_set_java_generic_services(bool value) { + _has_bits_[0] |= 0x00004000u; + java_generic_services_ = value; +} +inline void FileOptions::set_java_generic_services(bool value) { + _internal_set_java_generic_services(value); + // @@protoc_insertion_point(field_set:google.protobuf.FileOptions.java_generic_services) +} + +// optional bool py_generic_services = 18 [default = false]; +inline bool FileOptions::_internal_has_py_generic_services() const { + bool value = (_has_bits_[0] & 0x00008000u) != 0; + return value; +} +inline bool FileOptions::has_py_generic_services() const { + return _internal_has_py_generic_services(); +} +inline void FileOptions::clear_py_generic_services() { + py_generic_services_ = false; + _has_bits_[0] &= ~0x00008000u; +} +inline bool FileOptions::_internal_py_generic_services() const { + return py_generic_services_; +} +inline bool FileOptions::py_generic_services() const { + // @@protoc_insertion_point(field_get:google.protobuf.FileOptions.py_generic_services) + return _internal_py_generic_services(); +} +inline void FileOptions::_internal_set_py_generic_services(bool value) { + _has_bits_[0] |= 0x00008000u; + py_generic_services_ = value; +} +inline void FileOptions::set_py_generic_services(bool value) { + _internal_set_py_generic_services(value); + // @@protoc_insertion_point(field_set:google.protobuf.FileOptions.py_generic_services) +} + +// optional bool php_generic_services = 42 [default = false]; +inline bool FileOptions::_internal_has_php_generic_services() const { + bool value = (_has_bits_[0] & 0x00010000u) != 0; + return value; +} +inline bool FileOptions::has_php_generic_services() const { + return _internal_has_php_generic_services(); +} +inline void FileOptions::clear_php_generic_services() { + php_generic_services_ = false; + _has_bits_[0] &= ~0x00010000u; +} +inline bool FileOptions::_internal_php_generic_services() const { + return php_generic_services_; +} +inline bool FileOptions::php_generic_services() const { + // @@protoc_insertion_point(field_get:google.protobuf.FileOptions.php_generic_services) + return _internal_php_generic_services(); +} +inline void FileOptions::_internal_set_php_generic_services(bool value) { + _has_bits_[0] |= 0x00010000u; + php_generic_services_ = value; +} +inline void FileOptions::set_php_generic_services(bool value) { + _internal_set_php_generic_services(value); + // @@protoc_insertion_point(field_set:google.protobuf.FileOptions.php_generic_services) +} + +// optional bool deprecated = 23 [default = false]; +inline bool FileOptions::_internal_has_deprecated() const { + bool value = (_has_bits_[0] & 0x00020000u) != 0; + return value; +} +inline bool FileOptions::has_deprecated() const { + return _internal_has_deprecated(); +} +inline void FileOptions::clear_deprecated() { + deprecated_ = false; + _has_bits_[0] &= ~0x00020000u; +} +inline bool FileOptions::_internal_deprecated() const { + return deprecated_; +} +inline bool FileOptions::deprecated() const { + // @@protoc_insertion_point(field_get:google.protobuf.FileOptions.deprecated) + return _internal_deprecated(); +} +inline void FileOptions::_internal_set_deprecated(bool value) { + _has_bits_[0] |= 0x00020000u; + deprecated_ = value; +} +inline void FileOptions::set_deprecated(bool value) { + _internal_set_deprecated(value); + // @@protoc_insertion_point(field_set:google.protobuf.FileOptions.deprecated) +} + +// optional bool cc_enable_arenas = 31 [default = true]; +inline bool FileOptions::_internal_has_cc_enable_arenas() const { + bool value = (_has_bits_[0] & 0x00080000u) != 0; + return value; +} +inline bool FileOptions::has_cc_enable_arenas() const { + return _internal_has_cc_enable_arenas(); +} +inline void FileOptions::clear_cc_enable_arenas() { + cc_enable_arenas_ = true; + _has_bits_[0] &= ~0x00080000u; +} +inline bool FileOptions::_internal_cc_enable_arenas() const { + return cc_enable_arenas_; +} +inline bool FileOptions::cc_enable_arenas() const { + // @@protoc_insertion_point(field_get:google.protobuf.FileOptions.cc_enable_arenas) + return _internal_cc_enable_arenas(); +} +inline void FileOptions::_internal_set_cc_enable_arenas(bool value) { + _has_bits_[0] |= 0x00080000u; + cc_enable_arenas_ = value; +} +inline void FileOptions::set_cc_enable_arenas(bool value) { + _internal_set_cc_enable_arenas(value); + // @@protoc_insertion_point(field_set:google.protobuf.FileOptions.cc_enable_arenas) +} + +// optional string objc_class_prefix = 36; +inline bool FileOptions::_internal_has_objc_class_prefix() const { + bool value = (_has_bits_[0] & 0x00000008u) != 0; + return value; +} +inline bool FileOptions::has_objc_class_prefix() const { + return _internal_has_objc_class_prefix(); +} +inline void FileOptions::clear_objc_class_prefix() { + objc_class_prefix_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000008u; +} +inline const std::string& FileOptions::objc_class_prefix() const { + // @@protoc_insertion_point(field_get:google.protobuf.FileOptions.objc_class_prefix) + return _internal_objc_class_prefix(); +} +inline void FileOptions::set_objc_class_prefix(const std::string& value) { + _internal_set_objc_class_prefix(value); + // @@protoc_insertion_point(field_set:google.protobuf.FileOptions.objc_class_prefix) +} +inline std::string* FileOptions::mutable_objc_class_prefix() { + // @@protoc_insertion_point(field_mutable:google.protobuf.FileOptions.objc_class_prefix) + return _internal_mutable_objc_class_prefix(); +} +inline const std::string& FileOptions::_internal_objc_class_prefix() const { + return objc_class_prefix_.Get(); +} +inline void FileOptions::_internal_set_objc_class_prefix(const std::string& value) { + _has_bits_[0] |= 0x00000008u; + objc_class_prefix_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void FileOptions::set_objc_class_prefix(std::string&& value) { + _has_bits_[0] |= 0x00000008u; + objc_class_prefix_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:google.protobuf.FileOptions.objc_class_prefix) +} +inline void FileOptions::set_objc_class_prefix(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000008u; + objc_class_prefix_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:google.protobuf.FileOptions.objc_class_prefix) +} +inline void FileOptions::set_objc_class_prefix(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000008u; + objc_class_prefix_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:google.protobuf.FileOptions.objc_class_prefix) +} +inline std::string* FileOptions::_internal_mutable_objc_class_prefix() { + _has_bits_[0] |= 0x00000008u; + return objc_class_prefix_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* FileOptions::release_objc_class_prefix() { + // @@protoc_insertion_point(field_release:google.protobuf.FileOptions.objc_class_prefix) + if (!_internal_has_objc_class_prefix()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000008u; + return objc_class_prefix_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void FileOptions::set_allocated_objc_class_prefix(std::string* objc_class_prefix) { + if (objc_class_prefix != nullptr) { + _has_bits_[0] |= 0x00000008u; + } else { + _has_bits_[0] &= ~0x00000008u; + } + objc_class_prefix_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), objc_class_prefix, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:google.protobuf.FileOptions.objc_class_prefix) +} + +// optional string csharp_namespace = 37; +inline bool FileOptions::_internal_has_csharp_namespace() const { + bool value = (_has_bits_[0] & 0x00000010u) != 0; + return value; +} +inline bool FileOptions::has_csharp_namespace() const { + return _internal_has_csharp_namespace(); +} +inline void FileOptions::clear_csharp_namespace() { + csharp_namespace_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000010u; +} +inline const std::string& FileOptions::csharp_namespace() const { + // @@protoc_insertion_point(field_get:google.protobuf.FileOptions.csharp_namespace) + return _internal_csharp_namespace(); +} +inline void FileOptions::set_csharp_namespace(const std::string& value) { + _internal_set_csharp_namespace(value); + // @@protoc_insertion_point(field_set:google.protobuf.FileOptions.csharp_namespace) +} +inline std::string* FileOptions::mutable_csharp_namespace() { + // @@protoc_insertion_point(field_mutable:google.protobuf.FileOptions.csharp_namespace) + return _internal_mutable_csharp_namespace(); +} +inline const std::string& FileOptions::_internal_csharp_namespace() const { + return csharp_namespace_.Get(); +} +inline void FileOptions::_internal_set_csharp_namespace(const std::string& value) { + _has_bits_[0] |= 0x00000010u; + csharp_namespace_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void FileOptions::set_csharp_namespace(std::string&& value) { + _has_bits_[0] |= 0x00000010u; + csharp_namespace_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:google.protobuf.FileOptions.csharp_namespace) +} +inline void FileOptions::set_csharp_namespace(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000010u; + csharp_namespace_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:google.protobuf.FileOptions.csharp_namespace) +} +inline void FileOptions::set_csharp_namespace(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000010u; + csharp_namespace_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:google.protobuf.FileOptions.csharp_namespace) +} +inline std::string* FileOptions::_internal_mutable_csharp_namespace() { + _has_bits_[0] |= 0x00000010u; + return csharp_namespace_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* FileOptions::release_csharp_namespace() { + // @@protoc_insertion_point(field_release:google.protobuf.FileOptions.csharp_namespace) + if (!_internal_has_csharp_namespace()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000010u; + return csharp_namespace_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void FileOptions::set_allocated_csharp_namespace(std::string* csharp_namespace) { + if (csharp_namespace != nullptr) { + _has_bits_[0] |= 0x00000010u; + } else { + _has_bits_[0] &= ~0x00000010u; + } + csharp_namespace_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), csharp_namespace, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:google.protobuf.FileOptions.csharp_namespace) +} + +// optional string swift_prefix = 39; +inline bool FileOptions::_internal_has_swift_prefix() const { + bool value = (_has_bits_[0] & 0x00000020u) != 0; + return value; +} +inline bool FileOptions::has_swift_prefix() const { + return _internal_has_swift_prefix(); +} +inline void FileOptions::clear_swift_prefix() { + swift_prefix_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000020u; +} +inline const std::string& FileOptions::swift_prefix() const { + // @@protoc_insertion_point(field_get:google.protobuf.FileOptions.swift_prefix) + return _internal_swift_prefix(); +} +inline void FileOptions::set_swift_prefix(const std::string& value) { + _internal_set_swift_prefix(value); + // @@protoc_insertion_point(field_set:google.protobuf.FileOptions.swift_prefix) +} +inline std::string* FileOptions::mutable_swift_prefix() { + // @@protoc_insertion_point(field_mutable:google.protobuf.FileOptions.swift_prefix) + return _internal_mutable_swift_prefix(); +} +inline const std::string& FileOptions::_internal_swift_prefix() const { + return swift_prefix_.Get(); +} +inline void FileOptions::_internal_set_swift_prefix(const std::string& value) { + _has_bits_[0] |= 0x00000020u; + swift_prefix_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void FileOptions::set_swift_prefix(std::string&& value) { + _has_bits_[0] |= 0x00000020u; + swift_prefix_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:google.protobuf.FileOptions.swift_prefix) +} +inline void FileOptions::set_swift_prefix(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000020u; + swift_prefix_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:google.protobuf.FileOptions.swift_prefix) +} +inline void FileOptions::set_swift_prefix(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000020u; + swift_prefix_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:google.protobuf.FileOptions.swift_prefix) +} +inline std::string* FileOptions::_internal_mutable_swift_prefix() { + _has_bits_[0] |= 0x00000020u; + return swift_prefix_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* FileOptions::release_swift_prefix() { + // @@protoc_insertion_point(field_release:google.protobuf.FileOptions.swift_prefix) + if (!_internal_has_swift_prefix()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000020u; + return swift_prefix_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void FileOptions::set_allocated_swift_prefix(std::string* swift_prefix) { + if (swift_prefix != nullptr) { + _has_bits_[0] |= 0x00000020u; + } else { + _has_bits_[0] &= ~0x00000020u; + } + swift_prefix_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), swift_prefix, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:google.protobuf.FileOptions.swift_prefix) +} + +// optional string php_class_prefix = 40; +inline bool FileOptions::_internal_has_php_class_prefix() const { + bool value = (_has_bits_[0] & 0x00000040u) != 0; + return value; +} +inline bool FileOptions::has_php_class_prefix() const { + return _internal_has_php_class_prefix(); +} +inline void FileOptions::clear_php_class_prefix() { + php_class_prefix_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000040u; +} +inline const std::string& FileOptions::php_class_prefix() const { + // @@protoc_insertion_point(field_get:google.protobuf.FileOptions.php_class_prefix) + return _internal_php_class_prefix(); +} +inline void FileOptions::set_php_class_prefix(const std::string& value) { + _internal_set_php_class_prefix(value); + // @@protoc_insertion_point(field_set:google.protobuf.FileOptions.php_class_prefix) +} +inline std::string* FileOptions::mutable_php_class_prefix() { + // @@protoc_insertion_point(field_mutable:google.protobuf.FileOptions.php_class_prefix) + return _internal_mutable_php_class_prefix(); +} +inline const std::string& FileOptions::_internal_php_class_prefix() const { + return php_class_prefix_.Get(); +} +inline void FileOptions::_internal_set_php_class_prefix(const std::string& value) { + _has_bits_[0] |= 0x00000040u; + php_class_prefix_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void FileOptions::set_php_class_prefix(std::string&& value) { + _has_bits_[0] |= 0x00000040u; + php_class_prefix_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:google.protobuf.FileOptions.php_class_prefix) +} +inline void FileOptions::set_php_class_prefix(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000040u; + php_class_prefix_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:google.protobuf.FileOptions.php_class_prefix) +} +inline void FileOptions::set_php_class_prefix(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000040u; + php_class_prefix_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:google.protobuf.FileOptions.php_class_prefix) +} +inline std::string* FileOptions::_internal_mutable_php_class_prefix() { + _has_bits_[0] |= 0x00000040u; + return php_class_prefix_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* FileOptions::release_php_class_prefix() { + // @@protoc_insertion_point(field_release:google.protobuf.FileOptions.php_class_prefix) + if (!_internal_has_php_class_prefix()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000040u; + return php_class_prefix_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void FileOptions::set_allocated_php_class_prefix(std::string* php_class_prefix) { + if (php_class_prefix != nullptr) { + _has_bits_[0] |= 0x00000040u; + } else { + _has_bits_[0] &= ~0x00000040u; + } + php_class_prefix_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), php_class_prefix, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:google.protobuf.FileOptions.php_class_prefix) +} + +// optional string php_namespace = 41; +inline bool FileOptions::_internal_has_php_namespace() const { + bool value = (_has_bits_[0] & 0x00000080u) != 0; + return value; +} +inline bool FileOptions::has_php_namespace() const { + return _internal_has_php_namespace(); +} +inline void FileOptions::clear_php_namespace() { + php_namespace_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000080u; +} +inline const std::string& FileOptions::php_namespace() const { + // @@protoc_insertion_point(field_get:google.protobuf.FileOptions.php_namespace) + return _internal_php_namespace(); +} +inline void FileOptions::set_php_namespace(const std::string& value) { + _internal_set_php_namespace(value); + // @@protoc_insertion_point(field_set:google.protobuf.FileOptions.php_namespace) +} +inline std::string* FileOptions::mutable_php_namespace() { + // @@protoc_insertion_point(field_mutable:google.protobuf.FileOptions.php_namespace) + return _internal_mutable_php_namespace(); +} +inline const std::string& FileOptions::_internal_php_namespace() const { + return php_namespace_.Get(); +} +inline void FileOptions::_internal_set_php_namespace(const std::string& value) { + _has_bits_[0] |= 0x00000080u; + php_namespace_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void FileOptions::set_php_namespace(std::string&& value) { + _has_bits_[0] |= 0x00000080u; + php_namespace_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:google.protobuf.FileOptions.php_namespace) +} +inline void FileOptions::set_php_namespace(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000080u; + php_namespace_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:google.protobuf.FileOptions.php_namespace) +} +inline void FileOptions::set_php_namespace(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000080u; + php_namespace_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:google.protobuf.FileOptions.php_namespace) +} +inline std::string* FileOptions::_internal_mutable_php_namespace() { + _has_bits_[0] |= 0x00000080u; + return php_namespace_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* FileOptions::release_php_namespace() { + // @@protoc_insertion_point(field_release:google.protobuf.FileOptions.php_namespace) + if (!_internal_has_php_namespace()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000080u; + return php_namespace_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void FileOptions::set_allocated_php_namespace(std::string* php_namespace) { + if (php_namespace != nullptr) { + _has_bits_[0] |= 0x00000080u; + } else { + _has_bits_[0] &= ~0x00000080u; + } + php_namespace_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), php_namespace, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:google.protobuf.FileOptions.php_namespace) +} + +// optional string php_metadata_namespace = 44; +inline bool FileOptions::_internal_has_php_metadata_namespace() const { + bool value = (_has_bits_[0] & 0x00000100u) != 0; + return value; +} +inline bool FileOptions::has_php_metadata_namespace() const { + return _internal_has_php_metadata_namespace(); +} +inline void FileOptions::clear_php_metadata_namespace() { + php_metadata_namespace_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000100u; +} +inline const std::string& FileOptions::php_metadata_namespace() const { + // @@protoc_insertion_point(field_get:google.protobuf.FileOptions.php_metadata_namespace) + return _internal_php_metadata_namespace(); +} +inline void FileOptions::set_php_metadata_namespace(const std::string& value) { + _internal_set_php_metadata_namespace(value); + // @@protoc_insertion_point(field_set:google.protobuf.FileOptions.php_metadata_namespace) +} +inline std::string* FileOptions::mutable_php_metadata_namespace() { + // @@protoc_insertion_point(field_mutable:google.protobuf.FileOptions.php_metadata_namespace) + return _internal_mutable_php_metadata_namespace(); +} +inline const std::string& FileOptions::_internal_php_metadata_namespace() const { + return php_metadata_namespace_.Get(); +} +inline void FileOptions::_internal_set_php_metadata_namespace(const std::string& value) { + _has_bits_[0] |= 0x00000100u; + php_metadata_namespace_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void FileOptions::set_php_metadata_namespace(std::string&& value) { + _has_bits_[0] |= 0x00000100u; + php_metadata_namespace_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:google.protobuf.FileOptions.php_metadata_namespace) +} +inline void FileOptions::set_php_metadata_namespace(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000100u; + php_metadata_namespace_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:google.protobuf.FileOptions.php_metadata_namespace) +} +inline void FileOptions::set_php_metadata_namespace(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000100u; + php_metadata_namespace_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:google.protobuf.FileOptions.php_metadata_namespace) +} +inline std::string* FileOptions::_internal_mutable_php_metadata_namespace() { + _has_bits_[0] |= 0x00000100u; + return php_metadata_namespace_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* FileOptions::release_php_metadata_namespace() { + // @@protoc_insertion_point(field_release:google.protobuf.FileOptions.php_metadata_namespace) + if (!_internal_has_php_metadata_namespace()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000100u; + return php_metadata_namespace_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void FileOptions::set_allocated_php_metadata_namespace(std::string* php_metadata_namespace) { + if (php_metadata_namespace != nullptr) { + _has_bits_[0] |= 0x00000100u; + } else { + _has_bits_[0] &= ~0x00000100u; + } + php_metadata_namespace_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), php_metadata_namespace, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:google.protobuf.FileOptions.php_metadata_namespace) +} + +// optional string ruby_package = 45; +inline bool FileOptions::_internal_has_ruby_package() const { + bool value = (_has_bits_[0] & 0x00000200u) != 0; + return value; +} +inline bool FileOptions::has_ruby_package() const { + return _internal_has_ruby_package(); +} +inline void FileOptions::clear_ruby_package() { + ruby_package_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000200u; +} +inline const std::string& FileOptions::ruby_package() const { + // @@protoc_insertion_point(field_get:google.protobuf.FileOptions.ruby_package) + return _internal_ruby_package(); +} +inline void FileOptions::set_ruby_package(const std::string& value) { + _internal_set_ruby_package(value); + // @@protoc_insertion_point(field_set:google.protobuf.FileOptions.ruby_package) +} +inline std::string* FileOptions::mutable_ruby_package() { + // @@protoc_insertion_point(field_mutable:google.protobuf.FileOptions.ruby_package) + return _internal_mutable_ruby_package(); +} +inline const std::string& FileOptions::_internal_ruby_package() const { + return ruby_package_.Get(); +} +inline void FileOptions::_internal_set_ruby_package(const std::string& value) { + _has_bits_[0] |= 0x00000200u; + ruby_package_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void FileOptions::set_ruby_package(std::string&& value) { + _has_bits_[0] |= 0x00000200u; + ruby_package_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:google.protobuf.FileOptions.ruby_package) +} +inline void FileOptions::set_ruby_package(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000200u; + ruby_package_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:google.protobuf.FileOptions.ruby_package) +} +inline void FileOptions::set_ruby_package(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000200u; + ruby_package_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:google.protobuf.FileOptions.ruby_package) +} +inline std::string* FileOptions::_internal_mutable_ruby_package() { + _has_bits_[0] |= 0x00000200u; + return ruby_package_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* FileOptions::release_ruby_package() { + // @@protoc_insertion_point(field_release:google.protobuf.FileOptions.ruby_package) + if (!_internal_has_ruby_package()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000200u; + return ruby_package_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void FileOptions::set_allocated_ruby_package(std::string* ruby_package) { + if (ruby_package != nullptr) { + _has_bits_[0] |= 0x00000200u; + } else { + _has_bits_[0] &= ~0x00000200u; + } + ruby_package_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ruby_package, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:google.protobuf.FileOptions.ruby_package) +} + +// repeated .google.protobuf.UninterpretedOption uninterpreted_option = 999; +inline int FileOptions::_internal_uninterpreted_option_size() const { + return uninterpreted_option_.size(); +} +inline int FileOptions::uninterpreted_option_size() const { + return _internal_uninterpreted_option_size(); +} +inline void FileOptions::clear_uninterpreted_option() { + uninterpreted_option_.Clear(); +} +inline PROTOBUF_NAMESPACE_ID::UninterpretedOption* FileOptions::mutable_uninterpreted_option(int index) { + // @@protoc_insertion_point(field_mutable:google.protobuf.FileOptions.uninterpreted_option) + return uninterpreted_option_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::UninterpretedOption >* +FileOptions::mutable_uninterpreted_option() { + // @@protoc_insertion_point(field_mutable_list:google.protobuf.FileOptions.uninterpreted_option) + return &uninterpreted_option_; +} +inline const PROTOBUF_NAMESPACE_ID::UninterpretedOption& FileOptions::_internal_uninterpreted_option(int index) const { + return uninterpreted_option_.Get(index); +} +inline const PROTOBUF_NAMESPACE_ID::UninterpretedOption& FileOptions::uninterpreted_option(int index) const { + // @@protoc_insertion_point(field_get:google.protobuf.FileOptions.uninterpreted_option) + return _internal_uninterpreted_option(index); +} +inline PROTOBUF_NAMESPACE_ID::UninterpretedOption* FileOptions::_internal_add_uninterpreted_option() { + return uninterpreted_option_.Add(); +} +inline PROTOBUF_NAMESPACE_ID::UninterpretedOption* FileOptions::add_uninterpreted_option() { + // @@protoc_insertion_point(field_add:google.protobuf.FileOptions.uninterpreted_option) + return _internal_add_uninterpreted_option(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::UninterpretedOption >& +FileOptions::uninterpreted_option() const { + // @@protoc_insertion_point(field_list:google.protobuf.FileOptions.uninterpreted_option) + return uninterpreted_option_; +} + +// ------------------------------------------------------------------- + +// MessageOptions + +// optional bool message_set_wire_format = 1 [default = false]; +inline bool MessageOptions::_internal_has_message_set_wire_format() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool MessageOptions::has_message_set_wire_format() const { + return _internal_has_message_set_wire_format(); +} +inline void MessageOptions::clear_message_set_wire_format() { + message_set_wire_format_ = false; + _has_bits_[0] &= ~0x00000001u; +} +inline bool MessageOptions::_internal_message_set_wire_format() const { + return message_set_wire_format_; +} +inline bool MessageOptions::message_set_wire_format() const { + // @@protoc_insertion_point(field_get:google.protobuf.MessageOptions.message_set_wire_format) + return _internal_message_set_wire_format(); +} +inline void MessageOptions::_internal_set_message_set_wire_format(bool value) { + _has_bits_[0] |= 0x00000001u; + message_set_wire_format_ = value; +} +inline void MessageOptions::set_message_set_wire_format(bool value) { + _internal_set_message_set_wire_format(value); + // @@protoc_insertion_point(field_set:google.protobuf.MessageOptions.message_set_wire_format) +} + +// optional bool no_standard_descriptor_accessor = 2 [default = false]; +inline bool MessageOptions::_internal_has_no_standard_descriptor_accessor() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + return value; +} +inline bool MessageOptions::has_no_standard_descriptor_accessor() const { + return _internal_has_no_standard_descriptor_accessor(); +} +inline void MessageOptions::clear_no_standard_descriptor_accessor() { + no_standard_descriptor_accessor_ = false; + _has_bits_[0] &= ~0x00000002u; +} +inline bool MessageOptions::_internal_no_standard_descriptor_accessor() const { + return no_standard_descriptor_accessor_; +} +inline bool MessageOptions::no_standard_descriptor_accessor() const { + // @@protoc_insertion_point(field_get:google.protobuf.MessageOptions.no_standard_descriptor_accessor) + return _internal_no_standard_descriptor_accessor(); +} +inline void MessageOptions::_internal_set_no_standard_descriptor_accessor(bool value) { + _has_bits_[0] |= 0x00000002u; + no_standard_descriptor_accessor_ = value; +} +inline void MessageOptions::set_no_standard_descriptor_accessor(bool value) { + _internal_set_no_standard_descriptor_accessor(value); + // @@protoc_insertion_point(field_set:google.protobuf.MessageOptions.no_standard_descriptor_accessor) +} + +// optional bool deprecated = 3 [default = false]; +inline bool MessageOptions::_internal_has_deprecated() const { + bool value = (_has_bits_[0] & 0x00000004u) != 0; + return value; +} +inline bool MessageOptions::has_deprecated() const { + return _internal_has_deprecated(); +} +inline void MessageOptions::clear_deprecated() { + deprecated_ = false; + _has_bits_[0] &= ~0x00000004u; +} +inline bool MessageOptions::_internal_deprecated() const { + return deprecated_; +} +inline bool MessageOptions::deprecated() const { + // @@protoc_insertion_point(field_get:google.protobuf.MessageOptions.deprecated) + return _internal_deprecated(); +} +inline void MessageOptions::_internal_set_deprecated(bool value) { + _has_bits_[0] |= 0x00000004u; + deprecated_ = value; +} +inline void MessageOptions::set_deprecated(bool value) { + _internal_set_deprecated(value); + // @@protoc_insertion_point(field_set:google.protobuf.MessageOptions.deprecated) +} + +// optional bool map_entry = 7; +inline bool MessageOptions::_internal_has_map_entry() const { + bool value = (_has_bits_[0] & 0x00000008u) != 0; + return value; +} +inline bool MessageOptions::has_map_entry() const { + return _internal_has_map_entry(); +} +inline void MessageOptions::clear_map_entry() { + map_entry_ = false; + _has_bits_[0] &= ~0x00000008u; +} +inline bool MessageOptions::_internal_map_entry() const { + return map_entry_; +} +inline bool MessageOptions::map_entry() const { + // @@protoc_insertion_point(field_get:google.protobuf.MessageOptions.map_entry) + return _internal_map_entry(); +} +inline void MessageOptions::_internal_set_map_entry(bool value) { + _has_bits_[0] |= 0x00000008u; + map_entry_ = value; +} +inline void MessageOptions::set_map_entry(bool value) { + _internal_set_map_entry(value); + // @@protoc_insertion_point(field_set:google.protobuf.MessageOptions.map_entry) +} + +// repeated .google.protobuf.UninterpretedOption uninterpreted_option = 999; +inline int MessageOptions::_internal_uninterpreted_option_size() const { + return uninterpreted_option_.size(); +} +inline int MessageOptions::uninterpreted_option_size() const { + return _internal_uninterpreted_option_size(); +} +inline void MessageOptions::clear_uninterpreted_option() { + uninterpreted_option_.Clear(); +} +inline PROTOBUF_NAMESPACE_ID::UninterpretedOption* MessageOptions::mutable_uninterpreted_option(int index) { + // @@protoc_insertion_point(field_mutable:google.protobuf.MessageOptions.uninterpreted_option) + return uninterpreted_option_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::UninterpretedOption >* +MessageOptions::mutable_uninterpreted_option() { + // @@protoc_insertion_point(field_mutable_list:google.protobuf.MessageOptions.uninterpreted_option) + return &uninterpreted_option_; +} +inline const PROTOBUF_NAMESPACE_ID::UninterpretedOption& MessageOptions::_internal_uninterpreted_option(int index) const { + return uninterpreted_option_.Get(index); +} +inline const PROTOBUF_NAMESPACE_ID::UninterpretedOption& MessageOptions::uninterpreted_option(int index) const { + // @@protoc_insertion_point(field_get:google.protobuf.MessageOptions.uninterpreted_option) + return _internal_uninterpreted_option(index); +} +inline PROTOBUF_NAMESPACE_ID::UninterpretedOption* MessageOptions::_internal_add_uninterpreted_option() { + return uninterpreted_option_.Add(); +} +inline PROTOBUF_NAMESPACE_ID::UninterpretedOption* MessageOptions::add_uninterpreted_option() { + // @@protoc_insertion_point(field_add:google.protobuf.MessageOptions.uninterpreted_option) + return _internal_add_uninterpreted_option(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::UninterpretedOption >& +MessageOptions::uninterpreted_option() const { + // @@protoc_insertion_point(field_list:google.protobuf.MessageOptions.uninterpreted_option) + return uninterpreted_option_; +} + +// ------------------------------------------------------------------- + +// FieldOptions + +// optional .google.protobuf.FieldOptions.CType ctype = 1 [default = STRING]; +inline bool FieldOptions::_internal_has_ctype() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool FieldOptions::has_ctype() const { + return _internal_has_ctype(); +} +inline void FieldOptions::clear_ctype() { + ctype_ = 0; + _has_bits_[0] &= ~0x00000001u; +} +inline PROTOBUF_NAMESPACE_ID::FieldOptions_CType FieldOptions::_internal_ctype() const { + return static_cast< PROTOBUF_NAMESPACE_ID::FieldOptions_CType >(ctype_); +} +inline PROTOBUF_NAMESPACE_ID::FieldOptions_CType FieldOptions::ctype() const { + // @@protoc_insertion_point(field_get:google.protobuf.FieldOptions.ctype) + return _internal_ctype(); +} +inline void FieldOptions::_internal_set_ctype(PROTOBUF_NAMESPACE_ID::FieldOptions_CType value) { + assert(PROTOBUF_NAMESPACE_ID::FieldOptions_CType_IsValid(value)); + _has_bits_[0] |= 0x00000001u; + ctype_ = value; +} +inline void FieldOptions::set_ctype(PROTOBUF_NAMESPACE_ID::FieldOptions_CType value) { + _internal_set_ctype(value); + // @@protoc_insertion_point(field_set:google.protobuf.FieldOptions.ctype) +} + +// optional bool packed = 2; +inline bool FieldOptions::_internal_has_packed() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + return value; +} +inline bool FieldOptions::has_packed() const { + return _internal_has_packed(); +} +inline void FieldOptions::clear_packed() { + packed_ = false; + _has_bits_[0] &= ~0x00000002u; +} +inline bool FieldOptions::_internal_packed() const { + return packed_; +} +inline bool FieldOptions::packed() const { + // @@protoc_insertion_point(field_get:google.protobuf.FieldOptions.packed) + return _internal_packed(); +} +inline void FieldOptions::_internal_set_packed(bool value) { + _has_bits_[0] |= 0x00000002u; + packed_ = value; +} +inline void FieldOptions::set_packed(bool value) { + _internal_set_packed(value); + // @@protoc_insertion_point(field_set:google.protobuf.FieldOptions.packed) +} + +// optional .google.protobuf.FieldOptions.JSType jstype = 6 [default = JS_NORMAL]; +inline bool FieldOptions::_internal_has_jstype() const { + bool value = (_has_bits_[0] & 0x00000020u) != 0; + return value; +} +inline bool FieldOptions::has_jstype() const { + return _internal_has_jstype(); +} +inline void FieldOptions::clear_jstype() { + jstype_ = 0; + _has_bits_[0] &= ~0x00000020u; +} +inline PROTOBUF_NAMESPACE_ID::FieldOptions_JSType FieldOptions::_internal_jstype() const { + return static_cast< PROTOBUF_NAMESPACE_ID::FieldOptions_JSType >(jstype_); +} +inline PROTOBUF_NAMESPACE_ID::FieldOptions_JSType FieldOptions::jstype() const { + // @@protoc_insertion_point(field_get:google.protobuf.FieldOptions.jstype) + return _internal_jstype(); +} +inline void FieldOptions::_internal_set_jstype(PROTOBUF_NAMESPACE_ID::FieldOptions_JSType value) { + assert(PROTOBUF_NAMESPACE_ID::FieldOptions_JSType_IsValid(value)); + _has_bits_[0] |= 0x00000020u; + jstype_ = value; +} +inline void FieldOptions::set_jstype(PROTOBUF_NAMESPACE_ID::FieldOptions_JSType value) { + _internal_set_jstype(value); + // @@protoc_insertion_point(field_set:google.protobuf.FieldOptions.jstype) +} + +// optional bool lazy = 5 [default = false]; +inline bool FieldOptions::_internal_has_lazy() const { + bool value = (_has_bits_[0] & 0x00000004u) != 0; + return value; +} +inline bool FieldOptions::has_lazy() const { + return _internal_has_lazy(); +} +inline void FieldOptions::clear_lazy() { + lazy_ = false; + _has_bits_[0] &= ~0x00000004u; +} +inline bool FieldOptions::_internal_lazy() const { + return lazy_; +} +inline bool FieldOptions::lazy() const { + // @@protoc_insertion_point(field_get:google.protobuf.FieldOptions.lazy) + return _internal_lazy(); +} +inline void FieldOptions::_internal_set_lazy(bool value) { + _has_bits_[0] |= 0x00000004u; + lazy_ = value; +} +inline void FieldOptions::set_lazy(bool value) { + _internal_set_lazy(value); + // @@protoc_insertion_point(field_set:google.protobuf.FieldOptions.lazy) +} + +// optional bool deprecated = 3 [default = false]; +inline bool FieldOptions::_internal_has_deprecated() const { + bool value = (_has_bits_[0] & 0x00000008u) != 0; + return value; +} +inline bool FieldOptions::has_deprecated() const { + return _internal_has_deprecated(); +} +inline void FieldOptions::clear_deprecated() { + deprecated_ = false; + _has_bits_[0] &= ~0x00000008u; +} +inline bool FieldOptions::_internal_deprecated() const { + return deprecated_; +} +inline bool FieldOptions::deprecated() const { + // @@protoc_insertion_point(field_get:google.protobuf.FieldOptions.deprecated) + return _internal_deprecated(); +} +inline void FieldOptions::_internal_set_deprecated(bool value) { + _has_bits_[0] |= 0x00000008u; + deprecated_ = value; +} +inline void FieldOptions::set_deprecated(bool value) { + _internal_set_deprecated(value); + // @@protoc_insertion_point(field_set:google.protobuf.FieldOptions.deprecated) +} + +// optional bool weak = 10 [default = false]; +inline bool FieldOptions::_internal_has_weak() const { + bool value = (_has_bits_[0] & 0x00000010u) != 0; + return value; +} +inline bool FieldOptions::has_weak() const { + return _internal_has_weak(); +} +inline void FieldOptions::clear_weak() { + weak_ = false; + _has_bits_[0] &= ~0x00000010u; +} +inline bool FieldOptions::_internal_weak() const { + return weak_; +} +inline bool FieldOptions::weak() const { + // @@protoc_insertion_point(field_get:google.protobuf.FieldOptions.weak) + return _internal_weak(); +} +inline void FieldOptions::_internal_set_weak(bool value) { + _has_bits_[0] |= 0x00000010u; + weak_ = value; +} +inline void FieldOptions::set_weak(bool value) { + _internal_set_weak(value); + // @@protoc_insertion_point(field_set:google.protobuf.FieldOptions.weak) +} + +// repeated .google.protobuf.UninterpretedOption uninterpreted_option = 999; +inline int FieldOptions::_internal_uninterpreted_option_size() const { + return uninterpreted_option_.size(); +} +inline int FieldOptions::uninterpreted_option_size() const { + return _internal_uninterpreted_option_size(); +} +inline void FieldOptions::clear_uninterpreted_option() { + uninterpreted_option_.Clear(); +} +inline PROTOBUF_NAMESPACE_ID::UninterpretedOption* FieldOptions::mutable_uninterpreted_option(int index) { + // @@protoc_insertion_point(field_mutable:google.protobuf.FieldOptions.uninterpreted_option) + return uninterpreted_option_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::UninterpretedOption >* +FieldOptions::mutable_uninterpreted_option() { + // @@protoc_insertion_point(field_mutable_list:google.protobuf.FieldOptions.uninterpreted_option) + return &uninterpreted_option_; +} +inline const PROTOBUF_NAMESPACE_ID::UninterpretedOption& FieldOptions::_internal_uninterpreted_option(int index) const { + return uninterpreted_option_.Get(index); +} +inline const PROTOBUF_NAMESPACE_ID::UninterpretedOption& FieldOptions::uninterpreted_option(int index) const { + // @@protoc_insertion_point(field_get:google.protobuf.FieldOptions.uninterpreted_option) + return _internal_uninterpreted_option(index); +} +inline PROTOBUF_NAMESPACE_ID::UninterpretedOption* FieldOptions::_internal_add_uninterpreted_option() { + return uninterpreted_option_.Add(); +} +inline PROTOBUF_NAMESPACE_ID::UninterpretedOption* FieldOptions::add_uninterpreted_option() { + // @@protoc_insertion_point(field_add:google.protobuf.FieldOptions.uninterpreted_option) + return _internal_add_uninterpreted_option(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::UninterpretedOption >& +FieldOptions::uninterpreted_option() const { + // @@protoc_insertion_point(field_list:google.protobuf.FieldOptions.uninterpreted_option) + return uninterpreted_option_; +} + +// ------------------------------------------------------------------- + +// OneofOptions + +// repeated .google.protobuf.UninterpretedOption uninterpreted_option = 999; +inline int OneofOptions::_internal_uninterpreted_option_size() const { + return uninterpreted_option_.size(); +} +inline int OneofOptions::uninterpreted_option_size() const { + return _internal_uninterpreted_option_size(); +} +inline void OneofOptions::clear_uninterpreted_option() { + uninterpreted_option_.Clear(); +} +inline PROTOBUF_NAMESPACE_ID::UninterpretedOption* OneofOptions::mutable_uninterpreted_option(int index) { + // @@protoc_insertion_point(field_mutable:google.protobuf.OneofOptions.uninterpreted_option) + return uninterpreted_option_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::UninterpretedOption >* +OneofOptions::mutable_uninterpreted_option() { + // @@protoc_insertion_point(field_mutable_list:google.protobuf.OneofOptions.uninterpreted_option) + return &uninterpreted_option_; +} +inline const PROTOBUF_NAMESPACE_ID::UninterpretedOption& OneofOptions::_internal_uninterpreted_option(int index) const { + return uninterpreted_option_.Get(index); +} +inline const PROTOBUF_NAMESPACE_ID::UninterpretedOption& OneofOptions::uninterpreted_option(int index) const { + // @@protoc_insertion_point(field_get:google.protobuf.OneofOptions.uninterpreted_option) + return _internal_uninterpreted_option(index); +} +inline PROTOBUF_NAMESPACE_ID::UninterpretedOption* OneofOptions::_internal_add_uninterpreted_option() { + return uninterpreted_option_.Add(); +} +inline PROTOBUF_NAMESPACE_ID::UninterpretedOption* OneofOptions::add_uninterpreted_option() { + // @@protoc_insertion_point(field_add:google.protobuf.OneofOptions.uninterpreted_option) + return _internal_add_uninterpreted_option(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::UninterpretedOption >& +OneofOptions::uninterpreted_option() const { + // @@protoc_insertion_point(field_list:google.protobuf.OneofOptions.uninterpreted_option) + return uninterpreted_option_; +} + +// ------------------------------------------------------------------- + +// EnumOptions + +// optional bool allow_alias = 2; +inline bool EnumOptions::_internal_has_allow_alias() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool EnumOptions::has_allow_alias() const { + return _internal_has_allow_alias(); +} +inline void EnumOptions::clear_allow_alias() { + allow_alias_ = false; + _has_bits_[0] &= ~0x00000001u; +} +inline bool EnumOptions::_internal_allow_alias() const { + return allow_alias_; +} +inline bool EnumOptions::allow_alias() const { + // @@protoc_insertion_point(field_get:google.protobuf.EnumOptions.allow_alias) + return _internal_allow_alias(); +} +inline void EnumOptions::_internal_set_allow_alias(bool value) { + _has_bits_[0] |= 0x00000001u; + allow_alias_ = value; +} +inline void EnumOptions::set_allow_alias(bool value) { + _internal_set_allow_alias(value); + // @@protoc_insertion_point(field_set:google.protobuf.EnumOptions.allow_alias) +} + +// optional bool deprecated = 3 [default = false]; +inline bool EnumOptions::_internal_has_deprecated() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + return value; +} +inline bool EnumOptions::has_deprecated() const { + return _internal_has_deprecated(); +} +inline void EnumOptions::clear_deprecated() { + deprecated_ = false; + _has_bits_[0] &= ~0x00000002u; +} +inline bool EnumOptions::_internal_deprecated() const { + return deprecated_; +} +inline bool EnumOptions::deprecated() const { + // @@protoc_insertion_point(field_get:google.protobuf.EnumOptions.deprecated) + return _internal_deprecated(); +} +inline void EnumOptions::_internal_set_deprecated(bool value) { + _has_bits_[0] |= 0x00000002u; + deprecated_ = value; +} +inline void EnumOptions::set_deprecated(bool value) { + _internal_set_deprecated(value); + // @@protoc_insertion_point(field_set:google.protobuf.EnumOptions.deprecated) +} + +// repeated .google.protobuf.UninterpretedOption uninterpreted_option = 999; +inline int EnumOptions::_internal_uninterpreted_option_size() const { + return uninterpreted_option_.size(); +} +inline int EnumOptions::uninterpreted_option_size() const { + return _internal_uninterpreted_option_size(); +} +inline void EnumOptions::clear_uninterpreted_option() { + uninterpreted_option_.Clear(); +} +inline PROTOBUF_NAMESPACE_ID::UninterpretedOption* EnumOptions::mutable_uninterpreted_option(int index) { + // @@protoc_insertion_point(field_mutable:google.protobuf.EnumOptions.uninterpreted_option) + return uninterpreted_option_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::UninterpretedOption >* +EnumOptions::mutable_uninterpreted_option() { + // @@protoc_insertion_point(field_mutable_list:google.protobuf.EnumOptions.uninterpreted_option) + return &uninterpreted_option_; +} +inline const PROTOBUF_NAMESPACE_ID::UninterpretedOption& EnumOptions::_internal_uninterpreted_option(int index) const { + return uninterpreted_option_.Get(index); +} +inline const PROTOBUF_NAMESPACE_ID::UninterpretedOption& EnumOptions::uninterpreted_option(int index) const { + // @@protoc_insertion_point(field_get:google.protobuf.EnumOptions.uninterpreted_option) + return _internal_uninterpreted_option(index); +} +inline PROTOBUF_NAMESPACE_ID::UninterpretedOption* EnumOptions::_internal_add_uninterpreted_option() { + return uninterpreted_option_.Add(); +} +inline PROTOBUF_NAMESPACE_ID::UninterpretedOption* EnumOptions::add_uninterpreted_option() { + // @@protoc_insertion_point(field_add:google.protobuf.EnumOptions.uninterpreted_option) + return _internal_add_uninterpreted_option(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::UninterpretedOption >& +EnumOptions::uninterpreted_option() const { + // @@protoc_insertion_point(field_list:google.protobuf.EnumOptions.uninterpreted_option) + return uninterpreted_option_; +} + +// ------------------------------------------------------------------- + +// EnumValueOptions + +// optional bool deprecated = 1 [default = false]; +inline bool EnumValueOptions::_internal_has_deprecated() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool EnumValueOptions::has_deprecated() const { + return _internal_has_deprecated(); +} +inline void EnumValueOptions::clear_deprecated() { + deprecated_ = false; + _has_bits_[0] &= ~0x00000001u; +} +inline bool EnumValueOptions::_internal_deprecated() const { + return deprecated_; +} +inline bool EnumValueOptions::deprecated() const { + // @@protoc_insertion_point(field_get:google.protobuf.EnumValueOptions.deprecated) + return _internal_deprecated(); +} +inline void EnumValueOptions::_internal_set_deprecated(bool value) { + _has_bits_[0] |= 0x00000001u; + deprecated_ = value; +} +inline void EnumValueOptions::set_deprecated(bool value) { + _internal_set_deprecated(value); + // @@protoc_insertion_point(field_set:google.protobuf.EnumValueOptions.deprecated) +} + +// repeated .google.protobuf.UninterpretedOption uninterpreted_option = 999; +inline int EnumValueOptions::_internal_uninterpreted_option_size() const { + return uninterpreted_option_.size(); +} +inline int EnumValueOptions::uninterpreted_option_size() const { + return _internal_uninterpreted_option_size(); +} +inline void EnumValueOptions::clear_uninterpreted_option() { + uninterpreted_option_.Clear(); +} +inline PROTOBUF_NAMESPACE_ID::UninterpretedOption* EnumValueOptions::mutable_uninterpreted_option(int index) { + // @@protoc_insertion_point(field_mutable:google.protobuf.EnumValueOptions.uninterpreted_option) + return uninterpreted_option_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::UninterpretedOption >* +EnumValueOptions::mutable_uninterpreted_option() { + // @@protoc_insertion_point(field_mutable_list:google.protobuf.EnumValueOptions.uninterpreted_option) + return &uninterpreted_option_; +} +inline const PROTOBUF_NAMESPACE_ID::UninterpretedOption& EnumValueOptions::_internal_uninterpreted_option(int index) const { + return uninterpreted_option_.Get(index); +} +inline const PROTOBUF_NAMESPACE_ID::UninterpretedOption& EnumValueOptions::uninterpreted_option(int index) const { + // @@protoc_insertion_point(field_get:google.protobuf.EnumValueOptions.uninterpreted_option) + return _internal_uninterpreted_option(index); +} +inline PROTOBUF_NAMESPACE_ID::UninterpretedOption* EnumValueOptions::_internal_add_uninterpreted_option() { + return uninterpreted_option_.Add(); +} +inline PROTOBUF_NAMESPACE_ID::UninterpretedOption* EnumValueOptions::add_uninterpreted_option() { + // @@protoc_insertion_point(field_add:google.protobuf.EnumValueOptions.uninterpreted_option) + return _internal_add_uninterpreted_option(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::UninterpretedOption >& +EnumValueOptions::uninterpreted_option() const { + // @@protoc_insertion_point(field_list:google.protobuf.EnumValueOptions.uninterpreted_option) + return uninterpreted_option_; +} + +// ------------------------------------------------------------------- + +// ServiceOptions + +// optional bool deprecated = 33 [default = false]; +inline bool ServiceOptions::_internal_has_deprecated() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool ServiceOptions::has_deprecated() const { + return _internal_has_deprecated(); +} +inline void ServiceOptions::clear_deprecated() { + deprecated_ = false; + _has_bits_[0] &= ~0x00000001u; +} +inline bool ServiceOptions::_internal_deprecated() const { + return deprecated_; +} +inline bool ServiceOptions::deprecated() const { + // @@protoc_insertion_point(field_get:google.protobuf.ServiceOptions.deprecated) + return _internal_deprecated(); +} +inline void ServiceOptions::_internal_set_deprecated(bool value) { + _has_bits_[0] |= 0x00000001u; + deprecated_ = value; +} +inline void ServiceOptions::set_deprecated(bool value) { + _internal_set_deprecated(value); + // @@protoc_insertion_point(field_set:google.protobuf.ServiceOptions.deprecated) +} + +// repeated .google.protobuf.UninterpretedOption uninterpreted_option = 999; +inline int ServiceOptions::_internal_uninterpreted_option_size() const { + return uninterpreted_option_.size(); +} +inline int ServiceOptions::uninterpreted_option_size() const { + return _internal_uninterpreted_option_size(); +} +inline void ServiceOptions::clear_uninterpreted_option() { + uninterpreted_option_.Clear(); +} +inline PROTOBUF_NAMESPACE_ID::UninterpretedOption* ServiceOptions::mutable_uninterpreted_option(int index) { + // @@protoc_insertion_point(field_mutable:google.protobuf.ServiceOptions.uninterpreted_option) + return uninterpreted_option_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::UninterpretedOption >* +ServiceOptions::mutable_uninterpreted_option() { + // @@protoc_insertion_point(field_mutable_list:google.protobuf.ServiceOptions.uninterpreted_option) + return &uninterpreted_option_; +} +inline const PROTOBUF_NAMESPACE_ID::UninterpretedOption& ServiceOptions::_internal_uninterpreted_option(int index) const { + return uninterpreted_option_.Get(index); +} +inline const PROTOBUF_NAMESPACE_ID::UninterpretedOption& ServiceOptions::uninterpreted_option(int index) const { + // @@protoc_insertion_point(field_get:google.protobuf.ServiceOptions.uninterpreted_option) + return _internal_uninterpreted_option(index); +} +inline PROTOBUF_NAMESPACE_ID::UninterpretedOption* ServiceOptions::_internal_add_uninterpreted_option() { + return uninterpreted_option_.Add(); +} +inline PROTOBUF_NAMESPACE_ID::UninterpretedOption* ServiceOptions::add_uninterpreted_option() { + // @@protoc_insertion_point(field_add:google.protobuf.ServiceOptions.uninterpreted_option) + return _internal_add_uninterpreted_option(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::UninterpretedOption >& +ServiceOptions::uninterpreted_option() const { + // @@protoc_insertion_point(field_list:google.protobuf.ServiceOptions.uninterpreted_option) + return uninterpreted_option_; +} + +// ------------------------------------------------------------------- + +// MethodOptions + +// optional bool deprecated = 33 [default = false]; +inline bool MethodOptions::_internal_has_deprecated() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool MethodOptions::has_deprecated() const { + return _internal_has_deprecated(); +} +inline void MethodOptions::clear_deprecated() { + deprecated_ = false; + _has_bits_[0] &= ~0x00000001u; +} +inline bool MethodOptions::_internal_deprecated() const { + return deprecated_; +} +inline bool MethodOptions::deprecated() const { + // @@protoc_insertion_point(field_get:google.protobuf.MethodOptions.deprecated) + return _internal_deprecated(); +} +inline void MethodOptions::_internal_set_deprecated(bool value) { + _has_bits_[0] |= 0x00000001u; + deprecated_ = value; +} +inline void MethodOptions::set_deprecated(bool value) { + _internal_set_deprecated(value); + // @@protoc_insertion_point(field_set:google.protobuf.MethodOptions.deprecated) +} + +// optional .google.protobuf.MethodOptions.IdempotencyLevel idempotency_level = 34 [default = IDEMPOTENCY_UNKNOWN]; +inline bool MethodOptions::_internal_has_idempotency_level() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + return value; +} +inline bool MethodOptions::has_idempotency_level() const { + return _internal_has_idempotency_level(); +} +inline void MethodOptions::clear_idempotency_level() { + idempotency_level_ = 0; + _has_bits_[0] &= ~0x00000002u; +} +inline PROTOBUF_NAMESPACE_ID::MethodOptions_IdempotencyLevel MethodOptions::_internal_idempotency_level() const { + return static_cast< PROTOBUF_NAMESPACE_ID::MethodOptions_IdempotencyLevel >(idempotency_level_); +} +inline PROTOBUF_NAMESPACE_ID::MethodOptions_IdempotencyLevel MethodOptions::idempotency_level() const { + // @@protoc_insertion_point(field_get:google.protobuf.MethodOptions.idempotency_level) + return _internal_idempotency_level(); +} +inline void MethodOptions::_internal_set_idempotency_level(PROTOBUF_NAMESPACE_ID::MethodOptions_IdempotencyLevel value) { + assert(PROTOBUF_NAMESPACE_ID::MethodOptions_IdempotencyLevel_IsValid(value)); + _has_bits_[0] |= 0x00000002u; + idempotency_level_ = value; +} +inline void MethodOptions::set_idempotency_level(PROTOBUF_NAMESPACE_ID::MethodOptions_IdempotencyLevel value) { + _internal_set_idempotency_level(value); + // @@protoc_insertion_point(field_set:google.protobuf.MethodOptions.idempotency_level) +} + +// repeated .google.protobuf.UninterpretedOption uninterpreted_option = 999; +inline int MethodOptions::_internal_uninterpreted_option_size() const { + return uninterpreted_option_.size(); +} +inline int MethodOptions::uninterpreted_option_size() const { + return _internal_uninterpreted_option_size(); +} +inline void MethodOptions::clear_uninterpreted_option() { + uninterpreted_option_.Clear(); +} +inline PROTOBUF_NAMESPACE_ID::UninterpretedOption* MethodOptions::mutable_uninterpreted_option(int index) { + // @@protoc_insertion_point(field_mutable:google.protobuf.MethodOptions.uninterpreted_option) + return uninterpreted_option_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::UninterpretedOption >* +MethodOptions::mutable_uninterpreted_option() { + // @@protoc_insertion_point(field_mutable_list:google.protobuf.MethodOptions.uninterpreted_option) + return &uninterpreted_option_; +} +inline const PROTOBUF_NAMESPACE_ID::UninterpretedOption& MethodOptions::_internal_uninterpreted_option(int index) const { + return uninterpreted_option_.Get(index); +} +inline const PROTOBUF_NAMESPACE_ID::UninterpretedOption& MethodOptions::uninterpreted_option(int index) const { + // @@protoc_insertion_point(field_get:google.protobuf.MethodOptions.uninterpreted_option) + return _internal_uninterpreted_option(index); +} +inline PROTOBUF_NAMESPACE_ID::UninterpretedOption* MethodOptions::_internal_add_uninterpreted_option() { + return uninterpreted_option_.Add(); +} +inline PROTOBUF_NAMESPACE_ID::UninterpretedOption* MethodOptions::add_uninterpreted_option() { + // @@protoc_insertion_point(field_add:google.protobuf.MethodOptions.uninterpreted_option) + return _internal_add_uninterpreted_option(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::UninterpretedOption >& +MethodOptions::uninterpreted_option() const { + // @@protoc_insertion_point(field_list:google.protobuf.MethodOptions.uninterpreted_option) + return uninterpreted_option_; +} + +// ------------------------------------------------------------------- + +// UninterpretedOption_NamePart + +// required string name_part = 1; +inline bool UninterpretedOption_NamePart::_internal_has_name_part() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool UninterpretedOption_NamePart::has_name_part() const { + return _internal_has_name_part(); +} +inline void UninterpretedOption_NamePart::clear_name_part() { + name_part_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000001u; +} +inline const std::string& UninterpretedOption_NamePart::name_part() const { + // @@protoc_insertion_point(field_get:google.protobuf.UninterpretedOption.NamePart.name_part) + return _internal_name_part(); +} +inline void UninterpretedOption_NamePart::set_name_part(const std::string& value) { + _internal_set_name_part(value); + // @@protoc_insertion_point(field_set:google.protobuf.UninterpretedOption.NamePart.name_part) +} +inline std::string* UninterpretedOption_NamePart::mutable_name_part() { + // @@protoc_insertion_point(field_mutable:google.protobuf.UninterpretedOption.NamePart.name_part) + return _internal_mutable_name_part(); +} +inline const std::string& UninterpretedOption_NamePart::_internal_name_part() const { + return name_part_.Get(); +} +inline void UninterpretedOption_NamePart::_internal_set_name_part(const std::string& value) { + _has_bits_[0] |= 0x00000001u; + name_part_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void UninterpretedOption_NamePart::set_name_part(std::string&& value) { + _has_bits_[0] |= 0x00000001u; + name_part_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:google.protobuf.UninterpretedOption.NamePart.name_part) +} +inline void UninterpretedOption_NamePart::set_name_part(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000001u; + name_part_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:google.protobuf.UninterpretedOption.NamePart.name_part) +} +inline void UninterpretedOption_NamePart::set_name_part(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000001u; + name_part_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:google.protobuf.UninterpretedOption.NamePart.name_part) +} +inline std::string* UninterpretedOption_NamePart::_internal_mutable_name_part() { + _has_bits_[0] |= 0x00000001u; + return name_part_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* UninterpretedOption_NamePart::release_name_part() { + // @@protoc_insertion_point(field_release:google.protobuf.UninterpretedOption.NamePart.name_part) + if (!_internal_has_name_part()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000001u; + return name_part_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void UninterpretedOption_NamePart::set_allocated_name_part(std::string* name_part) { + if (name_part != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + name_part_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), name_part, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:google.protobuf.UninterpretedOption.NamePart.name_part) +} + +// required bool is_extension = 2; +inline bool UninterpretedOption_NamePart::_internal_has_is_extension() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + return value; +} +inline bool UninterpretedOption_NamePart::has_is_extension() const { + return _internal_has_is_extension(); +} +inline void UninterpretedOption_NamePart::clear_is_extension() { + is_extension_ = false; + _has_bits_[0] &= ~0x00000002u; +} +inline bool UninterpretedOption_NamePart::_internal_is_extension() const { + return is_extension_; +} +inline bool UninterpretedOption_NamePart::is_extension() const { + // @@protoc_insertion_point(field_get:google.protobuf.UninterpretedOption.NamePart.is_extension) + return _internal_is_extension(); +} +inline void UninterpretedOption_NamePart::_internal_set_is_extension(bool value) { + _has_bits_[0] |= 0x00000002u; + is_extension_ = value; +} +inline void UninterpretedOption_NamePart::set_is_extension(bool value) { + _internal_set_is_extension(value); + // @@protoc_insertion_point(field_set:google.protobuf.UninterpretedOption.NamePart.is_extension) +} + +// ------------------------------------------------------------------- + +// UninterpretedOption + +// repeated .google.protobuf.UninterpretedOption.NamePart name = 2; +inline int UninterpretedOption::_internal_name_size() const { + return name_.size(); +} +inline int UninterpretedOption::name_size() const { + return _internal_name_size(); +} +inline void UninterpretedOption::clear_name() { + name_.Clear(); +} +inline PROTOBUF_NAMESPACE_ID::UninterpretedOption_NamePart* UninterpretedOption::mutable_name(int index) { + // @@protoc_insertion_point(field_mutable:google.protobuf.UninterpretedOption.name) + return name_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::UninterpretedOption_NamePart >* +UninterpretedOption::mutable_name() { + // @@protoc_insertion_point(field_mutable_list:google.protobuf.UninterpretedOption.name) + return &name_; +} +inline const PROTOBUF_NAMESPACE_ID::UninterpretedOption_NamePart& UninterpretedOption::_internal_name(int index) const { + return name_.Get(index); +} +inline const PROTOBUF_NAMESPACE_ID::UninterpretedOption_NamePart& UninterpretedOption::name(int index) const { + // @@protoc_insertion_point(field_get:google.protobuf.UninterpretedOption.name) + return _internal_name(index); +} +inline PROTOBUF_NAMESPACE_ID::UninterpretedOption_NamePart* UninterpretedOption::_internal_add_name() { + return name_.Add(); +} +inline PROTOBUF_NAMESPACE_ID::UninterpretedOption_NamePart* UninterpretedOption::add_name() { + // @@protoc_insertion_point(field_add:google.protobuf.UninterpretedOption.name) + return _internal_add_name(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::UninterpretedOption_NamePart >& +UninterpretedOption::name() const { + // @@protoc_insertion_point(field_list:google.protobuf.UninterpretedOption.name) + return name_; +} + +// optional string identifier_value = 3; +inline bool UninterpretedOption::_internal_has_identifier_value() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool UninterpretedOption::has_identifier_value() const { + return _internal_has_identifier_value(); +} +inline void UninterpretedOption::clear_identifier_value() { + identifier_value_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000001u; +} +inline const std::string& UninterpretedOption::identifier_value() const { + // @@protoc_insertion_point(field_get:google.protobuf.UninterpretedOption.identifier_value) + return _internal_identifier_value(); +} +inline void UninterpretedOption::set_identifier_value(const std::string& value) { + _internal_set_identifier_value(value); + // @@protoc_insertion_point(field_set:google.protobuf.UninterpretedOption.identifier_value) +} +inline std::string* UninterpretedOption::mutable_identifier_value() { + // @@protoc_insertion_point(field_mutable:google.protobuf.UninterpretedOption.identifier_value) + return _internal_mutable_identifier_value(); +} +inline const std::string& UninterpretedOption::_internal_identifier_value() const { + return identifier_value_.Get(); +} +inline void UninterpretedOption::_internal_set_identifier_value(const std::string& value) { + _has_bits_[0] |= 0x00000001u; + identifier_value_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void UninterpretedOption::set_identifier_value(std::string&& value) { + _has_bits_[0] |= 0x00000001u; + identifier_value_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:google.protobuf.UninterpretedOption.identifier_value) +} +inline void UninterpretedOption::set_identifier_value(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000001u; + identifier_value_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:google.protobuf.UninterpretedOption.identifier_value) +} +inline void UninterpretedOption::set_identifier_value(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000001u; + identifier_value_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:google.protobuf.UninterpretedOption.identifier_value) +} +inline std::string* UninterpretedOption::_internal_mutable_identifier_value() { + _has_bits_[0] |= 0x00000001u; + return identifier_value_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* UninterpretedOption::release_identifier_value() { + // @@protoc_insertion_point(field_release:google.protobuf.UninterpretedOption.identifier_value) + if (!_internal_has_identifier_value()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000001u; + return identifier_value_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void UninterpretedOption::set_allocated_identifier_value(std::string* identifier_value) { + if (identifier_value != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + identifier_value_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), identifier_value, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:google.protobuf.UninterpretedOption.identifier_value) +} + +// optional uint64 positive_int_value = 4; +inline bool UninterpretedOption::_internal_has_positive_int_value() const { + bool value = (_has_bits_[0] & 0x00000008u) != 0; + return value; +} +inline bool UninterpretedOption::has_positive_int_value() const { + return _internal_has_positive_int_value(); +} +inline void UninterpretedOption::clear_positive_int_value() { + positive_int_value_ = PROTOBUF_ULONGLONG(0); + _has_bits_[0] &= ~0x00000008u; +} +inline ::PROTOBUF_NAMESPACE_ID::uint64 UninterpretedOption::_internal_positive_int_value() const { + return positive_int_value_; +} +inline ::PROTOBUF_NAMESPACE_ID::uint64 UninterpretedOption::positive_int_value() const { + // @@protoc_insertion_point(field_get:google.protobuf.UninterpretedOption.positive_int_value) + return _internal_positive_int_value(); +} +inline void UninterpretedOption::_internal_set_positive_int_value(::PROTOBUF_NAMESPACE_ID::uint64 value) { + _has_bits_[0] |= 0x00000008u; + positive_int_value_ = value; +} +inline void UninterpretedOption::set_positive_int_value(::PROTOBUF_NAMESPACE_ID::uint64 value) { + _internal_set_positive_int_value(value); + // @@protoc_insertion_point(field_set:google.protobuf.UninterpretedOption.positive_int_value) +} + +// optional int64 negative_int_value = 5; +inline bool UninterpretedOption::_internal_has_negative_int_value() const { + bool value = (_has_bits_[0] & 0x00000010u) != 0; + return value; +} +inline bool UninterpretedOption::has_negative_int_value() const { + return _internal_has_negative_int_value(); +} +inline void UninterpretedOption::clear_negative_int_value() { + negative_int_value_ = PROTOBUF_LONGLONG(0); + _has_bits_[0] &= ~0x00000010u; +} +inline ::PROTOBUF_NAMESPACE_ID::int64 UninterpretedOption::_internal_negative_int_value() const { + return negative_int_value_; +} +inline ::PROTOBUF_NAMESPACE_ID::int64 UninterpretedOption::negative_int_value() const { + // @@protoc_insertion_point(field_get:google.protobuf.UninterpretedOption.negative_int_value) + return _internal_negative_int_value(); +} +inline void UninterpretedOption::_internal_set_negative_int_value(::PROTOBUF_NAMESPACE_ID::int64 value) { + _has_bits_[0] |= 0x00000010u; + negative_int_value_ = value; +} +inline void UninterpretedOption::set_negative_int_value(::PROTOBUF_NAMESPACE_ID::int64 value) { + _internal_set_negative_int_value(value); + // @@protoc_insertion_point(field_set:google.protobuf.UninterpretedOption.negative_int_value) +} + +// optional double double_value = 6; +inline bool UninterpretedOption::_internal_has_double_value() const { + bool value = (_has_bits_[0] & 0x00000020u) != 0; + return value; +} +inline bool UninterpretedOption::has_double_value() const { + return _internal_has_double_value(); +} +inline void UninterpretedOption::clear_double_value() { + double_value_ = 0; + _has_bits_[0] &= ~0x00000020u; +} +inline double UninterpretedOption::_internal_double_value() const { + return double_value_; +} +inline double UninterpretedOption::double_value() const { + // @@protoc_insertion_point(field_get:google.protobuf.UninterpretedOption.double_value) + return _internal_double_value(); +} +inline void UninterpretedOption::_internal_set_double_value(double value) { + _has_bits_[0] |= 0x00000020u; + double_value_ = value; +} +inline void UninterpretedOption::set_double_value(double value) { + _internal_set_double_value(value); + // @@protoc_insertion_point(field_set:google.protobuf.UninterpretedOption.double_value) +} + +// optional bytes string_value = 7; +inline bool UninterpretedOption::_internal_has_string_value() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + return value; +} +inline bool UninterpretedOption::has_string_value() const { + return _internal_has_string_value(); +} +inline void UninterpretedOption::clear_string_value() { + string_value_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000002u; +} +inline const std::string& UninterpretedOption::string_value() const { + // @@protoc_insertion_point(field_get:google.protobuf.UninterpretedOption.string_value) + return _internal_string_value(); +} +inline void UninterpretedOption::set_string_value(const std::string& value) { + _internal_set_string_value(value); + // @@protoc_insertion_point(field_set:google.protobuf.UninterpretedOption.string_value) +} +inline std::string* UninterpretedOption::mutable_string_value() { + // @@protoc_insertion_point(field_mutable:google.protobuf.UninterpretedOption.string_value) + return _internal_mutable_string_value(); +} +inline const std::string& UninterpretedOption::_internal_string_value() const { + return string_value_.Get(); +} +inline void UninterpretedOption::_internal_set_string_value(const std::string& value) { + _has_bits_[0] |= 0x00000002u; + string_value_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void UninterpretedOption::set_string_value(std::string&& value) { + _has_bits_[0] |= 0x00000002u; + string_value_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:google.protobuf.UninterpretedOption.string_value) +} +inline void UninterpretedOption::set_string_value(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000002u; + string_value_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:google.protobuf.UninterpretedOption.string_value) +} +inline void UninterpretedOption::set_string_value(const void* value, + size_t size) { + _has_bits_[0] |= 0x00000002u; + string_value_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:google.protobuf.UninterpretedOption.string_value) +} +inline std::string* UninterpretedOption::_internal_mutable_string_value() { + _has_bits_[0] |= 0x00000002u; + return string_value_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* UninterpretedOption::release_string_value() { + // @@protoc_insertion_point(field_release:google.protobuf.UninterpretedOption.string_value) + if (!_internal_has_string_value()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000002u; + return string_value_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void UninterpretedOption::set_allocated_string_value(std::string* string_value) { + if (string_value != nullptr) { + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + string_value_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), string_value, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:google.protobuf.UninterpretedOption.string_value) +} + +// optional string aggregate_value = 8; +inline bool UninterpretedOption::_internal_has_aggregate_value() const { + bool value = (_has_bits_[0] & 0x00000004u) != 0; + return value; +} +inline bool UninterpretedOption::has_aggregate_value() const { + return _internal_has_aggregate_value(); +} +inline void UninterpretedOption::clear_aggregate_value() { + aggregate_value_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000004u; +} +inline const std::string& UninterpretedOption::aggregate_value() const { + // @@protoc_insertion_point(field_get:google.protobuf.UninterpretedOption.aggregate_value) + return _internal_aggregate_value(); +} +inline void UninterpretedOption::set_aggregate_value(const std::string& value) { + _internal_set_aggregate_value(value); + // @@protoc_insertion_point(field_set:google.protobuf.UninterpretedOption.aggregate_value) +} +inline std::string* UninterpretedOption::mutable_aggregate_value() { + // @@protoc_insertion_point(field_mutable:google.protobuf.UninterpretedOption.aggregate_value) + return _internal_mutable_aggregate_value(); +} +inline const std::string& UninterpretedOption::_internal_aggregate_value() const { + return aggregate_value_.Get(); +} +inline void UninterpretedOption::_internal_set_aggregate_value(const std::string& value) { + _has_bits_[0] |= 0x00000004u; + aggregate_value_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void UninterpretedOption::set_aggregate_value(std::string&& value) { + _has_bits_[0] |= 0x00000004u; + aggregate_value_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:google.protobuf.UninterpretedOption.aggregate_value) +} +inline void UninterpretedOption::set_aggregate_value(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000004u; + aggregate_value_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:google.protobuf.UninterpretedOption.aggregate_value) +} +inline void UninterpretedOption::set_aggregate_value(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000004u; + aggregate_value_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:google.protobuf.UninterpretedOption.aggregate_value) +} +inline std::string* UninterpretedOption::_internal_mutable_aggregate_value() { + _has_bits_[0] |= 0x00000004u; + return aggregate_value_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* UninterpretedOption::release_aggregate_value() { + // @@protoc_insertion_point(field_release:google.protobuf.UninterpretedOption.aggregate_value) + if (!_internal_has_aggregate_value()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000004u; + return aggregate_value_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void UninterpretedOption::set_allocated_aggregate_value(std::string* aggregate_value) { + if (aggregate_value != nullptr) { + _has_bits_[0] |= 0x00000004u; + } else { + _has_bits_[0] &= ~0x00000004u; + } + aggregate_value_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), aggregate_value, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:google.protobuf.UninterpretedOption.aggregate_value) +} + +// ------------------------------------------------------------------- + +// SourceCodeInfo_Location + +// repeated int32 path = 1 [packed = true]; +inline int SourceCodeInfo_Location::_internal_path_size() const { + return path_.size(); +} +inline int SourceCodeInfo_Location::path_size() const { + return _internal_path_size(); +} +inline void SourceCodeInfo_Location::clear_path() { + path_.Clear(); +} +inline ::PROTOBUF_NAMESPACE_ID::int32 SourceCodeInfo_Location::_internal_path(int index) const { + return path_.Get(index); +} +inline ::PROTOBUF_NAMESPACE_ID::int32 SourceCodeInfo_Location::path(int index) const { + // @@protoc_insertion_point(field_get:google.protobuf.SourceCodeInfo.Location.path) + return _internal_path(index); +} +inline void SourceCodeInfo_Location::set_path(int index, ::PROTOBUF_NAMESPACE_ID::int32 value) { + path_.Set(index, value); + // @@protoc_insertion_point(field_set:google.protobuf.SourceCodeInfo.Location.path) +} +inline void SourceCodeInfo_Location::_internal_add_path(::PROTOBUF_NAMESPACE_ID::int32 value) { + path_.Add(value); +} +inline void SourceCodeInfo_Location::add_path(::PROTOBUF_NAMESPACE_ID::int32 value) { + _internal_add_path(value); + // @@protoc_insertion_point(field_add:google.protobuf.SourceCodeInfo.Location.path) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >& +SourceCodeInfo_Location::_internal_path() const { + return path_; +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >& +SourceCodeInfo_Location::path() const { + // @@protoc_insertion_point(field_list:google.protobuf.SourceCodeInfo.Location.path) + return _internal_path(); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >* +SourceCodeInfo_Location::_internal_mutable_path() { + return &path_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >* +SourceCodeInfo_Location::mutable_path() { + // @@protoc_insertion_point(field_mutable_list:google.protobuf.SourceCodeInfo.Location.path) + return _internal_mutable_path(); +} + +// repeated int32 span = 2 [packed = true]; +inline int SourceCodeInfo_Location::_internal_span_size() const { + return span_.size(); +} +inline int SourceCodeInfo_Location::span_size() const { + return _internal_span_size(); +} +inline void SourceCodeInfo_Location::clear_span() { + span_.Clear(); +} +inline ::PROTOBUF_NAMESPACE_ID::int32 SourceCodeInfo_Location::_internal_span(int index) const { + return span_.Get(index); +} +inline ::PROTOBUF_NAMESPACE_ID::int32 SourceCodeInfo_Location::span(int index) const { + // @@protoc_insertion_point(field_get:google.protobuf.SourceCodeInfo.Location.span) + return _internal_span(index); +} +inline void SourceCodeInfo_Location::set_span(int index, ::PROTOBUF_NAMESPACE_ID::int32 value) { + span_.Set(index, value); + // @@protoc_insertion_point(field_set:google.protobuf.SourceCodeInfo.Location.span) +} +inline void SourceCodeInfo_Location::_internal_add_span(::PROTOBUF_NAMESPACE_ID::int32 value) { + span_.Add(value); +} +inline void SourceCodeInfo_Location::add_span(::PROTOBUF_NAMESPACE_ID::int32 value) { + _internal_add_span(value); + // @@protoc_insertion_point(field_add:google.protobuf.SourceCodeInfo.Location.span) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >& +SourceCodeInfo_Location::_internal_span() const { + return span_; +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >& +SourceCodeInfo_Location::span() const { + // @@protoc_insertion_point(field_list:google.protobuf.SourceCodeInfo.Location.span) + return _internal_span(); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >* +SourceCodeInfo_Location::_internal_mutable_span() { + return &span_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >* +SourceCodeInfo_Location::mutable_span() { + // @@protoc_insertion_point(field_mutable_list:google.protobuf.SourceCodeInfo.Location.span) + return _internal_mutable_span(); +} + +// optional string leading_comments = 3; +inline bool SourceCodeInfo_Location::_internal_has_leading_comments() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool SourceCodeInfo_Location::has_leading_comments() const { + return _internal_has_leading_comments(); +} +inline void SourceCodeInfo_Location::clear_leading_comments() { + leading_comments_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000001u; +} +inline const std::string& SourceCodeInfo_Location::leading_comments() const { + // @@protoc_insertion_point(field_get:google.protobuf.SourceCodeInfo.Location.leading_comments) + return _internal_leading_comments(); +} +inline void SourceCodeInfo_Location::set_leading_comments(const std::string& value) { + _internal_set_leading_comments(value); + // @@protoc_insertion_point(field_set:google.protobuf.SourceCodeInfo.Location.leading_comments) +} +inline std::string* SourceCodeInfo_Location::mutable_leading_comments() { + // @@protoc_insertion_point(field_mutable:google.protobuf.SourceCodeInfo.Location.leading_comments) + return _internal_mutable_leading_comments(); +} +inline const std::string& SourceCodeInfo_Location::_internal_leading_comments() const { + return leading_comments_.Get(); +} +inline void SourceCodeInfo_Location::_internal_set_leading_comments(const std::string& value) { + _has_bits_[0] |= 0x00000001u; + leading_comments_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void SourceCodeInfo_Location::set_leading_comments(std::string&& value) { + _has_bits_[0] |= 0x00000001u; + leading_comments_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:google.protobuf.SourceCodeInfo.Location.leading_comments) +} +inline void SourceCodeInfo_Location::set_leading_comments(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000001u; + leading_comments_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:google.protobuf.SourceCodeInfo.Location.leading_comments) +} +inline void SourceCodeInfo_Location::set_leading_comments(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000001u; + leading_comments_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:google.protobuf.SourceCodeInfo.Location.leading_comments) +} +inline std::string* SourceCodeInfo_Location::_internal_mutable_leading_comments() { + _has_bits_[0] |= 0x00000001u; + return leading_comments_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* SourceCodeInfo_Location::release_leading_comments() { + // @@protoc_insertion_point(field_release:google.protobuf.SourceCodeInfo.Location.leading_comments) + if (!_internal_has_leading_comments()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000001u; + return leading_comments_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void SourceCodeInfo_Location::set_allocated_leading_comments(std::string* leading_comments) { + if (leading_comments != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + leading_comments_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), leading_comments, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:google.protobuf.SourceCodeInfo.Location.leading_comments) +} + +// optional string trailing_comments = 4; +inline bool SourceCodeInfo_Location::_internal_has_trailing_comments() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + return value; +} +inline bool SourceCodeInfo_Location::has_trailing_comments() const { + return _internal_has_trailing_comments(); +} +inline void SourceCodeInfo_Location::clear_trailing_comments() { + trailing_comments_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000002u; +} +inline const std::string& SourceCodeInfo_Location::trailing_comments() const { + // @@protoc_insertion_point(field_get:google.protobuf.SourceCodeInfo.Location.trailing_comments) + return _internal_trailing_comments(); +} +inline void SourceCodeInfo_Location::set_trailing_comments(const std::string& value) { + _internal_set_trailing_comments(value); + // @@protoc_insertion_point(field_set:google.protobuf.SourceCodeInfo.Location.trailing_comments) +} +inline std::string* SourceCodeInfo_Location::mutable_trailing_comments() { + // @@protoc_insertion_point(field_mutable:google.protobuf.SourceCodeInfo.Location.trailing_comments) + return _internal_mutable_trailing_comments(); +} +inline const std::string& SourceCodeInfo_Location::_internal_trailing_comments() const { + return trailing_comments_.Get(); +} +inline void SourceCodeInfo_Location::_internal_set_trailing_comments(const std::string& value) { + _has_bits_[0] |= 0x00000002u; + trailing_comments_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void SourceCodeInfo_Location::set_trailing_comments(std::string&& value) { + _has_bits_[0] |= 0x00000002u; + trailing_comments_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:google.protobuf.SourceCodeInfo.Location.trailing_comments) +} +inline void SourceCodeInfo_Location::set_trailing_comments(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000002u; + trailing_comments_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:google.protobuf.SourceCodeInfo.Location.trailing_comments) +} +inline void SourceCodeInfo_Location::set_trailing_comments(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000002u; + trailing_comments_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:google.protobuf.SourceCodeInfo.Location.trailing_comments) +} +inline std::string* SourceCodeInfo_Location::_internal_mutable_trailing_comments() { + _has_bits_[0] |= 0x00000002u; + return trailing_comments_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* SourceCodeInfo_Location::release_trailing_comments() { + // @@protoc_insertion_point(field_release:google.protobuf.SourceCodeInfo.Location.trailing_comments) + if (!_internal_has_trailing_comments()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000002u; + return trailing_comments_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void SourceCodeInfo_Location::set_allocated_trailing_comments(std::string* trailing_comments) { + if (trailing_comments != nullptr) { + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + trailing_comments_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), trailing_comments, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:google.protobuf.SourceCodeInfo.Location.trailing_comments) +} + +// repeated string leading_detached_comments = 6; +inline int SourceCodeInfo_Location::_internal_leading_detached_comments_size() const { + return leading_detached_comments_.size(); +} +inline int SourceCodeInfo_Location::leading_detached_comments_size() const { + return _internal_leading_detached_comments_size(); +} +inline void SourceCodeInfo_Location::clear_leading_detached_comments() { + leading_detached_comments_.Clear(); +} +inline std::string* SourceCodeInfo_Location::add_leading_detached_comments() { + // @@protoc_insertion_point(field_add_mutable:google.protobuf.SourceCodeInfo.Location.leading_detached_comments) + return _internal_add_leading_detached_comments(); +} +inline const std::string& SourceCodeInfo_Location::_internal_leading_detached_comments(int index) const { + return leading_detached_comments_.Get(index); +} +inline const std::string& SourceCodeInfo_Location::leading_detached_comments(int index) const { + // @@protoc_insertion_point(field_get:google.protobuf.SourceCodeInfo.Location.leading_detached_comments) + return _internal_leading_detached_comments(index); +} +inline std::string* SourceCodeInfo_Location::mutable_leading_detached_comments(int index) { + // @@protoc_insertion_point(field_mutable:google.protobuf.SourceCodeInfo.Location.leading_detached_comments) + return leading_detached_comments_.Mutable(index); +} +inline void SourceCodeInfo_Location::set_leading_detached_comments(int index, const std::string& value) { + // @@protoc_insertion_point(field_set:google.protobuf.SourceCodeInfo.Location.leading_detached_comments) + leading_detached_comments_.Mutable(index)->assign(value); +} +inline void SourceCodeInfo_Location::set_leading_detached_comments(int index, std::string&& value) { + // @@protoc_insertion_point(field_set:google.protobuf.SourceCodeInfo.Location.leading_detached_comments) + leading_detached_comments_.Mutable(index)->assign(std::move(value)); +} +inline void SourceCodeInfo_Location::set_leading_detached_comments(int index, const char* value) { + GOOGLE_DCHECK(value != nullptr); + leading_detached_comments_.Mutable(index)->assign(value); + // @@protoc_insertion_point(field_set_char:google.protobuf.SourceCodeInfo.Location.leading_detached_comments) +} +inline void SourceCodeInfo_Location::set_leading_detached_comments(int index, const char* value, size_t size) { + leading_detached_comments_.Mutable(index)->assign( + reinterpret_cast(value), size); + // @@protoc_insertion_point(field_set_pointer:google.protobuf.SourceCodeInfo.Location.leading_detached_comments) +} +inline std::string* SourceCodeInfo_Location::_internal_add_leading_detached_comments() { + return leading_detached_comments_.Add(); +} +inline void SourceCodeInfo_Location::add_leading_detached_comments(const std::string& value) { + leading_detached_comments_.Add()->assign(value); + // @@protoc_insertion_point(field_add:google.protobuf.SourceCodeInfo.Location.leading_detached_comments) +} +inline void SourceCodeInfo_Location::add_leading_detached_comments(std::string&& value) { + leading_detached_comments_.Add(std::move(value)); + // @@protoc_insertion_point(field_add:google.protobuf.SourceCodeInfo.Location.leading_detached_comments) +} +inline void SourceCodeInfo_Location::add_leading_detached_comments(const char* value) { + GOOGLE_DCHECK(value != nullptr); + leading_detached_comments_.Add()->assign(value); + // @@protoc_insertion_point(field_add_char:google.protobuf.SourceCodeInfo.Location.leading_detached_comments) +} +inline void SourceCodeInfo_Location::add_leading_detached_comments(const char* value, size_t size) { + leading_detached_comments_.Add()->assign(reinterpret_cast(value), size); + // @@protoc_insertion_point(field_add_pointer:google.protobuf.SourceCodeInfo.Location.leading_detached_comments) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField& +SourceCodeInfo_Location::leading_detached_comments() const { + // @@protoc_insertion_point(field_list:google.protobuf.SourceCodeInfo.Location.leading_detached_comments) + return leading_detached_comments_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField* +SourceCodeInfo_Location::mutable_leading_detached_comments() { + // @@protoc_insertion_point(field_mutable_list:google.protobuf.SourceCodeInfo.Location.leading_detached_comments) + return &leading_detached_comments_; +} + +// ------------------------------------------------------------------- + +// SourceCodeInfo + +// repeated .google.protobuf.SourceCodeInfo.Location location = 1; +inline int SourceCodeInfo::_internal_location_size() const { + return location_.size(); +} +inline int SourceCodeInfo::location_size() const { + return _internal_location_size(); +} +inline void SourceCodeInfo::clear_location() { + location_.Clear(); +} +inline PROTOBUF_NAMESPACE_ID::SourceCodeInfo_Location* SourceCodeInfo::mutable_location(int index) { + // @@protoc_insertion_point(field_mutable:google.protobuf.SourceCodeInfo.location) + return location_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::SourceCodeInfo_Location >* +SourceCodeInfo::mutable_location() { + // @@protoc_insertion_point(field_mutable_list:google.protobuf.SourceCodeInfo.location) + return &location_; +} +inline const PROTOBUF_NAMESPACE_ID::SourceCodeInfo_Location& SourceCodeInfo::_internal_location(int index) const { + return location_.Get(index); +} +inline const PROTOBUF_NAMESPACE_ID::SourceCodeInfo_Location& SourceCodeInfo::location(int index) const { + // @@protoc_insertion_point(field_get:google.protobuf.SourceCodeInfo.location) + return _internal_location(index); +} +inline PROTOBUF_NAMESPACE_ID::SourceCodeInfo_Location* SourceCodeInfo::_internal_add_location() { + return location_.Add(); +} +inline PROTOBUF_NAMESPACE_ID::SourceCodeInfo_Location* SourceCodeInfo::add_location() { + // @@protoc_insertion_point(field_add:google.protobuf.SourceCodeInfo.location) + return _internal_add_location(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::SourceCodeInfo_Location >& +SourceCodeInfo::location() const { + // @@protoc_insertion_point(field_list:google.protobuf.SourceCodeInfo.location) + return location_; +} + +// ------------------------------------------------------------------- + +// GeneratedCodeInfo_Annotation + +// repeated int32 path = 1 [packed = true]; +inline int GeneratedCodeInfo_Annotation::_internal_path_size() const { + return path_.size(); +} +inline int GeneratedCodeInfo_Annotation::path_size() const { + return _internal_path_size(); +} +inline void GeneratedCodeInfo_Annotation::clear_path() { + path_.Clear(); +} +inline ::PROTOBUF_NAMESPACE_ID::int32 GeneratedCodeInfo_Annotation::_internal_path(int index) const { + return path_.Get(index); +} +inline ::PROTOBUF_NAMESPACE_ID::int32 GeneratedCodeInfo_Annotation::path(int index) const { + // @@protoc_insertion_point(field_get:google.protobuf.GeneratedCodeInfo.Annotation.path) + return _internal_path(index); +} +inline void GeneratedCodeInfo_Annotation::set_path(int index, ::PROTOBUF_NAMESPACE_ID::int32 value) { + path_.Set(index, value); + // @@protoc_insertion_point(field_set:google.protobuf.GeneratedCodeInfo.Annotation.path) +} +inline void GeneratedCodeInfo_Annotation::_internal_add_path(::PROTOBUF_NAMESPACE_ID::int32 value) { + path_.Add(value); +} +inline void GeneratedCodeInfo_Annotation::add_path(::PROTOBUF_NAMESPACE_ID::int32 value) { + _internal_add_path(value); + // @@protoc_insertion_point(field_add:google.protobuf.GeneratedCodeInfo.Annotation.path) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >& +GeneratedCodeInfo_Annotation::_internal_path() const { + return path_; +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >& +GeneratedCodeInfo_Annotation::path() const { + // @@protoc_insertion_point(field_list:google.protobuf.GeneratedCodeInfo.Annotation.path) + return _internal_path(); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >* +GeneratedCodeInfo_Annotation::_internal_mutable_path() { + return &path_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >* +GeneratedCodeInfo_Annotation::mutable_path() { + // @@protoc_insertion_point(field_mutable_list:google.protobuf.GeneratedCodeInfo.Annotation.path) + return _internal_mutable_path(); +} + +// optional string source_file = 2; +inline bool GeneratedCodeInfo_Annotation::_internal_has_source_file() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool GeneratedCodeInfo_Annotation::has_source_file() const { + return _internal_has_source_file(); +} +inline void GeneratedCodeInfo_Annotation::clear_source_file() { + source_file_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000001u; +} +inline const std::string& GeneratedCodeInfo_Annotation::source_file() const { + // @@protoc_insertion_point(field_get:google.protobuf.GeneratedCodeInfo.Annotation.source_file) + return _internal_source_file(); +} +inline void GeneratedCodeInfo_Annotation::set_source_file(const std::string& value) { + _internal_set_source_file(value); + // @@protoc_insertion_point(field_set:google.protobuf.GeneratedCodeInfo.Annotation.source_file) +} +inline std::string* GeneratedCodeInfo_Annotation::mutable_source_file() { + // @@protoc_insertion_point(field_mutable:google.protobuf.GeneratedCodeInfo.Annotation.source_file) + return _internal_mutable_source_file(); +} +inline const std::string& GeneratedCodeInfo_Annotation::_internal_source_file() const { + return source_file_.Get(); +} +inline void GeneratedCodeInfo_Annotation::_internal_set_source_file(const std::string& value) { + _has_bits_[0] |= 0x00000001u; + source_file_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void GeneratedCodeInfo_Annotation::set_source_file(std::string&& value) { + _has_bits_[0] |= 0x00000001u; + source_file_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:google.protobuf.GeneratedCodeInfo.Annotation.source_file) +} +inline void GeneratedCodeInfo_Annotation::set_source_file(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000001u; + source_file_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:google.protobuf.GeneratedCodeInfo.Annotation.source_file) +} +inline void GeneratedCodeInfo_Annotation::set_source_file(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000001u; + source_file_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:google.protobuf.GeneratedCodeInfo.Annotation.source_file) +} +inline std::string* GeneratedCodeInfo_Annotation::_internal_mutable_source_file() { + _has_bits_[0] |= 0x00000001u; + return source_file_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* GeneratedCodeInfo_Annotation::release_source_file() { + // @@protoc_insertion_point(field_release:google.protobuf.GeneratedCodeInfo.Annotation.source_file) + if (!_internal_has_source_file()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000001u; + return source_file_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void GeneratedCodeInfo_Annotation::set_allocated_source_file(std::string* source_file) { + if (source_file != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + source_file_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), source_file, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:google.protobuf.GeneratedCodeInfo.Annotation.source_file) +} + +// optional int32 begin = 3; +inline bool GeneratedCodeInfo_Annotation::_internal_has_begin() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + return value; +} +inline bool GeneratedCodeInfo_Annotation::has_begin() const { + return _internal_has_begin(); +} +inline void GeneratedCodeInfo_Annotation::clear_begin() { + begin_ = 0; + _has_bits_[0] &= ~0x00000002u; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 GeneratedCodeInfo_Annotation::_internal_begin() const { + return begin_; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 GeneratedCodeInfo_Annotation::begin() const { + // @@protoc_insertion_point(field_get:google.protobuf.GeneratedCodeInfo.Annotation.begin) + return _internal_begin(); +} +inline void GeneratedCodeInfo_Annotation::_internal_set_begin(::PROTOBUF_NAMESPACE_ID::int32 value) { + _has_bits_[0] |= 0x00000002u; + begin_ = value; +} +inline void GeneratedCodeInfo_Annotation::set_begin(::PROTOBUF_NAMESPACE_ID::int32 value) { + _internal_set_begin(value); + // @@protoc_insertion_point(field_set:google.protobuf.GeneratedCodeInfo.Annotation.begin) +} + +// optional int32 end = 4; +inline bool GeneratedCodeInfo_Annotation::_internal_has_end() const { + bool value = (_has_bits_[0] & 0x00000004u) != 0; + return value; +} +inline bool GeneratedCodeInfo_Annotation::has_end() const { + return _internal_has_end(); +} +inline void GeneratedCodeInfo_Annotation::clear_end() { + end_ = 0; + _has_bits_[0] &= ~0x00000004u; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 GeneratedCodeInfo_Annotation::_internal_end() const { + return end_; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 GeneratedCodeInfo_Annotation::end() const { + // @@protoc_insertion_point(field_get:google.protobuf.GeneratedCodeInfo.Annotation.end) + return _internal_end(); +} +inline void GeneratedCodeInfo_Annotation::_internal_set_end(::PROTOBUF_NAMESPACE_ID::int32 value) { + _has_bits_[0] |= 0x00000004u; + end_ = value; +} +inline void GeneratedCodeInfo_Annotation::set_end(::PROTOBUF_NAMESPACE_ID::int32 value) { + _internal_set_end(value); + // @@protoc_insertion_point(field_set:google.protobuf.GeneratedCodeInfo.Annotation.end) +} + +// ------------------------------------------------------------------- + +// GeneratedCodeInfo + +// repeated .google.protobuf.GeneratedCodeInfo.Annotation annotation = 1; +inline int GeneratedCodeInfo::_internal_annotation_size() const { + return annotation_.size(); +} +inline int GeneratedCodeInfo::annotation_size() const { + return _internal_annotation_size(); +} +inline void GeneratedCodeInfo::clear_annotation() { + annotation_.Clear(); +} +inline PROTOBUF_NAMESPACE_ID::GeneratedCodeInfo_Annotation* GeneratedCodeInfo::mutable_annotation(int index) { + // @@protoc_insertion_point(field_mutable:google.protobuf.GeneratedCodeInfo.annotation) + return annotation_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::GeneratedCodeInfo_Annotation >* +GeneratedCodeInfo::mutable_annotation() { + // @@protoc_insertion_point(field_mutable_list:google.protobuf.GeneratedCodeInfo.annotation) + return &annotation_; +} +inline const PROTOBUF_NAMESPACE_ID::GeneratedCodeInfo_Annotation& GeneratedCodeInfo::_internal_annotation(int index) const { + return annotation_.Get(index); +} +inline const PROTOBUF_NAMESPACE_ID::GeneratedCodeInfo_Annotation& GeneratedCodeInfo::annotation(int index) const { + // @@protoc_insertion_point(field_get:google.protobuf.GeneratedCodeInfo.annotation) + return _internal_annotation(index); +} +inline PROTOBUF_NAMESPACE_ID::GeneratedCodeInfo_Annotation* GeneratedCodeInfo::_internal_add_annotation() { + return annotation_.Add(); +} +inline PROTOBUF_NAMESPACE_ID::GeneratedCodeInfo_Annotation* GeneratedCodeInfo::add_annotation() { + // @@protoc_insertion_point(field_add:google.protobuf.GeneratedCodeInfo.annotation) + return _internal_add_annotation(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::GeneratedCodeInfo_Annotation >& +GeneratedCodeInfo::annotation() const { + // @@protoc_insertion_point(field_list:google.protobuf.GeneratedCodeInfo.annotation) + return annotation_; +} + +#ifdef __GNUC__ + #pragma GCC diagnostic pop +#endif // __GNUC__ +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + + +// @@protoc_insertion_point(namespace_scope) + +PROTOBUF_NAMESPACE_CLOSE + +PROTOBUF_NAMESPACE_OPEN + +template <> struct is_proto_enum< PROTOBUF_NAMESPACE_ID::FieldDescriptorProto_Type> : ::std::true_type {}; +template <> +inline const EnumDescriptor* GetEnumDescriptor< PROTOBUF_NAMESPACE_ID::FieldDescriptorProto_Type>() { + return PROTOBUF_NAMESPACE_ID::FieldDescriptorProto_Type_descriptor(); +} +template <> struct is_proto_enum< PROTOBUF_NAMESPACE_ID::FieldDescriptorProto_Label> : ::std::true_type {}; +template <> +inline const EnumDescriptor* GetEnumDescriptor< PROTOBUF_NAMESPACE_ID::FieldDescriptorProto_Label>() { + return PROTOBUF_NAMESPACE_ID::FieldDescriptorProto_Label_descriptor(); +} +template <> struct is_proto_enum< PROTOBUF_NAMESPACE_ID::FileOptions_OptimizeMode> : ::std::true_type {}; +template <> +inline const EnumDescriptor* GetEnumDescriptor< PROTOBUF_NAMESPACE_ID::FileOptions_OptimizeMode>() { + return PROTOBUF_NAMESPACE_ID::FileOptions_OptimizeMode_descriptor(); +} +template <> struct is_proto_enum< PROTOBUF_NAMESPACE_ID::FieldOptions_CType> : ::std::true_type {}; +template <> +inline const EnumDescriptor* GetEnumDescriptor< PROTOBUF_NAMESPACE_ID::FieldOptions_CType>() { + return PROTOBUF_NAMESPACE_ID::FieldOptions_CType_descriptor(); +} +template <> struct is_proto_enum< PROTOBUF_NAMESPACE_ID::FieldOptions_JSType> : ::std::true_type {}; +template <> +inline const EnumDescriptor* GetEnumDescriptor< PROTOBUF_NAMESPACE_ID::FieldOptions_JSType>() { + return PROTOBUF_NAMESPACE_ID::FieldOptions_JSType_descriptor(); +} +template <> struct is_proto_enum< PROTOBUF_NAMESPACE_ID::MethodOptions_IdempotencyLevel> : ::std::true_type {}; +template <> +inline const EnumDescriptor* GetEnumDescriptor< PROTOBUF_NAMESPACE_ID::MethodOptions_IdempotencyLevel>() { + return PROTOBUF_NAMESPACE_ID::MethodOptions_IdempotencyLevel_descriptor(); +} + +PROTOBUF_NAMESPACE_CLOSE + +// @@protoc_insertion_point(global_scope) + +#include +#endif // GOOGLE_PROTOBUF_INCLUDED_GOOGLE_PROTOBUF_INCLUDED_google_2fprotobuf_2fdescriptor_2eproto + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/descriptor_database.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/descriptor_database.h new file mode 100644 index 0000000000000000000000000000000000000000..30ea31f99e339a9423842220053394ce105e65d0 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/descriptor_database.h @@ -0,0 +1,399 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: kenton@google.com (Kenton Varda) +// Based on original Protocol Buffers design by +// Sanjay Ghemawat, Jeff Dean, and others. +// +// Interface for manipulating databases of descriptors. + +#ifndef GOOGLE_PROTOBUF_DESCRIPTOR_DATABASE_H__ +#define GOOGLE_PROTOBUF_DESCRIPTOR_DATABASE_H__ + +#include +#include +#include +#include +#include +#include + +#include + +#ifdef SWIG +#error "You cannot SWIG proto headers" +#endif + +namespace google { +namespace protobuf { + +// Defined in this file. +class DescriptorDatabase; +class SimpleDescriptorDatabase; +class EncodedDescriptorDatabase; +class DescriptorPoolDatabase; +class MergedDescriptorDatabase; + +// Abstract interface for a database of descriptors. +// +// This is useful if you want to create a DescriptorPool which loads +// descriptors on-demand from some sort of large database. If the database +// is large, it may be inefficient to enumerate every .proto file inside it +// calling DescriptorPool::BuildFile() for each one. Instead, a DescriptorPool +// can be created which wraps a DescriptorDatabase and only builds particular +// descriptors when they are needed. +class PROTOBUF_EXPORT DescriptorDatabase { + public: + inline DescriptorDatabase() {} + virtual ~DescriptorDatabase(); + + // Find a file by file name. Fills in in *output and returns true if found. + // Otherwise, returns false, leaving the contents of *output undefined. + virtual bool FindFileByName(const std::string& filename, + FileDescriptorProto* output) = 0; + + // Find the file that declares the given fully-qualified symbol name. + // If found, fills in *output and returns true, otherwise returns false + // and leaves *output undefined. + virtual bool FindFileContainingSymbol(const std::string& symbol_name, + FileDescriptorProto* output) = 0; + + // Find the file which defines an extension extending the given message type + // with the given field number. If found, fills in *output and returns true, + // otherwise returns false and leaves *output undefined. containing_type + // must be a fully-qualified type name. + virtual bool FindFileContainingExtension(const std::string& containing_type, + int field_number, + FileDescriptorProto* output) = 0; + + // Finds the tag numbers used by all known extensions of + // extendee_type, and appends them to output in an undefined + // order. This method is best-effort: it's not guaranteed that the + // database will find all extensions, and it's not guaranteed that + // FindFileContainingExtension will return true on all of the found + // numbers. Returns true if the search was successful, otherwise + // returns false and leaves output unchanged. + // + // This method has a default implementation that always returns + // false. + virtual bool FindAllExtensionNumbers(const std::string& /* extendee_type */, + std::vector* /* output */) { + return false; + } + + + // Finds the file names and appends them to the output in an + // undefined order. This method is best-effort: it's not guaranteed that the + // database will find all files. Returns true if the database supports + // searching all file names, otherwise returns false and leaves output + // unchanged. + // + // This method has a default implementation that always returns + // false. + virtual bool FindAllFileNames(std::vector* /*output*/) { + return false; + } + + // Finds the package names and appends them to the output in an + // undefined order. This method is best-effort: it's not guaranteed that the + // database will find all packages. Returns true if the database supports + // searching all package names, otherwise returns false and leaves output + // unchanged. + bool FindAllPackageNames(std::vector* output); + + // Finds the message names and appends them to the output in an + // undefined order. This method is best-effort: it's not guaranteed that the + // database will find all messages. Returns true if the database supports + // searching all message names, otherwise returns false and leaves output + // unchanged. + bool FindAllMessageNames(std::vector* output); + + private: + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(DescriptorDatabase); +}; + +// A DescriptorDatabase into which you can insert files manually. +// +// FindFileContainingSymbol() is fully-implemented. When you add a file, its +// symbols will be indexed for this purpose. Note that the implementation +// may return false positives, but only if it isn't possible for the symbol +// to be defined in any other file. In particular, if a file defines a symbol +// "Foo", then searching for "Foo.[anything]" will match that file. This way, +// the database does not need to aggressively index all children of a symbol. +// +// FindFileContainingExtension() is mostly-implemented. It works if and only +// if the original FieldDescriptorProto defining the extension has a +// fully-qualified type name in its "extendee" field (i.e. starts with a '.'). +// If the extendee is a relative name, SimpleDescriptorDatabase will not +// attempt to resolve the type, so it will not know what type the extension is +// extending. Therefore, calling FindFileContainingExtension() with the +// extension's containing type will never actually find that extension. Note +// that this is an unlikely problem, as all FileDescriptorProtos created by the +// protocol compiler (as well as ones created by calling +// FileDescriptor::CopyTo()) will always use fully-qualified names for all +// types. You only need to worry if you are constructing FileDescriptorProtos +// yourself, or are calling compiler::Parser directly. +class PROTOBUF_EXPORT SimpleDescriptorDatabase : public DescriptorDatabase { + public: + SimpleDescriptorDatabase(); + ~SimpleDescriptorDatabase() override; + + // Adds the FileDescriptorProto to the database, making a copy. The object + // can be deleted after Add() returns. Returns false if the file conflicted + // with a file already in the database, in which case an error will have + // been written to GOOGLE_LOG(ERROR). + bool Add(const FileDescriptorProto& file); + + // Adds the FileDescriptorProto to the database and takes ownership of it. + bool AddAndOwn(const FileDescriptorProto* file); + + // implements DescriptorDatabase ----------------------------------- + bool FindFileByName(const std::string& filename, + FileDescriptorProto* output) override; + bool FindFileContainingSymbol(const std::string& symbol_name, + FileDescriptorProto* output) override; + bool FindFileContainingExtension(const std::string& containing_type, + int field_number, + FileDescriptorProto* output) override; + bool FindAllExtensionNumbers(const std::string& extendee_type, + std::vector* output) override; + + bool FindAllFileNames(std::vector* output) override; + + private: + // So that it can use DescriptorIndex. + friend class EncodedDescriptorDatabase; + + // An index mapping file names, symbol names, and extension numbers to + // some sort of values. + template + class DescriptorIndex { + public: + // Helpers to recursively add particular descriptors and all their contents + // to the index. + bool AddFile(const FileDescriptorProto& file, Value value); + bool AddSymbol(const std::string& name, Value value); + bool AddNestedExtensions(const std::string& filename, + const DescriptorProto& message_type, Value value); + bool AddExtension(const std::string& filename, + const FieldDescriptorProto& field, Value value); + + Value FindFile(const std::string& filename); + Value FindSymbol(const std::string& name); + Value FindExtension(const std::string& containing_type, int field_number); + bool FindAllExtensionNumbers(const std::string& containing_type, + std::vector* output); + void FindAllFileNames(std::vector* output); + + private: + std::map by_name_; + std::map by_symbol_; + std::map, Value> by_extension_; + + // Invariant: The by_symbol_ map does not contain any symbols which are + // prefixes of other symbols in the map. For example, "foo.bar" is a + // prefix of "foo.bar.baz" (but is not a prefix of "foo.barbaz"). + // + // This invariant is important because it means that given a symbol name, + // we can find a key in the map which is a prefix of the symbol in O(lg n) + // time, and we know that there is at most one such key. + // + // The prefix lookup algorithm works like so: + // 1) Find the last key in the map which is less than or equal to the + // search key. + // 2) If the found key is a prefix of the search key, then return it. + // Otherwise, there is no match. + // + // I am sure this algorithm has been described elsewhere, but since I + // wasn't able to find it quickly I will instead prove that it works + // myself. The key to the algorithm is that if a match exists, step (1) + // will find it. Proof: + // 1) Define the "search key" to be the key we are looking for, the "found + // key" to be the key found in step (1), and the "match key" to be the + // key which actually matches the search key (i.e. the key we're trying + // to find). + // 2) The found key must be less than or equal to the search key by + // definition. + // 3) The match key must also be less than or equal to the search key + // (because it is a prefix). + // 4) The match key cannot be greater than the found key, because if it + // were, then step (1) of the algorithm would have returned the match + // key instead (since it finds the *greatest* key which is less than or + // equal to the search key). + // 5) Therefore, the found key must be between the match key and the search + // key, inclusive. + // 6) Since the search key must be a sub-symbol of the match key, if it is + // not equal to the match key, then search_key[match_key.size()] must + // be '.'. + // 7) Since '.' sorts before any other character that is valid in a symbol + // name, then if the found key is not equal to the match key, then + // found_key[match_key.size()] must also be '.', because any other value + // would make it sort after the search key. + // 8) Therefore, if the found key is not equal to the match key, then the + // found key must be a sub-symbol of the match key. However, this would + // contradict our map invariant which says that no symbol in the map is + // a sub-symbol of any other. + // 9) Therefore, the found key must match the match key. + // + // The above proof assumes the match key exists. In the case that the + // match key does not exist, then step (1) will return some other symbol. + // That symbol cannot be a super-symbol of the search key since if it were, + // then it would be a match, and we're assuming the match key doesn't exist. + // Therefore, step 2 will correctly return no match. + }; + + DescriptorIndex index_; + std::vector> files_to_delete_; + + // If file is non-NULL, copy it into *output and return true, otherwise + // return false. + bool MaybeCopy(const FileDescriptorProto* file, FileDescriptorProto* output); + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(SimpleDescriptorDatabase); +}; + +// Very similar to SimpleDescriptorDatabase, but stores all the descriptors +// as raw bytes and generally tries to use as little memory as possible. +// +// The same caveats regarding FindFileContainingExtension() apply as with +// SimpleDescriptorDatabase. +class PROTOBUF_EXPORT EncodedDescriptorDatabase : public DescriptorDatabase { + public: + EncodedDescriptorDatabase(); + ~EncodedDescriptorDatabase() override; + + // Adds the FileDescriptorProto to the database. The descriptor is provided + // in encoded form. The database does not make a copy of the bytes, nor + // does it take ownership; it's up to the caller to make sure the bytes + // remain valid for the life of the database. Returns false and logs an error + // if the bytes are not a valid FileDescriptorProto or if the file conflicted + // with a file already in the database. + bool Add(const void* encoded_file_descriptor, int size); + + // Like Add(), but makes a copy of the data, so that the caller does not + // need to keep it around. + bool AddCopy(const void* encoded_file_descriptor, int size); + + // Like FindFileContainingSymbol but returns only the name of the file. + bool FindNameOfFileContainingSymbol(const std::string& symbol_name, + std::string* output); + + // implements DescriptorDatabase ----------------------------------- + bool FindFileByName(const std::string& filename, + FileDescriptorProto* output) override; + bool FindFileContainingSymbol(const std::string& symbol_name, + FileDescriptorProto* output) override; + bool FindFileContainingExtension(const std::string& containing_type, + int field_number, + FileDescriptorProto* output) override; + bool FindAllExtensionNumbers(const std::string& extendee_type, + std::vector* output) override; + bool FindAllFileNames(std::vector* output) override; + + private: + class DescriptorIndex; + // Keep DescriptorIndex by pointer to hide the implementation to keep a + // cleaner header. + std::unique_ptr index_; + std::vector files_to_delete_; + + // If encoded_file.first is non-NULL, parse the data into *output and return + // true, otherwise return false. + bool MaybeParse(std::pair encoded_file, + FileDescriptorProto* output); + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(EncodedDescriptorDatabase); +}; + +// A DescriptorDatabase that fetches files from a given pool. +class PROTOBUF_EXPORT DescriptorPoolDatabase : public DescriptorDatabase { + public: + explicit DescriptorPoolDatabase(const DescriptorPool& pool); + ~DescriptorPoolDatabase() override; + + // implements DescriptorDatabase ----------------------------------- + bool FindFileByName(const std::string& filename, + FileDescriptorProto* output) override; + bool FindFileContainingSymbol(const std::string& symbol_name, + FileDescriptorProto* output) override; + bool FindFileContainingExtension(const std::string& containing_type, + int field_number, + FileDescriptorProto* output) override; + bool FindAllExtensionNumbers(const std::string& extendee_type, + std::vector* output) override; + + private: + const DescriptorPool& pool_; + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(DescriptorPoolDatabase); +}; + +// A DescriptorDatabase that wraps two or more others. It first searches the +// first database and, if that fails, tries the second, and so on. +class PROTOBUF_EXPORT MergedDescriptorDatabase : public DescriptorDatabase { + public: + // Merge just two databases. The sources remain property of the caller. + MergedDescriptorDatabase(DescriptorDatabase* source1, + DescriptorDatabase* source2); + // Merge more than two databases. The sources remain property of the caller. + // The vector may be deleted after the constructor returns but the + // DescriptorDatabases need to stick around. + explicit MergedDescriptorDatabase( + const std::vector& sources); + ~MergedDescriptorDatabase() override; + + // implements DescriptorDatabase ----------------------------------- + bool FindFileByName(const std::string& filename, + FileDescriptorProto* output) override; + bool FindFileContainingSymbol(const std::string& symbol_name, + FileDescriptorProto* output) override; + bool FindFileContainingExtension(const std::string& containing_type, + int field_number, + FileDescriptorProto* output) override; + // Merges the results of calling all databases. Returns true iff any + // of the databases returned true. + bool FindAllExtensionNumbers(const std::string& extendee_type, + std::vector* output) override; + + + private: + std::vector sources_; + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(MergedDescriptorDatabase); +}; + +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_DESCRIPTOR_DATABASE_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/duration.pb.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/duration.pb.h new file mode 100644 index 0000000000000000000000000000000000000000..769fd2ba6f068eb8d94e18fdbc7f97452d9a9e06 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/duration.pb.h @@ -0,0 +1,282 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: google/protobuf/duration.proto + +#ifndef GOOGLE_PROTOBUF_INCLUDED_google_2fprotobuf_2fduration_2eproto +#define GOOGLE_PROTOBUF_INCLUDED_google_2fprotobuf_2fduration_2eproto + +#include +#include + +#include +#if PROTOBUF_VERSION < 3013000 +#error This file was generated by a newer version of protoc which is +#error incompatible with your Protocol Buffer headers. Please update +#error your headers. +#endif +#if 3013000 < PROTOBUF_MIN_PROTOC_VERSION +#error This file was generated by an older version of protoc which is +#error incompatible with your Protocol Buffer headers. Please +#error regenerate this file with a newer version of protoc. +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include // IWYU pragma: export +#include // IWYU pragma: export +#include +// @@protoc_insertion_point(includes) +#include +#define PROTOBUF_INTERNAL_EXPORT_google_2fprotobuf_2fduration_2eproto PROTOBUF_EXPORT +PROTOBUF_NAMESPACE_OPEN +namespace internal { +class AnyMetadata; +} // namespace internal +PROTOBUF_NAMESPACE_CLOSE + +// Internal implementation detail -- do not use these members. +struct PROTOBUF_EXPORT TableStruct_google_2fprotobuf_2fduration_2eproto { + static const ::PROTOBUF_NAMESPACE_ID::internal::ParseTableField entries[] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::AuxiliaryParseTableField aux[] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::ParseTable schema[1] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::FieldMetadata field_metadata[]; + static const ::PROTOBUF_NAMESPACE_ID::internal::SerializationTable serialization_table[]; + static const ::PROTOBUF_NAMESPACE_ID::uint32 offsets[]; +}; +extern PROTOBUF_EXPORT const ::PROTOBUF_NAMESPACE_ID::internal::DescriptorTable descriptor_table_google_2fprotobuf_2fduration_2eproto; +PROTOBUF_NAMESPACE_OPEN +class Duration; +class DurationDefaultTypeInternal; +PROTOBUF_EXPORT extern DurationDefaultTypeInternal _Duration_default_instance_; +PROTOBUF_NAMESPACE_CLOSE +PROTOBUF_NAMESPACE_OPEN +template<> PROTOBUF_EXPORT PROTOBUF_NAMESPACE_ID::Duration* Arena::CreateMaybeMessage(Arena*); +PROTOBUF_NAMESPACE_CLOSE +PROTOBUF_NAMESPACE_OPEN + +// =================================================================== + +class PROTOBUF_EXPORT Duration PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:google.protobuf.Duration) */ { + public: + inline Duration() : Duration(nullptr) {} + virtual ~Duration(); + + Duration(const Duration& from); + Duration(Duration&& from) noexcept + : Duration() { + *this = ::std::move(from); + } + + inline Duration& operator=(const Duration& from) { + CopyFrom(from); + return *this; + } + inline Duration& operator=(Duration&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const Duration& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const Duration* internal_default_instance() { + return reinterpret_cast( + &_Duration_default_instance_); + } + static constexpr int kIndexInFileMessages = + 0; + + friend void swap(Duration& a, Duration& b) { + a.Swap(&b); + } + inline void Swap(Duration* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(Duration* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline Duration* New() const final { + return CreateMaybeMessage(nullptr); + } + + Duration* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const Duration& from); + void MergeFrom(const Duration& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(Duration* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "google.protobuf.Duration"; + } + protected: + explicit Duration(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_google_2fprotobuf_2fduration_2eproto); + return ::descriptor_table_google_2fprotobuf_2fduration_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kSecondsFieldNumber = 1, + kNanosFieldNumber = 2, + }; + // int64 seconds = 1; + void clear_seconds(); + ::PROTOBUF_NAMESPACE_ID::int64 seconds() const; + void set_seconds(::PROTOBUF_NAMESPACE_ID::int64 value); + private: + ::PROTOBUF_NAMESPACE_ID::int64 _internal_seconds() const; + void _internal_set_seconds(::PROTOBUF_NAMESPACE_ID::int64 value); + public: + + // int32 nanos = 2; + void clear_nanos(); + ::PROTOBUF_NAMESPACE_ID::int32 nanos() const; + void set_nanos(::PROTOBUF_NAMESPACE_ID::int32 value); + private: + ::PROTOBUF_NAMESPACE_ID::int32 _internal_nanos() const; + void _internal_set_nanos(::PROTOBUF_NAMESPACE_ID::int32 value); + public: + + // @@protoc_insertion_point(class_scope:google.protobuf.Duration) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::int64 seconds_; + ::PROTOBUF_NAMESPACE_ID::int32 nanos_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + friend struct ::TableStruct_google_2fprotobuf_2fduration_2eproto; +}; +// =================================================================== + + +// =================================================================== + +#ifdef __GNUC__ + #pragma GCC diagnostic push + #pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif // __GNUC__ +// Duration + +// int64 seconds = 1; +inline void Duration::clear_seconds() { + seconds_ = PROTOBUF_LONGLONG(0); +} +inline ::PROTOBUF_NAMESPACE_ID::int64 Duration::_internal_seconds() const { + return seconds_; +} +inline ::PROTOBUF_NAMESPACE_ID::int64 Duration::seconds() const { + // @@protoc_insertion_point(field_get:google.protobuf.Duration.seconds) + return _internal_seconds(); +} +inline void Duration::_internal_set_seconds(::PROTOBUF_NAMESPACE_ID::int64 value) { + + seconds_ = value; +} +inline void Duration::set_seconds(::PROTOBUF_NAMESPACE_ID::int64 value) { + _internal_set_seconds(value); + // @@protoc_insertion_point(field_set:google.protobuf.Duration.seconds) +} + +// int32 nanos = 2; +inline void Duration::clear_nanos() { + nanos_ = 0; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 Duration::_internal_nanos() const { + return nanos_; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 Duration::nanos() const { + // @@protoc_insertion_point(field_get:google.protobuf.Duration.nanos) + return _internal_nanos(); +} +inline void Duration::_internal_set_nanos(::PROTOBUF_NAMESPACE_ID::int32 value) { + + nanos_ = value; +} +inline void Duration::set_nanos(::PROTOBUF_NAMESPACE_ID::int32 value) { + _internal_set_nanos(value); + // @@protoc_insertion_point(field_set:google.protobuf.Duration.nanos) +} + +#ifdef __GNUC__ + #pragma GCC diagnostic pop +#endif // __GNUC__ + +// @@protoc_insertion_point(namespace_scope) + +PROTOBUF_NAMESPACE_CLOSE + +// @@protoc_insertion_point(global_scope) + +#include +#endif // GOOGLE_PROTOBUF_INCLUDED_GOOGLE_PROTOBUF_INCLUDED_google_2fprotobuf_2fduration_2eproto + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/dynamic_message.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/dynamic_message.h new file mode 100644 index 0000000000000000000000000000000000000000..e9f02a2eb43f7c860890fc67540392539ecdfae4 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/dynamic_message.h @@ -0,0 +1,244 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: kenton@google.com (Kenton Varda) +// Based on original Protocol Buffers design by +// Sanjay Ghemawat, Jeff Dean, and others. +// +// Defines an implementation of Message which can emulate types which are not +// known at compile-time. + +#ifndef GOOGLE_PROTOBUF_DYNAMIC_MESSAGE_H__ +#define GOOGLE_PROTOBUF_DYNAMIC_MESSAGE_H__ + +#include +#include +#include + +#include +#include +#include +#include +#include + +#ifdef SWIG +#error "You cannot SWIG proto headers" +#endif + +#include + +namespace google { +namespace protobuf { + +// Defined in other files. +class Descriptor; // descriptor.h +class DescriptorPool; // descriptor.h + +// Constructs implementations of Message which can emulate types which are not +// known at compile-time. +// +// Sometimes you want to be able to manipulate protocol types that you don't +// know about at compile time. It would be nice to be able to construct +// a Message object which implements the message type given by any arbitrary +// Descriptor. DynamicMessage provides this. +// +// As it turns out, a DynamicMessage needs to construct extra +// information about its type in order to operate. Most of this information +// can be shared between all DynamicMessages of the same type. But, caching +// this information in some sort of global map would be a bad idea, since +// the cached information for a particular descriptor could outlive the +// descriptor itself. To avoid this problem, DynamicMessageFactory +// encapsulates this "cache". All DynamicMessages of the same type created +// from the same factory will share the same support data. Any Descriptors +// used with a particular factory must outlive the factory. +class PROTOBUF_EXPORT DynamicMessageFactory : public MessageFactory { + public: + // Construct a DynamicMessageFactory that will search for extensions in + // the DescriptorPool in which the extendee is defined. + DynamicMessageFactory(); + + // Construct a DynamicMessageFactory that will search for extensions in + // the given DescriptorPool. + // + // DEPRECATED: Use CodedInputStream::SetExtensionRegistry() to tell the + // parser to look for extensions in an alternate pool. However, note that + // this is almost never what you want to do. Almost all users should use + // the zero-arg constructor. + DynamicMessageFactory(const DescriptorPool* pool); + + ~DynamicMessageFactory(); + + // Call this to tell the DynamicMessageFactory that if it is given a + // Descriptor d for which: + // d->file()->pool() == DescriptorPool::generated_pool(), + // then it should delegate to MessageFactory::generated_factory() instead + // of constructing a dynamic implementation of the message. In theory there + // is no down side to doing this, so it may become the default in the future. + void SetDelegateToGeneratedFactory(bool enable) { + delegate_to_generated_factory_ = enable; + } + + // implements MessageFactory --------------------------------------- + + // Given a Descriptor, constructs the default (prototype) Message of that + // type. You can then call that message's New() method to construct a + // mutable message of that type. + // + // Calling this method twice with the same Descriptor returns the same + // object. The returned object remains property of the factory and will + // be destroyed when the factory is destroyed. Also, any objects created + // by calling the prototype's New() method share some data with the + // prototype, so these must be destroyed before the DynamicMessageFactory + // is destroyed. + // + // The given descriptor must outlive the returned message, and hence must + // outlive the DynamicMessageFactory. + // + // The method is thread-safe. + const Message* GetPrototype(const Descriptor* type) override; + + private: + const DescriptorPool* pool_; + bool delegate_to_generated_factory_; + + // This struct just contains a hash_map. We can't #include from + // this header due to hacks needed for hash_map portability in the open source + // release. Namely, stubs/hash.h, which defines hash_map portably, is not a + // public header (for good reason), but dynamic_message.h is, and public + // headers may only #include other public headers. + struct PrototypeMap; + std::unique_ptr prototypes_; + mutable internal::WrappedMutex prototypes_mutex_; + + friend class DynamicMessage; + const Message* GetPrototypeNoLock(const Descriptor* type); + + // Construct default oneof instance for reflection usage if oneof + // is defined. + static void ConstructDefaultOneofInstance(const Descriptor* type, + const uint32 offsets[], + void* default_oneof_instance); + // Delete default oneof instance. Called by ~DynamicMessageFactory. + static void DeleteDefaultOneofInstance(const Descriptor* type, + const uint32 offsets[], + const void* default_oneof_instance); + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(DynamicMessageFactory); +}; + +// Helper for computing a sorted list of map entries via reflection. +class PROTOBUF_EXPORT DynamicMapSorter { + public: + static std::vector Sort(const Message& message, int map_size, + const Reflection* reflection, + const FieldDescriptor* field) { + std::vector result; + result.reserve(map_size); + RepeatedFieldRef map_field = + reflection->GetRepeatedFieldRef(message, field); + for (auto it = map_field.begin(); it != map_field.end(); ++it) { + result.push_back(&*it); + } + MapEntryMessageComparator comparator(field->message_type()); + std::stable_sort(result.begin(), result.end(), comparator); + // Complain if the keys aren't in ascending order. +#ifndef NDEBUG + for (size_t j = 1; j < static_cast(map_size); j++) { + if (!comparator(result[j - 1], result[j])) { + GOOGLE_LOG(ERROR) << (comparator(result[j], result[j - 1]) + ? "internal error in map key sorting" + : "map keys are not unique"); + } + } +#endif + return result; + } + + private: + class PROTOBUF_EXPORT MapEntryMessageComparator { + public: + explicit MapEntryMessageComparator(const Descriptor* descriptor) + : field_(descriptor->field(0)) {} + + bool operator()(const Message* a, const Message* b) { + const Reflection* reflection = a->GetReflection(); + switch (field_->cpp_type()) { + case FieldDescriptor::CPPTYPE_BOOL: { + bool first = reflection->GetBool(*a, field_); + bool second = reflection->GetBool(*b, field_); + return first < second; + } + case FieldDescriptor::CPPTYPE_INT32: { + int32 first = reflection->GetInt32(*a, field_); + int32 second = reflection->GetInt32(*b, field_); + return first < second; + } + case FieldDescriptor::CPPTYPE_INT64: { + int64 first = reflection->GetInt64(*a, field_); + int64 second = reflection->GetInt64(*b, field_); + return first < second; + } + case FieldDescriptor::CPPTYPE_UINT32: { + uint32 first = reflection->GetUInt32(*a, field_); + uint32 second = reflection->GetUInt32(*b, field_); + return first < second; + } + case FieldDescriptor::CPPTYPE_UINT64: { + uint64 first = reflection->GetUInt64(*a, field_); + uint64 second = reflection->GetUInt64(*b, field_); + return first < second; + } + case FieldDescriptor::CPPTYPE_STRING: { + std::string first = reflection->GetString(*a, field_); + std::string second = reflection->GetString(*b, field_); + return first < second; + } + default: + GOOGLE_LOG(DFATAL) << "Invalid key for map field."; + return true; + } + } + + private: + const FieldDescriptor* field_; + }; +}; + +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_DYNAMIC_MESSAGE_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/empty.pb.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/empty.pb.h new file mode 100644 index 0000000000000000000000000000000000000000..9ccde906f78b2181cd9aeaa3733bdb7fc3b3dc67 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/empty.pb.h @@ -0,0 +1,218 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: google/protobuf/empty.proto + +#ifndef GOOGLE_PROTOBUF_INCLUDED_google_2fprotobuf_2fempty_2eproto +#define GOOGLE_PROTOBUF_INCLUDED_google_2fprotobuf_2fempty_2eproto + +#include +#include + +#include +#if PROTOBUF_VERSION < 3013000 +#error This file was generated by a newer version of protoc which is +#error incompatible with your Protocol Buffer headers. Please update +#error your headers. +#endif +#if 3013000 < PROTOBUF_MIN_PROTOC_VERSION +#error This file was generated by an older version of protoc which is +#error incompatible with your Protocol Buffer headers. Please +#error regenerate this file with a newer version of protoc. +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include // IWYU pragma: export +#include // IWYU pragma: export +#include +// @@protoc_insertion_point(includes) +#include +#define PROTOBUF_INTERNAL_EXPORT_google_2fprotobuf_2fempty_2eproto PROTOBUF_EXPORT +PROTOBUF_NAMESPACE_OPEN +namespace internal { +class AnyMetadata; +} // namespace internal +PROTOBUF_NAMESPACE_CLOSE + +// Internal implementation detail -- do not use these members. +struct PROTOBUF_EXPORT TableStruct_google_2fprotobuf_2fempty_2eproto { + static const ::PROTOBUF_NAMESPACE_ID::internal::ParseTableField entries[] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::AuxiliaryParseTableField aux[] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::ParseTable schema[1] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::FieldMetadata field_metadata[]; + static const ::PROTOBUF_NAMESPACE_ID::internal::SerializationTable serialization_table[]; + static const ::PROTOBUF_NAMESPACE_ID::uint32 offsets[]; +}; +extern PROTOBUF_EXPORT const ::PROTOBUF_NAMESPACE_ID::internal::DescriptorTable descriptor_table_google_2fprotobuf_2fempty_2eproto; +PROTOBUF_NAMESPACE_OPEN +class Empty; +class EmptyDefaultTypeInternal; +PROTOBUF_EXPORT extern EmptyDefaultTypeInternal _Empty_default_instance_; +PROTOBUF_NAMESPACE_CLOSE +PROTOBUF_NAMESPACE_OPEN +template<> PROTOBUF_EXPORT PROTOBUF_NAMESPACE_ID::Empty* Arena::CreateMaybeMessage(Arena*); +PROTOBUF_NAMESPACE_CLOSE +PROTOBUF_NAMESPACE_OPEN + +// =================================================================== + +class PROTOBUF_EXPORT Empty PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:google.protobuf.Empty) */ { + public: + inline Empty() : Empty(nullptr) {} + virtual ~Empty(); + + Empty(const Empty& from); + Empty(Empty&& from) noexcept + : Empty() { + *this = ::std::move(from); + } + + inline Empty& operator=(const Empty& from) { + CopyFrom(from); + return *this; + } + inline Empty& operator=(Empty&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const Empty& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const Empty* internal_default_instance() { + return reinterpret_cast( + &_Empty_default_instance_); + } + static constexpr int kIndexInFileMessages = + 0; + + friend void swap(Empty& a, Empty& b) { + a.Swap(&b); + } + inline void Swap(Empty* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(Empty* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline Empty* New() const final { + return CreateMaybeMessage(nullptr); + } + + Empty* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const Empty& from); + void MergeFrom(const Empty& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(Empty* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "google.protobuf.Empty"; + } + protected: + explicit Empty(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_google_2fprotobuf_2fempty_2eproto); + return ::descriptor_table_google_2fprotobuf_2fempty_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + // @@protoc_insertion_point(class_scope:google.protobuf.Empty) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + friend struct ::TableStruct_google_2fprotobuf_2fempty_2eproto; +}; +// =================================================================== + + +// =================================================================== + +#ifdef __GNUC__ + #pragma GCC diagnostic push + #pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif // __GNUC__ +// Empty + +#ifdef __GNUC__ + #pragma GCC diagnostic pop +#endif // __GNUC__ + +// @@protoc_insertion_point(namespace_scope) + +PROTOBUF_NAMESPACE_CLOSE + +// @@protoc_insertion_point(global_scope) + +#include +#endif // GOOGLE_PROTOBUF_INCLUDED_GOOGLE_PROTOBUF_INCLUDED_google_2fprotobuf_2fempty_2eproto + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/extension_set.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/extension_set.h new file mode 100644 index 0000000000000000000000000000000000000000..a8c5bd647d592ec276de51365b5619dcb716e3e7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/extension_set.h @@ -0,0 +1,1593 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: kenton@google.com (Kenton Varda) +// Based on original Protocol Buffers design by +// Sanjay Ghemawat, Jeff Dean, and others. +// +// This header is logically internal, but is made public because it is used +// from protocol-compiler-generated code, which may reside in other components. + +#ifndef GOOGLE_PROTOBUF_EXTENSION_SET_H__ +#define GOOGLE_PROTOBUF_EXTENSION_SET_H__ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include + +#ifdef SWIG +#error "You cannot SWIG proto headers" +#endif + +namespace google { +namespace protobuf { +class Arena; +class Descriptor; // descriptor.h +class FieldDescriptor; // descriptor.h +class DescriptorPool; // descriptor.h +class MessageLite; // message_lite.h +class Message; // message.h +class MessageFactory; // message.h +class UnknownFieldSet; // unknown_field_set.h +namespace internal { +class FieldSkipper; // wire_format_lite.h +} // namespace internal +} // namespace protobuf +} // namespace google + +namespace google { +namespace protobuf { +namespace internal { + +class InternalMetadata; + +// Used to store values of type WireFormatLite::FieldType without having to +// #include wire_format_lite.h. Also, ensures that we use only one byte to +// store these values, which is important to keep the layout of +// ExtensionSet::Extension small. +typedef uint8 FieldType; + +// A function which, given an integer value, returns true if the number +// matches one of the defined values for the corresponding enum type. This +// is used with RegisterEnumExtension, below. +typedef bool EnumValidityFunc(int number); + +// Version of the above which takes an argument. This is needed to deal with +// extensions that are not compiled in. +typedef bool EnumValidityFuncWithArg(const void* arg, int number); + +// Information about a registered extension. +struct ExtensionInfo { + inline ExtensionInfo() {} + inline ExtensionInfo(FieldType type_param, bool isrepeated, bool ispacked) + : type(type_param), + is_repeated(isrepeated), + is_packed(ispacked), + descriptor(NULL) {} + + FieldType type; + bool is_repeated; + bool is_packed; + + struct EnumValidityCheck { + EnumValidityFuncWithArg* func; + const void* arg; + }; + + struct MessageInfo { + const MessageLite* prototype; + }; + + union { + EnumValidityCheck enum_validity_check; + MessageInfo message_info; + }; + + // The descriptor for this extension, if one exists and is known. May be + // NULL. Must not be NULL if the descriptor for the extension does not + // live in the same pool as the descriptor for the containing type. + const FieldDescriptor* descriptor; +}; + +// Abstract interface for an object which looks up extension definitions. Used +// when parsing. +class PROTOBUF_EXPORT ExtensionFinder { + public: + virtual ~ExtensionFinder(); + + // Find the extension with the given containing type and number. + virtual bool Find(int number, ExtensionInfo* output) = 0; +}; + +// Implementation of ExtensionFinder which finds extensions defined in .proto +// files which have been compiled into the binary. +class PROTOBUF_EXPORT GeneratedExtensionFinder : public ExtensionFinder { + public: + GeneratedExtensionFinder(const MessageLite* containing_type) + : containing_type_(containing_type) {} + ~GeneratedExtensionFinder() override {} + + // Returns true and fills in *output if found, otherwise returns false. + bool Find(int number, ExtensionInfo* output) override; + + private: + const MessageLite* containing_type_; +}; + +// A FieldSkipper used for parsing MessageSet. +class MessageSetFieldSkipper; + +// Note: extension_set_heavy.cc defines DescriptorPoolExtensionFinder for +// finding extensions from a DescriptorPool. + +// This is an internal helper class intended for use within the protocol buffer +// library and generated classes. Clients should not use it directly. Instead, +// use the generated accessors such as GetExtension() of the class being +// extended. +// +// This class manages extensions for a protocol message object. The +// message's HasExtension(), GetExtension(), MutableExtension(), and +// ClearExtension() methods are just thin wrappers around the embedded +// ExtensionSet. When parsing, if a tag number is encountered which is +// inside one of the message type's extension ranges, the tag is passed +// off to the ExtensionSet for parsing. Etc. +class PROTOBUF_EXPORT ExtensionSet { + public: + ExtensionSet(); + explicit ExtensionSet(Arena* arena); + ~ExtensionSet(); + + // These are called at startup by protocol-compiler-generated code to + // register known extensions. The registrations are used by ParseField() + // to look up extensions for parsed field numbers. Note that dynamic parsing + // does not use ParseField(); only protocol-compiler-generated parsing + // methods do. + static void RegisterExtension(const MessageLite* containing_type, int number, + FieldType type, bool is_repeated, + bool is_packed); + static void RegisterEnumExtension(const MessageLite* containing_type, + int number, FieldType type, + bool is_repeated, bool is_packed, + EnumValidityFunc* is_valid); + static void RegisterMessageExtension(const MessageLite* containing_type, + int number, FieldType type, + bool is_repeated, bool is_packed, + const MessageLite* prototype); + + // ================================================================= + + // Add all fields which are currently present to the given vector. This + // is useful to implement Reflection::ListFields(). + void AppendToList(const Descriptor* containing_type, + const DescriptorPool* pool, + std::vector* output) const; + + // ================================================================= + // Accessors + // + // Generated message classes include type-safe templated wrappers around + // these methods. Generally you should use those rather than call these + // directly, unless you are doing low-level memory management. + // + // When calling any of these accessors, the extension number requested + // MUST exist in the DescriptorPool provided to the constructor. Otherwise, + // the method will fail an assert. Normally, though, you would not call + // these directly; you would either call the generated accessors of your + // message class (e.g. GetExtension()) or you would call the accessors + // of the reflection interface. In both cases, it is impossible to + // trigger this assert failure: the generated accessors only accept + // linked-in extension types as parameters, while the Reflection interface + // requires you to provide the FieldDescriptor describing the extension. + // + // When calling any of these accessors, a protocol-compiler-generated + // implementation of the extension corresponding to the number MUST + // be linked in, and the FieldDescriptor used to refer to it MUST be + // the one generated by that linked-in code. Otherwise, the method will + // die on an assert failure. The message objects returned by the message + // accessors are guaranteed to be of the correct linked-in type. + // + // These methods pretty much match Reflection except that: + // - They're not virtual. + // - They identify fields by number rather than FieldDescriptors. + // - They identify enum values using integers rather than descriptors. + // - Strings provide Mutable() in addition to Set() accessors. + + bool Has(int number) const; + int ExtensionSize(int number) const; // Size of a repeated extension. + int NumExtensions() const; // The number of extensions + FieldType ExtensionType(int number) const; + void ClearExtension(int number); + + // singular fields ------------------------------------------------- + + int32 GetInt32(int number, int32 default_value) const; + int64 GetInt64(int number, int64 default_value) const; + uint32 GetUInt32(int number, uint32 default_value) const; + uint64 GetUInt64(int number, uint64 default_value) const; + float GetFloat(int number, float default_value) const; + double GetDouble(int number, double default_value) const; + bool GetBool(int number, bool default_value) const; + int GetEnum(int number, int default_value) const; + const std::string& GetString(int number, + const std::string& default_value) const; + const MessageLite& GetMessage(int number, + const MessageLite& default_value) const; + const MessageLite& GetMessage(int number, const Descriptor* message_type, + MessageFactory* factory) const; + + // |descriptor| may be NULL so long as it is known that the descriptor for + // the extension lives in the same pool as the descriptor for the containing + // type. +#define desc const FieldDescriptor* descriptor // avoid line wrapping + void SetInt32(int number, FieldType type, int32 value, desc); + void SetInt64(int number, FieldType type, int64 value, desc); + void SetUInt32(int number, FieldType type, uint32 value, desc); + void SetUInt64(int number, FieldType type, uint64 value, desc); + void SetFloat(int number, FieldType type, float value, desc); + void SetDouble(int number, FieldType type, double value, desc); + void SetBool(int number, FieldType type, bool value, desc); + void SetEnum(int number, FieldType type, int value, desc); + void SetString(int number, FieldType type, std::string value, desc); + std::string* MutableString(int number, FieldType type, desc); + MessageLite* MutableMessage(int number, FieldType type, + const MessageLite& prototype, desc); + MessageLite* MutableMessage(const FieldDescriptor* descriptor, + MessageFactory* factory); + // Adds the given message to the ExtensionSet, taking ownership of the + // message object. Existing message with the same number will be deleted. + // If "message" is NULL, this is equivalent to "ClearExtension(number)". + void SetAllocatedMessage(int number, FieldType type, + const FieldDescriptor* descriptor, + MessageLite* message); + void UnsafeArenaSetAllocatedMessage(int number, FieldType type, + const FieldDescriptor* descriptor, + MessageLite* message); + MessageLite* ReleaseMessage(int number, const MessageLite& prototype); + MessageLite* UnsafeArenaReleaseMessage(int number, + const MessageLite& prototype); + + MessageLite* ReleaseMessage(const FieldDescriptor* descriptor, + MessageFactory* factory); + MessageLite* UnsafeArenaReleaseMessage(const FieldDescriptor* descriptor, + MessageFactory* factory); +#undef desc + Arena* GetArena() const { return arena_; } + + // repeated fields ------------------------------------------------- + + // Fetches a RepeatedField extension by number; returns |default_value| + // if no such extension exists. User should not touch this directly; it is + // used by the GetRepeatedExtension() method. + const void* GetRawRepeatedField(int number, const void* default_value) const; + // Fetches a mutable version of a RepeatedField extension by number, + // instantiating one if none exists. Similar to above, user should not use + // this directly; it underlies MutableRepeatedExtension(). + void* MutableRawRepeatedField(int number, FieldType field_type, bool packed, + const FieldDescriptor* desc); + + // This is an overload of MutableRawRepeatedField to maintain compatibility + // with old code using a previous API. This version of + // MutableRawRepeatedField() will GOOGLE_CHECK-fail on a missing extension. + // (E.g.: borg/clients/internal/proto1/proto2_reflection.cc.) + void* MutableRawRepeatedField(int number); + + int32 GetRepeatedInt32(int number, int index) const; + int64 GetRepeatedInt64(int number, int index) const; + uint32 GetRepeatedUInt32(int number, int index) const; + uint64 GetRepeatedUInt64(int number, int index) const; + float GetRepeatedFloat(int number, int index) const; + double GetRepeatedDouble(int number, int index) const; + bool GetRepeatedBool(int number, int index) const; + int GetRepeatedEnum(int number, int index) const; + const std::string& GetRepeatedString(int number, int index) const; + const MessageLite& GetRepeatedMessage(int number, int index) const; + + void SetRepeatedInt32(int number, int index, int32 value); + void SetRepeatedInt64(int number, int index, int64 value); + void SetRepeatedUInt32(int number, int index, uint32 value); + void SetRepeatedUInt64(int number, int index, uint64 value); + void SetRepeatedFloat(int number, int index, float value); + void SetRepeatedDouble(int number, int index, double value); + void SetRepeatedBool(int number, int index, bool value); + void SetRepeatedEnum(int number, int index, int value); + void SetRepeatedString(int number, int index, std::string value); + std::string* MutableRepeatedString(int number, int index); + MessageLite* MutableRepeatedMessage(int number, int index); + +#define desc const FieldDescriptor* descriptor // avoid line wrapping + void AddInt32(int number, FieldType type, bool packed, int32 value, desc); + void AddInt64(int number, FieldType type, bool packed, int64 value, desc); + void AddUInt32(int number, FieldType type, bool packed, uint32 value, desc); + void AddUInt64(int number, FieldType type, bool packed, uint64 value, desc); + void AddFloat(int number, FieldType type, bool packed, float value, desc); + void AddDouble(int number, FieldType type, bool packed, double value, desc); + void AddBool(int number, FieldType type, bool packed, bool value, desc); + void AddEnum(int number, FieldType type, bool packed, int value, desc); + void AddString(int number, FieldType type, std::string value, desc); + std::string* AddString(int number, FieldType type, desc); + MessageLite* AddMessage(int number, FieldType type, + const MessageLite& prototype, desc); + MessageLite* AddMessage(const FieldDescriptor* descriptor, + MessageFactory* factory); + void AddAllocatedMessage(const FieldDescriptor* descriptor, + MessageLite* new_entry); +#undef desc + + void RemoveLast(int number); + MessageLite* ReleaseLast(int number); + void SwapElements(int number, int index1, int index2); + + // ----------------------------------------------------------------- + // TODO(kenton): Hardcore memory management accessors + + // ================================================================= + // convenience methods for implementing methods of Message + // + // These could all be implemented in terms of the other methods of this + // class, but providing them here helps keep the generated code size down. + + void Clear(); + void MergeFrom(const ExtensionSet& other); + void Swap(ExtensionSet* other); + void SwapExtension(ExtensionSet* other, int number); + bool IsInitialized() const; + + // Parses a single extension from the input. The input should start out + // positioned immediately after the tag. + bool ParseField(uint32 tag, io::CodedInputStream* input, + ExtensionFinder* extension_finder, + FieldSkipper* field_skipper); + + // Specific versions for lite or full messages (constructs the appropriate + // FieldSkipper automatically). |containing_type| is the default + // instance for the containing message; it is used only to look up the + // extension by number. See RegisterExtension(), above. Unlike the other + // methods of ExtensionSet, this only works for generated message types -- + // it looks up extensions registered using RegisterExtension(). + bool ParseField(uint32 tag, io::CodedInputStream* input, + const MessageLite* containing_type); + bool ParseField(uint32 tag, io::CodedInputStream* input, + const Message* containing_type, + UnknownFieldSet* unknown_fields); + bool ParseField(uint32 tag, io::CodedInputStream* input, + const MessageLite* containing_type, + io::CodedOutputStream* unknown_fields); + + // Lite parser + const char* ParseField(uint64 tag, const char* ptr, + const MessageLite* containing_type, + internal::InternalMetadata* metadata, + internal::ParseContext* ctx); + // Full parser + const char* ParseField(uint64 tag, const char* ptr, + const Message* containing_type, + internal::InternalMetadata* metadata, + internal::ParseContext* ctx); + template + const char* ParseMessageSet(const char* ptr, const Msg* containing_type, + InternalMetadata* metadata, + internal::ParseContext* ctx) { + struct MessageSetItem { + const char* _InternalParse(const char* ptr, ParseContext* ctx) { + return me->ParseMessageSetItem(ptr, containing_type, metadata, ctx); + } + ExtensionSet* me; + const Msg* containing_type; + InternalMetadata* metadata; + } item{this, containing_type, metadata}; + while (!ctx->Done(&ptr)) { + uint32 tag; + ptr = ReadTag(ptr, &tag); + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); + if (tag == WireFormatLite::kMessageSetItemStartTag) { + ptr = ctx->ParseGroup(&item, ptr, tag); + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); + } else { + if (tag == 0 || (tag & 7) == 4) { + ctx->SetLastTag(tag); + return ptr; + } + ptr = ParseField(tag, ptr, containing_type, metadata, ctx); + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); + } + } + return ptr; + } + + // Parse an entire message in MessageSet format. Such messages have no + // fields, only extensions. + bool ParseMessageSetLite(io::CodedInputStream* input, + ExtensionFinder* extension_finder, + FieldSkipper* field_skipper); + bool ParseMessageSet(io::CodedInputStream* input, + ExtensionFinder* extension_finder, + MessageSetFieldSkipper* field_skipper); + + // Specific versions for lite or full messages (constructs the appropriate + // FieldSkipper automatically). + bool ParseMessageSet(io::CodedInputStream* input, + const MessageLite* containing_type, + std::string* unknown_fields); + bool ParseMessageSet(io::CodedInputStream* input, + const Message* containing_type, + UnknownFieldSet* unknown_fields); + + // Write all extension fields with field numbers in the range + // [start_field_number, end_field_number) + // to the output stream, using the cached sizes computed when ByteSize() was + // last called. Note that the range bounds are inclusive-exclusive. + void SerializeWithCachedSizes(int start_field_number, int end_field_number, + io::CodedOutputStream* output) const { + output->SetCur(_InternalSerialize(start_field_number, end_field_number, + output->Cur(), output->EpsCopy())); + } + + // Same as SerializeWithCachedSizes, but without any bounds checking. + // The caller must ensure that target has sufficient capacity for the + // serialized extensions. + // + // Returns a pointer past the last written byte. + uint8* _InternalSerialize(int start_field_number, int end_field_number, + uint8* target, + io::EpsCopyOutputStream* stream) const; + + // Like above but serializes in MessageSet format. + void SerializeMessageSetWithCachedSizes(io::CodedOutputStream* output) const { + output->SetCur(InternalSerializeMessageSetWithCachedSizesToArray( + output->Cur(), output->EpsCopy())); + } + uint8* InternalSerializeMessageSetWithCachedSizesToArray( + uint8* target, io::EpsCopyOutputStream* stream) const; + + // For backward-compatibility, versions of two of the above methods that + // serialize deterministically iff SetDefaultSerializationDeterministic() + // has been called. + uint8* SerializeWithCachedSizesToArray(int start_field_number, + int end_field_number, + uint8* target) const; + uint8* SerializeMessageSetWithCachedSizesToArray(uint8* target) const; + + // Returns the total serialized size of all the extensions. + size_t ByteSize() const; + + // Like ByteSize() but uses MessageSet format. + size_t MessageSetByteSize() const; + + // Returns (an estimate of) the total number of bytes used for storing the + // extensions in memory, excluding sizeof(*this). If the ExtensionSet is + // for a lite message (and thus possibly contains lite messages), the results + // are undefined (might work, might crash, might corrupt data, might not even + // be linked in). It's up to the protocol compiler to avoid calling this on + // such ExtensionSets (easy enough since lite messages don't implement + // SpaceUsed()). + size_t SpaceUsedExcludingSelfLong() const; + + // This method just calls SpaceUsedExcludingSelfLong() but it can not be + // inlined because the definition of SpaceUsedExcludingSelfLong() is not + // included in lite runtime and when an inline method refers to it MSVC + // will complain about unresolved symbols when building the lite runtime + // as .dll. + int SpaceUsedExcludingSelf() const; + + private: + // Interface of a lazily parsed singular message extension. + class PROTOBUF_EXPORT LazyMessageExtension { + public: + LazyMessageExtension() {} + virtual ~LazyMessageExtension() {} + + virtual LazyMessageExtension* New(Arena* arena) const = 0; + virtual const MessageLite& GetMessage( + const MessageLite& prototype) const = 0; + virtual MessageLite* MutableMessage(const MessageLite& prototype) = 0; + virtual void SetAllocatedMessage(MessageLite* message) = 0; + virtual void UnsafeArenaSetAllocatedMessage(MessageLite* message) = 0; + virtual MessageLite* ReleaseMessage(const MessageLite& prototype) = 0; + virtual MessageLite* UnsafeArenaReleaseMessage( + const MessageLite& prototype) = 0; + + virtual bool IsInitialized() const = 0; + + PROTOBUF_DEPRECATED_MSG("Please use ByteSizeLong() instead") + virtual int ByteSize() const { return internal::ToIntSize(ByteSizeLong()); } + virtual size_t ByteSizeLong() const = 0; + virtual size_t SpaceUsedLong() const = 0; + + virtual void MergeFrom(const LazyMessageExtension& other) = 0; + virtual void Clear() = 0; + + virtual bool ReadMessage(const MessageLite& prototype, + io::CodedInputStream* input) = 0; + virtual const char* _InternalParse(const char* ptr, ParseContext* ctx) = 0; + virtual uint8* WriteMessageToArray( + int number, uint8* target, io::EpsCopyOutputStream* stream) const = 0; + + private: + virtual void UnusedKeyMethod(); // Dummy key method to avoid weak vtable. + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(LazyMessageExtension); + }; + struct Extension { + // The order of these fields packs Extension into 24 bytes when using 8 + // byte alignment. Consider this when adding or removing fields here. + union { + int32 int32_value; + int64 int64_value; + uint32 uint32_value; + uint64 uint64_value; + float float_value; + double double_value; + bool bool_value; + int enum_value; + std::string* string_value; + MessageLite* message_value; + LazyMessageExtension* lazymessage_value; + + RepeatedField* repeated_int32_value; + RepeatedField* repeated_int64_value; + RepeatedField* repeated_uint32_value; + RepeatedField* repeated_uint64_value; + RepeatedField* repeated_float_value; + RepeatedField* repeated_double_value; + RepeatedField* repeated_bool_value; + RepeatedField* repeated_enum_value; + RepeatedPtrField* repeated_string_value; + RepeatedPtrField* repeated_message_value; + }; + + FieldType type; + bool is_repeated; + + // For singular types, indicates if the extension is "cleared". This + // happens when an extension is set and then later cleared by the caller. + // We want to keep the Extension object around for reuse, so instead of + // removing it from the map, we just set is_cleared = true. This has no + // meaning for repeated types; for those, the size of the RepeatedField + // simply becomes zero when cleared. + bool is_cleared : 4; + + // For singular message types, indicates whether lazy parsing is enabled + // for this extension. This field is only valid when type == TYPE_MESSAGE + // and !is_repeated because we only support lazy parsing for singular + // message types currently. If is_lazy = true, the extension is stored in + // lazymessage_value. Otherwise, the extension will be message_value. + bool is_lazy : 4; + + // For repeated types, this indicates if the [packed=true] option is set. + bool is_packed; + + // For packed fields, the size of the packed data is recorded here when + // ByteSize() is called then used during serialization. + // TODO(kenton): Use atomic when C++ supports it. + mutable int cached_size; + + // The descriptor for this extension, if one exists and is known. May be + // NULL. Must not be NULL if the descriptor for the extension does not + // live in the same pool as the descriptor for the containing type. + const FieldDescriptor* descriptor; + + // Some helper methods for operations on a single Extension. + uint8* InternalSerializeFieldWithCachedSizesToArray( + int number, uint8* target, io::EpsCopyOutputStream* stream) const; + uint8* InternalSerializeMessageSetItemWithCachedSizesToArray( + int number, uint8* target, io::EpsCopyOutputStream* stream) const; + size_t ByteSize(int number) const; + size_t MessageSetItemByteSize(int number) const; + void Clear(); + int GetSize() const; + void Free(); + size_t SpaceUsedExcludingSelfLong() const; + bool IsInitialized() const; + }; + + // The Extension struct is small enough to be passed by value, so we use it + // directly as the value type in mappings rather than use pointers. We use + // sorted maps rather than hash-maps because we expect most ExtensionSets will + // only contain a small number of extension. Also, we want AppendToList and + // deterministic serialization to order fields by field number. + + struct KeyValue { + int first; + Extension second; + + struct FirstComparator { + bool operator()(const KeyValue& lhs, const KeyValue& rhs) const { + return lhs.first < rhs.first; + } + bool operator()(const KeyValue& lhs, int key) const { + return lhs.first < key; + } + bool operator()(int key, const KeyValue& rhs) const { + return key < rhs.first; + } + }; + }; + + typedef std::map LargeMap; + + // Wrapper API that switches between flat-map and LargeMap. + + // Finds a key (if present) in the ExtensionSet. + const Extension* FindOrNull(int key) const; + Extension* FindOrNull(int key); + + // Helper-functions that only inspect the LargeMap. + const Extension* FindOrNullInLargeMap(int key) const; + Extension* FindOrNullInLargeMap(int key); + + // Inserts a new (key, Extension) into the ExtensionSet (and returns true), or + // finds the already-existing Extension for that key (returns false). + // The Extension* will point to the new-or-found Extension. + std::pair Insert(int key); + + // Grows the flat_capacity_. + // If flat_capacity_ > kMaximumFlatCapacity, converts to LargeMap. + void GrowCapacity(size_t minimum_new_capacity); + static constexpr uint16 kMaximumFlatCapacity = 256; + bool is_large() const { return flat_capacity_ > kMaximumFlatCapacity; } + + // Removes a key from the ExtensionSet. + void Erase(int key); + + size_t Size() const { + return PROTOBUF_PREDICT_FALSE(is_large()) ? map_.large->size() : flat_size_; + } + + // Similar to std::for_each. + // Each Iterator is decomposed into ->first and ->second fields, so + // that the KeyValueFunctor can be agnostic vis-a-vis KeyValue-vs-std::pair. + template + static KeyValueFunctor ForEach(Iterator begin, Iterator end, + KeyValueFunctor func) { + for (Iterator it = begin; it != end; ++it) func(it->first, it->second); + return std::move(func); + } + + // Applies a functor to the pairs in sorted order. + template + KeyValueFunctor ForEach(KeyValueFunctor func) { + if (PROTOBUF_PREDICT_FALSE(is_large())) { + return ForEach(map_.large->begin(), map_.large->end(), std::move(func)); + } + return ForEach(flat_begin(), flat_end(), std::move(func)); + } + + // Applies a functor to the pairs in sorted order. + template + KeyValueFunctor ForEach(KeyValueFunctor func) const { + if (PROTOBUF_PREDICT_FALSE(is_large())) { + return ForEach(map_.large->begin(), map_.large->end(), std::move(func)); + } + return ForEach(flat_begin(), flat_end(), std::move(func)); + } + + // Merges existing Extension from other_extension + void InternalExtensionMergeFrom(int number, const Extension& other_extension); + + // Returns true and fills field_number and extension if extension is found. + // Note to support packed repeated field compatibility, it also fills whether + // the tag on wire is packed, which can be different from + // extension->is_packed (whether packed=true is specified). + bool FindExtensionInfoFromTag(uint32 tag, ExtensionFinder* extension_finder, + int* field_number, ExtensionInfo* extension, + bool* was_packed_on_wire); + + // Returns true and fills extension if extension is found. + // Note to support packed repeated field compatibility, it also fills whether + // the tag on wire is packed, which can be different from + // extension->is_packed (whether packed=true is specified). + bool FindExtensionInfoFromFieldNumber(int wire_type, int field_number, + ExtensionFinder* extension_finder, + ExtensionInfo* extension, + bool* was_packed_on_wire); + + // Parses a single extension from the input. The input should start out + // positioned immediately after the wire tag. This method is called in + // ParseField() after field number and was_packed_on_wire is extracted from + // the wire tag and ExtensionInfo is found by the field number. + bool ParseFieldWithExtensionInfo(int field_number, bool was_packed_on_wire, + const ExtensionInfo& extension, + io::CodedInputStream* input, + FieldSkipper* field_skipper); + + // Like ParseField(), but this method may parse singular message extensions + // lazily depending on the value of FLAGS_eagerly_parse_message_sets. + bool ParseFieldMaybeLazily(int wire_type, int field_number, + io::CodedInputStream* input, + ExtensionFinder* extension_finder, + MessageSetFieldSkipper* field_skipper); + + // Gets the extension with the given number, creating it if it does not + // already exist. Returns true if the extension did not already exist. + bool MaybeNewExtension(int number, const FieldDescriptor* descriptor, + Extension** result); + + // Gets the repeated extension for the given descriptor, creating it if + // it does not exist. + Extension* MaybeNewRepeatedExtension(const FieldDescriptor* descriptor); + + // Parse a single MessageSet item -- called just after the item group start + // tag has been read. + bool ParseMessageSetItemLite(io::CodedInputStream* input, + ExtensionFinder* extension_finder, + FieldSkipper* field_skipper); + // Parse a single MessageSet item -- called just after the item group start + // tag has been read. + bool ParseMessageSetItem(io::CodedInputStream* input, + ExtensionFinder* extension_finder, + MessageSetFieldSkipper* field_skipper); + + bool FindExtension(int wire_type, uint32 field, + const MessageLite* containing_type, + const internal::ParseContext* /*ctx*/, + ExtensionInfo* extension, bool* was_packed_on_wire) { + GeneratedExtensionFinder finder(containing_type); + return FindExtensionInfoFromFieldNumber(wire_type, field, &finder, + extension, was_packed_on_wire); + } + inline bool FindExtension(int wire_type, uint32 field, + const Message* containing_type, + const internal::ParseContext* ctx, + ExtensionInfo* extension, bool* was_packed_on_wire); + // Used for MessageSet only + const char* ParseFieldMaybeLazily(uint64 tag, const char* ptr, + const MessageLite* containing_type, + internal::InternalMetadata* metadata, + internal::ParseContext* ctx) { + // Lite MessageSet doesn't implement lazy. + return ParseField(tag, ptr, containing_type, metadata, ctx); + } + const char* ParseFieldMaybeLazily(uint64 tag, const char* ptr, + const Message* containing_type, + internal::InternalMetadata* metadata, + internal::ParseContext* ctx); + const char* ParseMessageSetItem(const char* ptr, + const MessageLite* containing_type, + internal::InternalMetadata* metadata, + internal::ParseContext* ctx); + const char* ParseMessageSetItem(const char* ptr, + const Message* containing_type, + internal::InternalMetadata* metadata, + internal::ParseContext* ctx); + + // Implemented in extension_set_inl.h to keep code out of the header file. + template + const char* ParseFieldWithExtensionInfo(int number, bool was_packed_on_wire, + const ExtensionInfo& info, + internal::InternalMetadata* metadata, + const char* ptr, + internal::ParseContext* ctx); + template + const char* ParseMessageSetItemTmpl(const char* ptr, + const Msg* containing_type, + internal::InternalMetadata* metadata, + internal::ParseContext* ctx); + + // Hack: RepeatedPtrFieldBase declares ExtensionSet as a friend. This + // friendship should automatically extend to ExtensionSet::Extension, but + // unfortunately some older compilers (e.g. GCC 3.4.4) do not implement this + // correctly. So, we must provide helpers for calling methods of that + // class. + + // Defined in extension_set_heavy.cc. + static inline size_t RepeatedMessage_SpaceUsedExcludingSelfLong( + RepeatedPtrFieldBase* field); + + KeyValue* flat_begin() { + assert(!is_large()); + return map_.flat; + } + const KeyValue* flat_begin() const { + assert(!is_large()); + return map_.flat; + } + KeyValue* flat_end() { + assert(!is_large()); + return map_.flat + flat_size_; + } + const KeyValue* flat_end() const { + assert(!is_large()); + return map_.flat + flat_size_; + } + + Arena* arena_; + + // Manual memory-management: + // map_.flat is an allocated array of flat_capacity_ elements. + // [map_.flat, map_.flat + flat_size_) is the currently-in-use prefix. + uint16 flat_capacity_; + uint16 flat_size_; + union AllocatedData { + KeyValue* flat; + + // If flat_capacity_ > kMaximumFlatCapacity, switch to LargeMap, + // which guarantees O(n lg n) CPU but larger constant factors. + LargeMap* large; + } map_; + + static void DeleteFlatMap(const KeyValue* flat, uint16 flat_capacity); + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(ExtensionSet); +}; + +// These are just for convenience... +inline void ExtensionSet::SetString(int number, FieldType type, + std::string value, + const FieldDescriptor* descriptor) { + MutableString(number, type, descriptor)->assign(std::move(value)); +} +inline void ExtensionSet::SetRepeatedString(int number, int index, + std::string value) { + MutableRepeatedString(number, index)->assign(std::move(value)); +} +inline void ExtensionSet::AddString(int number, FieldType type, + std::string value, + const FieldDescriptor* descriptor) { + AddString(number, type, descriptor)->assign(std::move(value)); +} +// =================================================================== +// Glue for generated extension accessors + +// ------------------------------------------------------------------- +// Template magic + +// First we have a set of classes representing "type traits" for different +// field types. A type traits class knows how to implement basic accessors +// for extensions of a particular type given an ExtensionSet. The signature +// for a type traits class looks like this: +// +// class TypeTraits { +// public: +// typedef ? ConstType; +// typedef ? MutableType; +// // TypeTraits for singular fields and repeated fields will define the +// // symbol "Singular" or "Repeated" respectively. These two symbols will +// // be used in extension accessors to distinguish between singular +// // extensions and repeated extensions. If the TypeTraits for the passed +// // in extension doesn't have the expected symbol defined, it means the +// // user is passing a repeated extension to a singular accessor, or the +// // opposite. In that case the C++ compiler will generate an error +// // message "no matching member function" to inform the user. +// typedef ? Singular +// typedef ? Repeated +// +// static inline ConstType Get(int number, const ExtensionSet& set); +// static inline void Set(int number, ConstType value, ExtensionSet* set); +// static inline MutableType Mutable(int number, ExtensionSet* set); +// +// // Variants for repeated fields. +// static inline ConstType Get(int number, const ExtensionSet& set, +// int index); +// static inline void Set(int number, int index, +// ConstType value, ExtensionSet* set); +// static inline MutableType Mutable(int number, int index, +// ExtensionSet* set); +// static inline void Add(int number, ConstType value, ExtensionSet* set); +// static inline MutableType Add(int number, ExtensionSet* set); +// This is used by the ExtensionIdentifier constructor to register +// the extension at dynamic initialization. +// template +// static void Register(int number, FieldType type, bool is_packed); +// }; +// +// Not all of these methods make sense for all field types. For example, the +// "Mutable" methods only make sense for strings and messages, and the +// repeated methods only make sense for repeated types. So, each type +// traits class implements only the set of methods from this signature that it +// actually supports. This will cause a compiler error if the user tries to +// access an extension using a method that doesn't make sense for its type. +// For example, if "foo" is an extension of type "optional int32", then if you +// try to write code like: +// my_message.MutableExtension(foo) +// you will get a compile error because PrimitiveTypeTraits does not +// have a "Mutable()" method. + +// ------------------------------------------------------------------- +// PrimitiveTypeTraits + +// Since the ExtensionSet has different methods for each primitive type, +// we must explicitly define the methods of the type traits class for each +// known type. +template +class PrimitiveTypeTraits { + public: + typedef Type ConstType; + typedef Type MutableType; + typedef PrimitiveTypeTraits Singular; + + static inline ConstType Get(int number, const ExtensionSet& set, + ConstType default_value); + static inline void Set(int number, FieldType field_type, ConstType value, + ExtensionSet* set); + template + static void Register(int number, FieldType type, bool is_packed) { + ExtensionSet::RegisterExtension(&ExtendeeT::default_instance(), number, + type, false, is_packed); + } +}; + +template +class RepeatedPrimitiveTypeTraits { + public: + typedef Type ConstType; + typedef Type MutableType; + typedef RepeatedPrimitiveTypeTraits Repeated; + + typedef RepeatedField RepeatedFieldType; + + static inline Type Get(int number, const ExtensionSet& set, int index); + static inline void Set(int number, int index, Type value, ExtensionSet* set); + static inline void Add(int number, FieldType field_type, bool is_packed, + Type value, ExtensionSet* set); + + static inline const RepeatedField& GetRepeated( + int number, const ExtensionSet& set); + static inline RepeatedField* MutableRepeated(int number, + FieldType field_type, + bool is_packed, + ExtensionSet* set); + + static const RepeatedFieldType* GetDefaultRepeatedField(); + template + static void Register(int number, FieldType type, bool is_packed) { + ExtensionSet::RegisterExtension(&ExtendeeT::default_instance(), number, + type, true, is_packed); + } +}; + +class PROTOBUF_EXPORT RepeatedPrimitiveDefaults { + private: + template + friend class RepeatedPrimitiveTypeTraits; + static const RepeatedPrimitiveDefaults* default_instance(); + RepeatedField default_repeated_field_int32_; + RepeatedField default_repeated_field_int64_; + RepeatedField default_repeated_field_uint32_; + RepeatedField default_repeated_field_uint64_; + RepeatedField default_repeated_field_double_; + RepeatedField default_repeated_field_float_; + RepeatedField default_repeated_field_bool_; +}; + +#define PROTOBUF_DEFINE_PRIMITIVE_TYPE(TYPE, METHOD) \ + template <> \ + inline TYPE PrimitiveTypeTraits::Get( \ + int number, const ExtensionSet& set, TYPE default_value) { \ + return set.Get##METHOD(number, default_value); \ + } \ + template <> \ + inline void PrimitiveTypeTraits::Set(int number, FieldType field_type, \ + TYPE value, ExtensionSet* set) { \ + set->Set##METHOD(number, field_type, value, NULL); \ + } \ + \ + template <> \ + inline TYPE RepeatedPrimitiveTypeTraits::Get( \ + int number, const ExtensionSet& set, int index) { \ + return set.GetRepeated##METHOD(number, index); \ + } \ + template <> \ + inline void RepeatedPrimitiveTypeTraits::Set( \ + int number, int index, TYPE value, ExtensionSet* set) { \ + set->SetRepeated##METHOD(number, index, value); \ + } \ + template <> \ + inline void RepeatedPrimitiveTypeTraits::Add( \ + int number, FieldType field_type, bool is_packed, TYPE value, \ + ExtensionSet* set) { \ + set->Add##METHOD(number, field_type, is_packed, value, NULL); \ + } \ + template <> \ + inline const RepeatedField* \ + RepeatedPrimitiveTypeTraits::GetDefaultRepeatedField() { \ + return &RepeatedPrimitiveDefaults::default_instance() \ + ->default_repeated_field_##TYPE##_; \ + } \ + template <> \ + inline const RepeatedField& \ + RepeatedPrimitiveTypeTraits::GetRepeated(int number, \ + const ExtensionSet& set) { \ + return *reinterpret_cast*>( \ + set.GetRawRepeatedField(number, GetDefaultRepeatedField())); \ + } \ + template <> \ + inline RepeatedField* \ + RepeatedPrimitiveTypeTraits::MutableRepeated( \ + int number, FieldType field_type, bool is_packed, ExtensionSet* set) { \ + return reinterpret_cast*>( \ + set->MutableRawRepeatedField(number, field_type, is_packed, NULL)); \ + } + +PROTOBUF_DEFINE_PRIMITIVE_TYPE(int32, Int32) +PROTOBUF_DEFINE_PRIMITIVE_TYPE(int64, Int64) +PROTOBUF_DEFINE_PRIMITIVE_TYPE(uint32, UInt32) +PROTOBUF_DEFINE_PRIMITIVE_TYPE(uint64, UInt64) +PROTOBUF_DEFINE_PRIMITIVE_TYPE(float, Float) +PROTOBUF_DEFINE_PRIMITIVE_TYPE(double, Double) +PROTOBUF_DEFINE_PRIMITIVE_TYPE(bool, Bool) + +#undef PROTOBUF_DEFINE_PRIMITIVE_TYPE + +// ------------------------------------------------------------------- +// StringTypeTraits + +// Strings support both Set() and Mutable(). +class PROTOBUF_EXPORT StringTypeTraits { + public: + typedef const std::string& ConstType; + typedef std::string* MutableType; + typedef StringTypeTraits Singular; + + static inline const std::string& Get(int number, const ExtensionSet& set, + ConstType default_value) { + return set.GetString(number, default_value); + } + static inline void Set(int number, FieldType field_type, + const std::string& value, ExtensionSet* set) { + set->SetString(number, field_type, value, NULL); + } + static inline std::string* Mutable(int number, FieldType field_type, + ExtensionSet* set) { + return set->MutableString(number, field_type, NULL); + } + template + static void Register(int number, FieldType type, bool is_packed) { + ExtensionSet::RegisterExtension(&ExtendeeT::default_instance(), number, + type, false, is_packed); + } +}; + +class PROTOBUF_EXPORT RepeatedStringTypeTraits { + public: + typedef const std::string& ConstType; + typedef std::string* MutableType; + typedef RepeatedStringTypeTraits Repeated; + + typedef RepeatedPtrField RepeatedFieldType; + + static inline const std::string& Get(int number, const ExtensionSet& set, + int index) { + return set.GetRepeatedString(number, index); + } + static inline void Set(int number, int index, const std::string& value, + ExtensionSet* set) { + set->SetRepeatedString(number, index, value); + } + static inline std::string* Mutable(int number, int index, ExtensionSet* set) { + return set->MutableRepeatedString(number, index); + } + static inline void Add(int number, FieldType field_type, bool /*is_packed*/, + const std::string& value, ExtensionSet* set) { + set->AddString(number, field_type, value, NULL); + } + static inline std::string* Add(int number, FieldType field_type, + ExtensionSet* set) { + return set->AddString(number, field_type, NULL); + } + static inline const RepeatedPtrField& GetRepeated( + int number, const ExtensionSet& set) { + return *reinterpret_cast*>( + set.GetRawRepeatedField(number, GetDefaultRepeatedField())); + } + + static inline RepeatedPtrField* MutableRepeated( + int number, FieldType field_type, bool is_packed, ExtensionSet* set) { + return reinterpret_cast*>( + set->MutableRawRepeatedField(number, field_type, is_packed, NULL)); + } + + static const RepeatedFieldType* GetDefaultRepeatedField(); + + template + static void Register(int number, FieldType type, bool is_packed) { + ExtensionSet::RegisterExtension(&ExtendeeT::default_instance(), number, + type, true, is_packed); + } + + private: + static void InitializeDefaultRepeatedFields(); + static void DestroyDefaultRepeatedFields(); +}; + +// ------------------------------------------------------------------- +// EnumTypeTraits + +// ExtensionSet represents enums using integers internally, so we have to +// static_cast around. +template +class EnumTypeTraits { + public: + typedef Type ConstType; + typedef Type MutableType; + typedef EnumTypeTraits Singular; + + static inline ConstType Get(int number, const ExtensionSet& set, + ConstType default_value) { + return static_cast(set.GetEnum(number, default_value)); + } + static inline void Set(int number, FieldType field_type, ConstType value, + ExtensionSet* set) { + GOOGLE_DCHECK(IsValid(value)); + set->SetEnum(number, field_type, value, NULL); + } + template + static void Register(int number, FieldType type, bool is_packed) { + ExtensionSet::RegisterEnumExtension(&ExtendeeT::default_instance(), number, + type, false, is_packed, IsValid); + } +}; + +template +class RepeatedEnumTypeTraits { + public: + typedef Type ConstType; + typedef Type MutableType; + typedef RepeatedEnumTypeTraits Repeated; + + typedef RepeatedField RepeatedFieldType; + + static inline ConstType Get(int number, const ExtensionSet& set, int index) { + return static_cast(set.GetRepeatedEnum(number, index)); + } + static inline void Set(int number, int index, ConstType value, + ExtensionSet* set) { + GOOGLE_DCHECK(IsValid(value)); + set->SetRepeatedEnum(number, index, value); + } + static inline void Add(int number, FieldType field_type, bool is_packed, + ConstType value, ExtensionSet* set) { + GOOGLE_DCHECK(IsValid(value)); + set->AddEnum(number, field_type, is_packed, value, NULL); + } + static inline const RepeatedField& GetRepeated( + int number, const ExtensionSet& set) { + // Hack: the `Extension` struct stores a RepeatedField for enums. + // RepeatedField cannot implicitly convert to RepeatedField + // so we need to do some casting magic. See message.h for similar + // contortions for non-extension fields. + return *reinterpret_cast*>( + set.GetRawRepeatedField(number, GetDefaultRepeatedField())); + } + + static inline RepeatedField* MutableRepeated(int number, + FieldType field_type, + bool is_packed, + ExtensionSet* set) { + return reinterpret_cast*>( + set->MutableRawRepeatedField(number, field_type, is_packed, NULL)); + } + + static const RepeatedFieldType* GetDefaultRepeatedField() { + // Hack: as noted above, repeated enum fields are internally stored as a + // RepeatedField. We need to be able to instantiate global static + // objects to return as default (empty) repeated fields on non-existent + // extensions. We would not be able to know a-priori all of the enum types + // (values of |Type|) to instantiate all of these, so we just re-use int32's + // default repeated field object. + return reinterpret_cast*>( + RepeatedPrimitiveTypeTraits::GetDefaultRepeatedField()); + } + template + static void Register(int number, FieldType type, bool is_packed) { + ExtensionSet::RegisterEnumExtension(&ExtendeeT::default_instance(), number, + type, true, is_packed, IsValid); + } +}; + +// ------------------------------------------------------------------- +// MessageTypeTraits + +// ExtensionSet guarantees that when manipulating extensions with message +// types, the implementation used will be the compiled-in class representing +// that type. So, we can static_cast down to the exact type we expect. +template +class MessageTypeTraits { + public: + typedef const Type& ConstType; + typedef Type* MutableType; + typedef MessageTypeTraits Singular; + + static inline ConstType Get(int number, const ExtensionSet& set, + ConstType default_value) { + return static_cast(set.GetMessage(number, default_value)); + } + static inline MutableType Mutable(int number, FieldType field_type, + ExtensionSet* set) { + return static_cast(set->MutableMessage( + number, field_type, Type::default_instance(), NULL)); + } + static inline void SetAllocated(int number, FieldType field_type, + MutableType message, ExtensionSet* set) { + set->SetAllocatedMessage(number, field_type, NULL, message); + } + static inline void UnsafeArenaSetAllocated(int number, FieldType field_type, + MutableType message, + ExtensionSet* set) { + set->UnsafeArenaSetAllocatedMessage(number, field_type, NULL, message); + } + static inline MutableType Release(int number, FieldType /* field_type */, + ExtensionSet* set) { + return static_cast( + set->ReleaseMessage(number, Type::default_instance())); + } + static inline MutableType UnsafeArenaRelease(int number, + FieldType /* field_type */, + ExtensionSet* set) { + return static_cast( + set->UnsafeArenaReleaseMessage(number, Type::default_instance())); + } + template + static void Register(int number, FieldType type, bool is_packed) { + ExtensionSet::RegisterMessageExtension(&ExtendeeT::default_instance(), + number, type, false, is_packed, + &Type::default_instance()); + } +}; + +// forward declaration +class RepeatedMessageGenericTypeTraits; + +template +class RepeatedMessageTypeTraits { + public: + typedef const Type& ConstType; + typedef Type* MutableType; + typedef RepeatedMessageTypeTraits Repeated; + + typedef RepeatedPtrField RepeatedFieldType; + + static inline ConstType Get(int number, const ExtensionSet& set, int index) { + return static_cast(set.GetRepeatedMessage(number, index)); + } + static inline MutableType Mutable(int number, int index, ExtensionSet* set) { + return static_cast(set->MutableRepeatedMessage(number, index)); + } + static inline MutableType Add(int number, FieldType field_type, + ExtensionSet* set) { + return static_cast( + set->AddMessage(number, field_type, Type::default_instance(), NULL)); + } + static inline const RepeatedPtrField& GetRepeated( + int number, const ExtensionSet& set) { + // See notes above in RepeatedEnumTypeTraits::GetRepeated(): same + // casting hack applies here, because a RepeatedPtrField + // cannot naturally become a RepeatedPtrType even though Type is + // presumably a message. google::protobuf::Message goes through similar contortions + // with a reinterpret_cast<>. + return *reinterpret_cast*>( + set.GetRawRepeatedField(number, GetDefaultRepeatedField())); + } + static inline RepeatedPtrField* MutableRepeated(int number, + FieldType field_type, + bool is_packed, + ExtensionSet* set) { + return reinterpret_cast*>( + set->MutableRawRepeatedField(number, field_type, is_packed, NULL)); + } + + static const RepeatedFieldType* GetDefaultRepeatedField(); + template + static void Register(int number, FieldType type, bool is_packed) { + ExtensionSet::RegisterMessageExtension(&ExtendeeT::default_instance(), + number, type, true, is_packed, + &Type::default_instance()); + } +}; + +template +inline const typename RepeatedMessageTypeTraits::RepeatedFieldType* +RepeatedMessageTypeTraits::GetDefaultRepeatedField() { + static auto instance = OnShutdownDelete(new RepeatedFieldType); + return instance; +} + +// ------------------------------------------------------------------- +// ExtensionIdentifier + +// This is the type of actual extension objects. E.g. if you have: +// extends Foo with optional int32 bar = 1234; +// then "bar" will be defined in C++ as: +// ExtensionIdentifier, 5, false> bar(1234); +// +// Note that we could, in theory, supply the field number as a template +// parameter, and thus make an instance of ExtensionIdentifier have no +// actual contents. However, if we did that, then using an extension +// identifier would not necessarily cause the compiler to output any sort +// of reference to any symbol defined in the extension's .pb.o file. Some +// linkers will actually drop object files that are not explicitly referenced, +// but that would be bad because it would cause this extension to not be +// registered at static initialization, and therefore using it would crash. + +template +class ExtensionIdentifier { + public: + typedef TypeTraitsType TypeTraits; + typedef ExtendeeType Extendee; + + ExtensionIdentifier(int number, typename TypeTraits::ConstType default_value) + : number_(number), default_value_(default_value) { + Register(number); + } + inline int number() const { return number_; } + typename TypeTraits::ConstType default_value() const { + return default_value_; + } + + static void Register(int number) { + TypeTraits::template Register(number, field_type, is_packed); + } + + private: + const int number_; + typename TypeTraits::ConstType default_value_; +}; + +// ------------------------------------------------------------------- +// Generated accessors + +// This macro should be expanded in the context of a generated type which +// has extensions. +// +// We use "_proto_TypeTraits" as a type name below because "TypeTraits" +// causes problems if the class has a nested message or enum type with that +// name and "_TypeTraits" is technically reserved for the C++ library since +// it starts with an underscore followed by a capital letter. +// +// For similar reason, we use "_field_type" and "_is_packed" as parameter names +// below, so that "field_type" and "is_packed" can be used as field names. +#define GOOGLE_PROTOBUF_EXTENSION_ACCESSORS(CLASSNAME) \ + /* Has, Size, Clear */ \ + template \ + inline bool HasExtension( \ + const ::PROTOBUF_NAMESPACE_ID::internal::ExtensionIdentifier< \ + CLASSNAME, _proto_TypeTraits, _field_type, _is_packed>& id) const { \ + return _extensions_.Has(id.number()); \ + } \ + \ + template \ + inline void ClearExtension( \ + const ::PROTOBUF_NAMESPACE_ID::internal::ExtensionIdentifier< \ + CLASSNAME, _proto_TypeTraits, _field_type, _is_packed>& id) { \ + _extensions_.ClearExtension(id.number()); \ + } \ + \ + template \ + inline int ExtensionSize( \ + const ::PROTOBUF_NAMESPACE_ID::internal::ExtensionIdentifier< \ + CLASSNAME, _proto_TypeTraits, _field_type, _is_packed>& id) const { \ + return _extensions_.ExtensionSize(id.number()); \ + } \ + \ + /* Singular accessors */ \ + template \ + inline typename _proto_TypeTraits::Singular::ConstType GetExtension( \ + const ::PROTOBUF_NAMESPACE_ID::internal::ExtensionIdentifier< \ + CLASSNAME, _proto_TypeTraits, _field_type, _is_packed>& id) const { \ + return _proto_TypeTraits::Get(id.number(), _extensions_, \ + id.default_value()); \ + } \ + \ + template \ + inline typename _proto_TypeTraits::Singular::MutableType MutableExtension( \ + const ::PROTOBUF_NAMESPACE_ID::internal::ExtensionIdentifier< \ + CLASSNAME, _proto_TypeTraits, _field_type, _is_packed>& id) { \ + return _proto_TypeTraits::Mutable(id.number(), _field_type, \ + &_extensions_); \ + } \ + \ + template \ + inline void SetExtension( \ + const ::PROTOBUF_NAMESPACE_ID::internal::ExtensionIdentifier< \ + CLASSNAME, _proto_TypeTraits, _field_type, _is_packed>& id, \ + typename _proto_TypeTraits::Singular::ConstType value) { \ + _proto_TypeTraits::Set(id.number(), _field_type, value, &_extensions_); \ + } \ + \ + template \ + inline void SetAllocatedExtension( \ + const ::PROTOBUF_NAMESPACE_ID::internal::ExtensionIdentifier< \ + CLASSNAME, _proto_TypeTraits, _field_type, _is_packed>& id, \ + typename _proto_TypeTraits::Singular::MutableType value) { \ + _proto_TypeTraits::SetAllocated(id.number(), _field_type, value, \ + &_extensions_); \ + } \ + template \ + inline void UnsafeArenaSetAllocatedExtension( \ + const ::PROTOBUF_NAMESPACE_ID::internal::ExtensionIdentifier< \ + CLASSNAME, _proto_TypeTraits, _field_type, _is_packed>& id, \ + typename _proto_TypeTraits::Singular::MutableType value) { \ + _proto_TypeTraits::UnsafeArenaSetAllocated(id.number(), _field_type, \ + value, &_extensions_); \ + } \ + template \ + inline typename _proto_TypeTraits::Singular::MutableType ReleaseExtension( \ + const ::PROTOBUF_NAMESPACE_ID::internal::ExtensionIdentifier< \ + CLASSNAME, _proto_TypeTraits, _field_type, _is_packed>& id) { \ + return _proto_TypeTraits::Release(id.number(), _field_type, \ + &_extensions_); \ + } \ + template \ + inline typename _proto_TypeTraits::Singular::MutableType \ + UnsafeArenaReleaseExtension( \ + const ::PROTOBUF_NAMESPACE_ID::internal::ExtensionIdentifier< \ + CLASSNAME, _proto_TypeTraits, _field_type, _is_packed>& id) { \ + return _proto_TypeTraits::UnsafeArenaRelease(id.number(), _field_type, \ + &_extensions_); \ + } \ + \ + /* Repeated accessors */ \ + template \ + inline typename _proto_TypeTraits::Repeated::ConstType GetExtension( \ + const ::PROTOBUF_NAMESPACE_ID::internal::ExtensionIdentifier< \ + CLASSNAME, _proto_TypeTraits, _field_type, _is_packed>& id, \ + int index) const { \ + return _proto_TypeTraits::Get(id.number(), _extensions_, index); \ + } \ + \ + template \ + inline typename _proto_TypeTraits::Repeated::MutableType MutableExtension( \ + const ::PROTOBUF_NAMESPACE_ID::internal::ExtensionIdentifier< \ + CLASSNAME, _proto_TypeTraits, _field_type, _is_packed>& id, \ + int index) { \ + return _proto_TypeTraits::Mutable(id.number(), index, &_extensions_); \ + } \ + \ + template \ + inline void SetExtension( \ + const ::PROTOBUF_NAMESPACE_ID::internal::ExtensionIdentifier< \ + CLASSNAME, _proto_TypeTraits, _field_type, _is_packed>& id, \ + int index, typename _proto_TypeTraits::Repeated::ConstType value) { \ + _proto_TypeTraits::Set(id.number(), index, value, &_extensions_); \ + } \ + \ + template \ + inline typename _proto_TypeTraits::Repeated::MutableType AddExtension( \ + const ::PROTOBUF_NAMESPACE_ID::internal::ExtensionIdentifier< \ + CLASSNAME, _proto_TypeTraits, _field_type, _is_packed>& id) { \ + return _proto_TypeTraits::Add(id.number(), _field_type, &_extensions_); \ + } \ + \ + template \ + inline void AddExtension( \ + const ::PROTOBUF_NAMESPACE_ID::internal::ExtensionIdentifier< \ + CLASSNAME, _proto_TypeTraits, _field_type, _is_packed>& id, \ + typename _proto_TypeTraits::Repeated::ConstType value) { \ + _proto_TypeTraits::Add(id.number(), _field_type, _is_packed, value, \ + &_extensions_); \ + } \ + \ + template \ + inline const typename _proto_TypeTraits::Repeated::RepeatedFieldType& \ + GetRepeatedExtension( \ + const ::PROTOBUF_NAMESPACE_ID::internal::ExtensionIdentifier< \ + CLASSNAME, _proto_TypeTraits, _field_type, _is_packed>& id) const { \ + return _proto_TypeTraits::GetRepeated(id.number(), _extensions_); \ + } \ + \ + template \ + inline typename _proto_TypeTraits::Repeated::RepeatedFieldType* \ + MutableRepeatedExtension( \ + const ::PROTOBUF_NAMESPACE_ID::internal::ExtensionIdentifier< \ + CLASSNAME, _proto_TypeTraits, _field_type, _is_packed>& id) { \ + return _proto_TypeTraits::MutableRepeated(id.number(), _field_type, \ + _is_packed, &_extensions_); \ + } + +} // namespace internal + +// Call this function to ensure that this extensions's reflection is linked into +// the binary: +// +// google::protobuf::LinkExtensionReflection(Foo::my_extension); +// +// This will ensure that the following lookup will succeed: +// +// DescriptorPool::generated_pool()->FindExtensionByName("Foo.my_extension"); +// +// This is often relevant for parsing extensions in text mode. +// +// As a side-effect, it will also guarantee that anything else from the same +// .proto file will also be available for lookup in the generated pool. +// +// This function does not actually register the extension, so it does not need +// to be called before the lookup. However it does need to occur in a function +// that cannot be stripped from the binary (ie. it must be reachable from main). +// +// Best practice is to call this function as close as possible to where the +// reflection is actually needed. This function is very cheap to call, so you +// should not need to worry about its runtime overhead except in tight loops (on +// x86-64 it compiles into two "mov" instructions). +template +void LinkExtensionReflection( + const google::protobuf::internal::ExtensionIdentifier< + ExtendeeType, TypeTraitsType, field_type, is_packed>& extension) { + internal::StrongReference(extension); +} + +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_EXTENSION_SET_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/extension_set_inl.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/extension_set_inl.h new file mode 100644 index 0000000000000000000000000000000000000000..c8996c1c4e5022c0fe2a0e7fedbdd7241a8716b6 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/extension_set_inl.h @@ -0,0 +1,281 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef GOOGLE_PROTOBUF_EXTENSION_SET_INL_H__ +#define GOOGLE_PROTOBUF_EXTENSION_SET_INL_H__ + +#include +#include +#include + +namespace google { +namespace protobuf { +namespace internal { + +template +const char* ExtensionSet::ParseFieldWithExtensionInfo( + int number, bool was_packed_on_wire, const ExtensionInfo& extension, + InternalMetadata* metadata, const char* ptr, internal::ParseContext* ctx) { + if (was_packed_on_wire) { + switch (extension.type) { +#define HANDLE_TYPE(UPPERCASE, CPP_CAMELCASE) \ + case WireFormatLite::TYPE_##UPPERCASE: \ + return internal::Packed##CPP_CAMELCASE##Parser( \ + MutableRawRepeatedField(number, extension.type, extension.is_packed, \ + extension.descriptor), \ + ptr, ctx); + HANDLE_TYPE(INT32, Int32); + HANDLE_TYPE(INT64, Int64); + HANDLE_TYPE(UINT32, UInt32); + HANDLE_TYPE(UINT64, UInt64); + HANDLE_TYPE(SINT32, SInt32); + HANDLE_TYPE(SINT64, SInt64); + HANDLE_TYPE(FIXED32, Fixed32); + HANDLE_TYPE(FIXED64, Fixed64); + HANDLE_TYPE(SFIXED32, SFixed32); + HANDLE_TYPE(SFIXED64, SFixed64); + HANDLE_TYPE(FLOAT, Float); + HANDLE_TYPE(DOUBLE, Double); + HANDLE_TYPE(BOOL, Bool); +#undef HANDLE_TYPE + + case WireFormatLite::TYPE_ENUM: + return internal::PackedEnumParserArg( + MutableRawRepeatedField(number, extension.type, extension.is_packed, + extension.descriptor), + ptr, ctx, extension.enum_validity_check.func, + extension.enum_validity_check.arg, metadata, number); + case WireFormatLite::TYPE_STRING: + case WireFormatLite::TYPE_BYTES: + case WireFormatLite::TYPE_GROUP: + case WireFormatLite::TYPE_MESSAGE: + GOOGLE_LOG(FATAL) << "Non-primitive types can't be packed."; + break; + } + } else { + switch (extension.type) { +#define HANDLE_VARINT_TYPE(UPPERCASE, CPP_CAMELCASE) \ + case WireFormatLite::TYPE_##UPPERCASE: { \ + uint64 value; \ + ptr = VarintParse(ptr, &value); \ + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); \ + if (extension.is_repeated) { \ + Add##CPP_CAMELCASE(number, WireFormatLite::TYPE_##UPPERCASE, \ + extension.is_packed, value, extension.descriptor); \ + } else { \ + Set##CPP_CAMELCASE(number, WireFormatLite::TYPE_##UPPERCASE, value, \ + extension.descriptor); \ + } \ + } break + + HANDLE_VARINT_TYPE(INT32, Int32); + HANDLE_VARINT_TYPE(INT64, Int64); + HANDLE_VARINT_TYPE(UINT32, UInt32); + HANDLE_VARINT_TYPE(UINT64, UInt64); + HANDLE_VARINT_TYPE(BOOL, Bool); +#undef HANDLE_VARINT_TYPE +#define HANDLE_SVARINT_TYPE(UPPERCASE, CPP_CAMELCASE, SIZE) \ + case WireFormatLite::TYPE_##UPPERCASE: { \ + uint64 val; \ + ptr = VarintParse(ptr, &val); \ + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); \ + auto value = WireFormatLite::ZigZagDecode##SIZE(val); \ + if (extension.is_repeated) { \ + Add##CPP_CAMELCASE(number, WireFormatLite::TYPE_##UPPERCASE, \ + extension.is_packed, value, extension.descriptor); \ + } else { \ + Set##CPP_CAMELCASE(number, WireFormatLite::TYPE_##UPPERCASE, value, \ + extension.descriptor); \ + } \ + } break + + HANDLE_SVARINT_TYPE(SINT32, Int32, 32); + HANDLE_SVARINT_TYPE(SINT64, Int64, 64); +#undef HANDLE_SVARINT_TYPE +#define HANDLE_FIXED_TYPE(UPPERCASE, CPP_CAMELCASE, CPPTYPE) \ + case WireFormatLite::TYPE_##UPPERCASE: { \ + auto value = UnalignedLoad(ptr); \ + ptr += sizeof(CPPTYPE); \ + if (extension.is_repeated) { \ + Add##CPP_CAMELCASE(number, WireFormatLite::TYPE_##UPPERCASE, \ + extension.is_packed, value, extension.descriptor); \ + } else { \ + Set##CPP_CAMELCASE(number, WireFormatLite::TYPE_##UPPERCASE, value, \ + extension.descriptor); \ + } \ + } break + + HANDLE_FIXED_TYPE(FIXED32, UInt32, uint32); + HANDLE_FIXED_TYPE(FIXED64, UInt64, uint64); + HANDLE_FIXED_TYPE(SFIXED32, Int32, int32); + HANDLE_FIXED_TYPE(SFIXED64, Int64, int64); + HANDLE_FIXED_TYPE(FLOAT, Float, float); + HANDLE_FIXED_TYPE(DOUBLE, Double, double); +#undef HANDLE_FIXED_TYPE + + case WireFormatLite::TYPE_ENUM: { + uint64 val; + ptr = VarintParse(ptr, &val); + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); + int value = val; + + if (!extension.enum_validity_check.func( + extension.enum_validity_check.arg, value)) { + WriteVarint(number, val, metadata->mutable_unknown_fields()); + } else if (extension.is_repeated) { + AddEnum(number, WireFormatLite::TYPE_ENUM, extension.is_packed, value, + extension.descriptor); + } else { + SetEnum(number, WireFormatLite::TYPE_ENUM, value, + extension.descriptor); + } + break; + } + + case WireFormatLite::TYPE_BYTES: + case WireFormatLite::TYPE_STRING: { + std::string* value = + extension.is_repeated + ? AddString(number, WireFormatLite::TYPE_STRING, + extension.descriptor) + : MutableString(number, WireFormatLite::TYPE_STRING, + extension.descriptor); + int size = ReadSize(&ptr); + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); + return ctx->ReadString(ptr, size, value); + } + + case WireFormatLite::TYPE_GROUP: { + MessageLite* value = + extension.is_repeated + ? AddMessage(number, WireFormatLite::TYPE_GROUP, + *extension.message_info.prototype, + extension.descriptor) + : MutableMessage(number, WireFormatLite::TYPE_GROUP, + *extension.message_info.prototype, + extension.descriptor); + uint32 tag = (number << 3) + WireFormatLite::WIRETYPE_START_GROUP; + return ctx->ParseGroup(value, ptr, tag); + } + + case WireFormatLite::TYPE_MESSAGE: { + MessageLite* value = + extension.is_repeated + ? AddMessage(number, WireFormatLite::TYPE_MESSAGE, + *extension.message_info.prototype, + extension.descriptor) + : MutableMessage(number, WireFormatLite::TYPE_MESSAGE, + *extension.message_info.prototype, + extension.descriptor); + return ctx->ParseMessage(value, ptr); + } + } + } + return ptr; +} + +template +const char* ExtensionSet::ParseMessageSetItemTmpl( + const char* ptr, const Msg* containing_type, + internal::InternalMetadata* metadata, internal::ParseContext* ctx) { + std::string payload; + uint32 type_id = 0; + bool payload_read = false; + while (!ctx->Done(&ptr)) { + uint32 tag = static_cast(*ptr++); + if (tag == WireFormatLite::kMessageSetTypeIdTag) { + uint64 tmp; + ptr = ParseBigVarint(ptr, &tmp); + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); + type_id = tmp; + if (payload_read) { + ExtensionInfo extension; + bool was_packed_on_wire; + if (!FindExtension(2, type_id, containing_type, ctx, &extension, + &was_packed_on_wire)) { + WriteLengthDelimited(type_id, payload, + metadata->mutable_unknown_fields()); + } else { + MessageLite* value = + extension.is_repeated + ? AddMessage(type_id, WireFormatLite::TYPE_MESSAGE, + *extension.message_info.prototype, + extension.descriptor) + : MutableMessage(type_id, WireFormatLite::TYPE_MESSAGE, + *extension.message_info.prototype, + extension.descriptor); + + const char* p; + // We can't use regular parse from string as we have to track + // proper recursion depth and descriptor pools. + ParseContext tmp_ctx(ctx->depth(), false, &p, payload); + tmp_ctx.data().pool = ctx->data().pool; + tmp_ctx.data().factory = ctx->data().factory; + GOOGLE_PROTOBUF_PARSER_ASSERT(value->_InternalParse(p, &tmp_ctx) && + tmp_ctx.EndedAtLimit()); + } + type_id = 0; + } + } else if (tag == WireFormatLite::kMessageSetMessageTag) { + if (type_id != 0) { + ptr = ParseFieldMaybeLazily(static_cast(type_id) * 8 + 2, ptr, + containing_type, metadata, ctx); + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr != nullptr); + type_id = 0; + } else { + int32 size = ReadSize(&ptr); + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); + ptr = ctx->ReadString(ptr, size, &payload); + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); + payload_read = true; + } + } else { + ptr = ReadTag(ptr - 1, &tag); + if (tag == 0 || (tag & 7) == 4) { + ctx->SetLastTag(tag); + return ptr; + } + ptr = ParseField(tag, ptr, containing_type, metadata, ctx); + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); + } + } + return ptr; +} + +} // namespace internal +} // namespace protobuf +} // namespace google + +#endif // GOOGLE_PROTOBUF_EXTENSION_SET_INL_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/field_mask.pb.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/field_mask.pb.h new file mode 100644 index 0000000000000000000000000000000000000000..d8d57d823c293171522dfef2099cf1be0b51b261 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/field_mask.pb.h @@ -0,0 +1,320 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: google/protobuf/field_mask.proto + +#ifndef GOOGLE_PROTOBUF_INCLUDED_google_2fprotobuf_2ffield_5fmask_2eproto +#define GOOGLE_PROTOBUF_INCLUDED_google_2fprotobuf_2ffield_5fmask_2eproto + +#include +#include + +#include +#if PROTOBUF_VERSION < 3013000 +#error This file was generated by a newer version of protoc which is +#error incompatible with your Protocol Buffer headers. Please update +#error your headers. +#endif +#if 3013000 < PROTOBUF_MIN_PROTOC_VERSION +#error This file was generated by an older version of protoc which is +#error incompatible with your Protocol Buffer headers. Please +#error regenerate this file with a newer version of protoc. +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include // IWYU pragma: export +#include // IWYU pragma: export +#include +// @@protoc_insertion_point(includes) +#include +#define PROTOBUF_INTERNAL_EXPORT_google_2fprotobuf_2ffield_5fmask_2eproto PROTOBUF_EXPORT +PROTOBUF_NAMESPACE_OPEN +namespace internal { +class AnyMetadata; +} // namespace internal +PROTOBUF_NAMESPACE_CLOSE + +// Internal implementation detail -- do not use these members. +struct PROTOBUF_EXPORT TableStruct_google_2fprotobuf_2ffield_5fmask_2eproto { + static const ::PROTOBUF_NAMESPACE_ID::internal::ParseTableField entries[] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::AuxiliaryParseTableField aux[] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::ParseTable schema[1] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::FieldMetadata field_metadata[]; + static const ::PROTOBUF_NAMESPACE_ID::internal::SerializationTable serialization_table[]; + static const ::PROTOBUF_NAMESPACE_ID::uint32 offsets[]; +}; +extern PROTOBUF_EXPORT const ::PROTOBUF_NAMESPACE_ID::internal::DescriptorTable descriptor_table_google_2fprotobuf_2ffield_5fmask_2eproto; +PROTOBUF_NAMESPACE_OPEN +class FieldMask; +class FieldMaskDefaultTypeInternal; +PROTOBUF_EXPORT extern FieldMaskDefaultTypeInternal _FieldMask_default_instance_; +PROTOBUF_NAMESPACE_CLOSE +PROTOBUF_NAMESPACE_OPEN +template<> PROTOBUF_EXPORT PROTOBUF_NAMESPACE_ID::FieldMask* Arena::CreateMaybeMessage(Arena*); +PROTOBUF_NAMESPACE_CLOSE +PROTOBUF_NAMESPACE_OPEN + +// =================================================================== + +class PROTOBUF_EXPORT FieldMask PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:google.protobuf.FieldMask) */ { + public: + inline FieldMask() : FieldMask(nullptr) {} + virtual ~FieldMask(); + + FieldMask(const FieldMask& from); + FieldMask(FieldMask&& from) noexcept + : FieldMask() { + *this = ::std::move(from); + } + + inline FieldMask& operator=(const FieldMask& from) { + CopyFrom(from); + return *this; + } + inline FieldMask& operator=(FieldMask&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const FieldMask& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const FieldMask* internal_default_instance() { + return reinterpret_cast( + &_FieldMask_default_instance_); + } + static constexpr int kIndexInFileMessages = + 0; + + friend void swap(FieldMask& a, FieldMask& b) { + a.Swap(&b); + } + inline void Swap(FieldMask* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(FieldMask* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline FieldMask* New() const final { + return CreateMaybeMessage(nullptr); + } + + FieldMask* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const FieldMask& from); + void MergeFrom(const FieldMask& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(FieldMask* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "google.protobuf.FieldMask"; + } + protected: + explicit FieldMask(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_google_2fprotobuf_2ffield_5fmask_2eproto); + return ::descriptor_table_google_2fprotobuf_2ffield_5fmask_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kPathsFieldNumber = 1, + }; + // repeated string paths = 1; + int paths_size() const; + private: + int _internal_paths_size() const; + public: + void clear_paths(); + const std::string& paths(int index) const; + std::string* mutable_paths(int index); + void set_paths(int index, const std::string& value); + void set_paths(int index, std::string&& value); + void set_paths(int index, const char* value); + void set_paths(int index, const char* value, size_t size); + std::string* add_paths(); + void add_paths(const std::string& value); + void add_paths(std::string&& value); + void add_paths(const char* value); + void add_paths(const char* value, size_t size); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField& paths() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField* mutable_paths(); + private: + const std::string& _internal_paths(int index) const; + std::string* _internal_add_paths(); + public: + + // @@protoc_insertion_point(class_scope:google.protobuf.FieldMask) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField paths_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + friend struct ::TableStruct_google_2fprotobuf_2ffield_5fmask_2eproto; +}; +// =================================================================== + + +// =================================================================== + +#ifdef __GNUC__ + #pragma GCC diagnostic push + #pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif // __GNUC__ +// FieldMask + +// repeated string paths = 1; +inline int FieldMask::_internal_paths_size() const { + return paths_.size(); +} +inline int FieldMask::paths_size() const { + return _internal_paths_size(); +} +inline void FieldMask::clear_paths() { + paths_.Clear(); +} +inline std::string* FieldMask::add_paths() { + // @@protoc_insertion_point(field_add_mutable:google.protobuf.FieldMask.paths) + return _internal_add_paths(); +} +inline const std::string& FieldMask::_internal_paths(int index) const { + return paths_.Get(index); +} +inline const std::string& FieldMask::paths(int index) const { + // @@protoc_insertion_point(field_get:google.protobuf.FieldMask.paths) + return _internal_paths(index); +} +inline std::string* FieldMask::mutable_paths(int index) { + // @@protoc_insertion_point(field_mutable:google.protobuf.FieldMask.paths) + return paths_.Mutable(index); +} +inline void FieldMask::set_paths(int index, const std::string& value) { + // @@protoc_insertion_point(field_set:google.protobuf.FieldMask.paths) + paths_.Mutable(index)->assign(value); +} +inline void FieldMask::set_paths(int index, std::string&& value) { + // @@protoc_insertion_point(field_set:google.protobuf.FieldMask.paths) + paths_.Mutable(index)->assign(std::move(value)); +} +inline void FieldMask::set_paths(int index, const char* value) { + GOOGLE_DCHECK(value != nullptr); + paths_.Mutable(index)->assign(value); + // @@protoc_insertion_point(field_set_char:google.protobuf.FieldMask.paths) +} +inline void FieldMask::set_paths(int index, const char* value, size_t size) { + paths_.Mutable(index)->assign( + reinterpret_cast(value), size); + // @@protoc_insertion_point(field_set_pointer:google.protobuf.FieldMask.paths) +} +inline std::string* FieldMask::_internal_add_paths() { + return paths_.Add(); +} +inline void FieldMask::add_paths(const std::string& value) { + paths_.Add()->assign(value); + // @@protoc_insertion_point(field_add:google.protobuf.FieldMask.paths) +} +inline void FieldMask::add_paths(std::string&& value) { + paths_.Add(std::move(value)); + // @@protoc_insertion_point(field_add:google.protobuf.FieldMask.paths) +} +inline void FieldMask::add_paths(const char* value) { + GOOGLE_DCHECK(value != nullptr); + paths_.Add()->assign(value); + // @@protoc_insertion_point(field_add_char:google.protobuf.FieldMask.paths) +} +inline void FieldMask::add_paths(const char* value, size_t size) { + paths_.Add()->assign(reinterpret_cast(value), size); + // @@protoc_insertion_point(field_add_pointer:google.protobuf.FieldMask.paths) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField& +FieldMask::paths() const { + // @@protoc_insertion_point(field_list:google.protobuf.FieldMask.paths) + return paths_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField* +FieldMask::mutable_paths() { + // @@protoc_insertion_point(field_mutable_list:google.protobuf.FieldMask.paths) + return &paths_; +} + +#ifdef __GNUC__ + #pragma GCC diagnostic pop +#endif // __GNUC__ + +// @@protoc_insertion_point(namespace_scope) + +PROTOBUF_NAMESPACE_CLOSE + +// @@protoc_insertion_point(global_scope) + +#include +#endif // GOOGLE_PROTOBUF_INCLUDED_GOOGLE_PROTOBUF_INCLUDED_google_2fprotobuf_2ffield_5fmask_2eproto + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/generated_enum_reflection.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/generated_enum_reflection.h new file mode 100644 index 0000000000000000000000000000000000000000..64257d58ffef9d1094a797b0ec9ce315315ee42f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/generated_enum_reflection.h @@ -0,0 +1,103 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: jasonh@google.com (Jason Hsueh) +// +// This header is logically internal, but is made public because it is used +// from protocol-compiler-generated code, which may reside in other components. +// It provides reflection support for generated enums, and is included in +// generated .pb.h files and should have minimal dependencies. The methods are +// implemented in generated_message_reflection.cc. + +#ifndef GOOGLE_PROTOBUF_GENERATED_ENUM_REFLECTION_H__ +#define GOOGLE_PROTOBUF_GENERATED_ENUM_REFLECTION_H__ + +#include + +#include +#include +#include + +#ifdef SWIG +#error "You cannot SWIG proto headers" +#endif + +#include + +namespace google { +namespace protobuf { +class EnumDescriptor; +} // namespace protobuf +} // namespace google + +namespace google { +namespace protobuf { + +// Returns the EnumDescriptor for enum type E, which must be a +// proto-declared enum type. Code generated by the protocol compiler +// will include specializations of this template for each enum type declared. +template +const EnumDescriptor* GetEnumDescriptor(); + +namespace internal { + +// Helper for EnumType_Parse functions: try to parse the string 'name' as +// an enum name of the given type, returning true and filling in value on +// success, or returning false and leaving value unchanged on failure. +PROTOBUF_EXPORT bool ParseNamedEnum(const EnumDescriptor* descriptor, + ConstStringParam name, int* value); + +template +bool ParseNamedEnum(const EnumDescriptor* descriptor, ConstStringParam name, + EnumType* value) { + int tmp; + if (!ParseNamedEnum(descriptor, name, &tmp)) return false; + *value = static_cast(tmp); + return true; +} + +// Just a wrapper around printing the name of a value. The main point of this +// function is not to be inlined, so that you can do this without including +// descriptor.h. +PROTOBUF_EXPORT const std::string& NameOfEnum(const EnumDescriptor* descriptor, + int value); + +} // namespace internal +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_GENERATED_ENUM_REFLECTION_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/generated_enum_util.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/generated_enum_util.h new file mode 100644 index 0000000000000000000000000000000000000000..45f5083336bebfda4e5dd65dcb1e68c9e6196daf --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/generated_enum_util.h @@ -0,0 +1,88 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef GOOGLE_PROTOBUF_GENERATED_ENUM_UTIL_H__ +#define GOOGLE_PROTOBUF_GENERATED_ENUM_UTIL_H__ + +#include + +#include +#include + +#include + +#ifdef SWIG +#error "You cannot SWIG proto headers" +#endif + +namespace google { +namespace protobuf { + +// This type trait can be used to cause templates to only match proto2 enum +// types. +template +struct is_proto_enum : ::std::false_type {}; + +namespace internal { + +// The table entry format for storing enum name-to-value mapping used with lite +// protos. This struct and the following related functions should only be used +// by protobuf generated code. +struct EnumEntry { + StringPiece name; + int value; +}; + +// Looks up a numeric enum value given the string name. +PROTOBUF_EXPORT bool LookUpEnumValue(const EnumEntry* enums, size_t size, + StringPiece name, int* value); + +// Looks up an enum name given the numeric value. +PROTOBUF_EXPORT int LookUpEnumName(const EnumEntry* enums, + const int* sorted_indices, size_t size, + int value); + +// Initializes the list of enum names in std::string form. +PROTOBUF_EXPORT bool InitializeEnumStrings( + const EnumEntry* enums, const int* sorted_indices, size_t size, + internal::ExplicitlyConstructed* enum_strings); + +} // namespace internal +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_GENERATED_ENUM_UTIL_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/generated_message_reflection.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/generated_message_reflection.h new file mode 100644 index 0000000000000000000000000000000000000000..0ce150f8cf53e5da729b1ae00ea22b2f867302ae --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/generated_message_reflection.h @@ -0,0 +1,330 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: kenton@google.com (Kenton Varda) +// Based on original Protocol Buffers design by +// Sanjay Ghemawat, Jeff Dean, and others. +// +// This header is logically internal, but is made public because it is used +// from protocol-compiler-generated code, which may reside in other components. + +#ifndef GOOGLE_PROTOBUF_GENERATED_MESSAGE_REFLECTION_H__ +#define GOOGLE_PROTOBUF_GENERATED_MESSAGE_REFLECTION_H__ + +#include +#include +#include +#include +// TODO(jasonh): Remove this once the compiler change to directly include this +// is released to components. +#include +#include +#include +#include +#include + + +#include + +#ifdef SWIG +#error "You cannot SWIG proto headers" +#endif + +namespace google { +namespace protobuf { +class DescriptorPool; +class MapKey; +class MapValueRef; +class MessageLayoutInspector; +class Message; +struct Metadata; +} // namespace protobuf +} // namespace google + +namespace google { +namespace protobuf { +namespace internal { +class DefaultEmptyOneof; +// Defined in other files. +class ExtensionSet; // extension_set.h +class WeakFieldMap; // weak_field_map.h + +// This struct describes the internal layout of the message, hence this is +// used to act on the message reflectively. +// default_instance: The default instance of the message. This is only +// used to obtain pointers to default instances of embedded +// messages, which GetMessage() will return if the particular +// sub-message has not been initialized yet. (Thus, all +// embedded message fields *must* have non-null pointers +// in the default instance.) +// offsets: An array of ints giving the byte offsets. +// For each oneof or weak field, the offset is relative to the +// default_instance. These can be computed at compile time +// using the +// PROTO2_GENERATED_DEFAULT_ONEOF_FIELD_OFFSET() +// macro. For each none oneof field, the offset is related to +// the start of the message object. These can be computed at +// compile time using the +// PROTO2_GENERATED_MESSAGE_FIELD_OFFSET() macro. +// Besides offsets for all fields, this array also contains +// offsets for oneof unions. The offset of the i-th oneof union +// is offsets[descriptor->field_count() + i]. +// has_bit_indices: Mapping from field indexes to their index in the has +// bit array. +// has_bits_offset: Offset in the message of an array of uint32s of size +// descriptor->field_count()/32, rounded up. This is a +// bitfield where each bit indicates whether or not the +// corresponding field of the message has been initialized. +// The bit for field index i is obtained by the expression: +// has_bits[i / 32] & (1 << (i % 32)) +// unknown_fields_offset: Offset in the message of the UnknownFieldSet for +// the message. +// extensions_offset: Offset in the message of the ExtensionSet for the +// message, or -1 if the message type has no extension +// ranges. +// oneof_case_offset: Offset in the message of an array of uint32s of +// size descriptor->oneof_decl_count(). Each uint32 +// indicates what field is set for each oneof. +// object_size: The size of a message object of this type, as measured +// by sizeof(). +// arena_offset: If a message doesn't have a unknown_field_set that stores +// the arena, it must have a direct pointer to the arena. +// weak_field_map_offset: If the message proto has weak fields, this is the +// offset of _weak_field_map_ in the generated proto. Otherwise +// -1. +struct ReflectionSchema { + public: + // Size of a google::protobuf::Message object of this type. + uint32 GetObjectSize() const { return static_cast(object_size_); } + + bool InRealOneof(const FieldDescriptor* field) const { + return field->containing_oneof() && + !field->containing_oneof()->is_synthetic(); + } + + // Offset of a non-oneof field. Getting a field offset is slightly more + // efficient when we know statically that it is not a oneof field. + uint32 GetFieldOffsetNonOneof(const FieldDescriptor* field) const { + GOOGLE_DCHECK(!InRealOneof(field)); + return OffsetValue(offsets_[field->index()], field->type()); + } + + // Offset of any field. + uint32 GetFieldOffset(const FieldDescriptor* field) const { + if (InRealOneof(field)) { + size_t offset = + static_cast(field->containing_type()->field_count() + + field->containing_oneof()->index()); + return OffsetValue(offsets_[offset], field->type()); + } else { + return GetFieldOffsetNonOneof(field); + } + } + + bool IsFieldInlined(const FieldDescriptor* field) const { + if (InRealOneof(field)) { + size_t offset = + static_cast(field->containing_type()->field_count() + + field->containing_oneof()->index()); + return Inlined(offsets_[offset], field->type()); + } else { + return Inlined(offsets_[field->index()], field->type()); + } + } + + uint32 GetOneofCaseOffset(const OneofDescriptor* oneof_descriptor) const { + return static_cast(oneof_case_offset_) + + static_cast(static_cast(oneof_descriptor->index()) * + sizeof(uint32)); + } + + bool HasHasbits() const { return has_bits_offset_ != -1; } + + // Bit index within the bit array of hasbits. Bit order is low-to-high. + uint32 HasBitIndex(const FieldDescriptor* field) const { + if (has_bits_offset_ == -1) return static_cast(-1); + GOOGLE_DCHECK(HasHasbits()); + return has_bit_indices_[field->index()]; + } + + // Byte offset of the hasbits array. + uint32 HasBitsOffset() const { + GOOGLE_DCHECK(HasHasbits()); + return static_cast(has_bits_offset_); + } + + // The offset of the InternalMetadataWithArena member. + // For Lite this will actually be an InternalMetadataWithArenaLite. + // The schema doesn't contain enough information to distinguish between + // these two cases. + uint32 GetMetadataOffset() const { + return static_cast(metadata_offset_); + } + + // Whether this message has an ExtensionSet. + bool HasExtensionSet() const { return extensions_offset_ != -1; } + + // The offset of the ExtensionSet in this message. + uint32 GetExtensionSetOffset() const { + GOOGLE_DCHECK(HasExtensionSet()); + return static_cast(extensions_offset_); + } + + // The off set of WeakFieldMap when the message contains weak fields. + // The default is 0 for now. + int GetWeakFieldMapOffset() const { return weak_field_map_offset_; } + + bool IsDefaultInstance(const Message& message) const { + return &message == default_instance_; + } + + // Returns a pointer to the default value for this field. The size and type + // of the underlying data depends on the field's type. + const void* GetFieldDefault(const FieldDescriptor* field) const { + return reinterpret_cast(default_instance_) + + OffsetValue(offsets_[field->index()], field->type()); + } + + bool IsFieldStripped(const FieldDescriptor* field) const { + return false; + } + + bool IsMessageStripped(const Descriptor* descriptor) const { + return false; + } + + + bool HasWeakFields() const { return weak_field_map_offset_ > 0; } + + // These members are intended to be private, but we cannot actually make them + // private because this prevents us from using aggregate initialization of + // them, ie. + // + // ReflectionSchema schema = {a, b, c, d, e, ...}; + // private: + const Message* default_instance_; + const uint32* offsets_; + const uint32* has_bit_indices_; + int has_bits_offset_; + int metadata_offset_; + int extensions_offset_; + int oneof_case_offset_; + int object_size_; + int weak_field_map_offset_; + + // We tag offset values to provide additional data about fields (such as + // inlined). + static uint32 OffsetValue(uint32 v, FieldDescriptor::Type type) { + if (type == FieldDescriptor::TYPE_STRING || + type == FieldDescriptor::TYPE_BYTES) { + return v & ~1u; + } else { + return v; + } + } + + static bool Inlined(uint32 v, FieldDescriptor::Type type) { + if (type == FieldDescriptor::TYPE_STRING || + type == FieldDescriptor::TYPE_BYTES) { + return v & 1u; + } else { + // Non string/byte fields are not inlined. + return false; + } + } +}; + +// Structs that the code generator emits directly to describe a message. +// These should never used directly except to build a ReflectionSchema +// object. +// +// EXPERIMENTAL: these are changing rapidly, and may completely disappear +// or merge with ReflectionSchema. +struct MigrationSchema { + int32 offsets_index; + int32 has_bit_indices_index; + int object_size; +}; + +struct SCCInfoBase; + +struct PROTOBUF_EXPORT DescriptorTable { + mutable bool is_initialized; + bool is_eager; + const char* descriptor; + const char* filename; + int size; // of serialized descriptor + once_flag* once; + SCCInfoBase* const* init_default_instances; + const DescriptorTable* const* deps; + int num_sccs; + int num_deps; + const MigrationSchema* schemas; + const Message* const* default_instances; + const uint32* offsets; + // update the following descriptor arrays. + Metadata* file_level_metadata; + int num_messages; + const EnumDescriptor** file_level_enum_descriptors; + const ServiceDescriptor** file_level_service_descriptors; +}; + +// AssignDescriptors() pulls the compiled FileDescriptor from the DescriptorPool +// and uses it to populate all of the global variables which store pointers to +// the descriptor objects. It also constructs the reflection objects. It is +// called the first time anyone calls descriptor() or GetReflection() on one of +// the types defined in the file. AssignDescriptors() is thread-safe. +void PROTOBUF_EXPORT AssignDescriptors(const DescriptorTable* table, + bool eager = false); + +// AddDescriptors() is a file-level procedure which adds the encoded +// FileDescriptorProto for this .proto file to the global DescriptorPool for +// generated files (DescriptorPool::generated_pool()). It ordinarily runs at +// static initialization time, but is not used at all in LITE_RUNTIME mode. +// AddDescriptors() is *not* thread-safe. +void PROTOBUF_EXPORT AddDescriptors(const DescriptorTable* table); + +// These cannot be in lite so we put them in the reflection. +PROTOBUF_EXPORT void UnknownFieldSetSerializer(const uint8* base, uint32 offset, + uint32 tag, uint32 has_offset, + io::CodedOutputStream* output); + +} // namespace internal +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_GENERATED_MESSAGE_REFLECTION_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/generated_message_table_driven.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/generated_message_table_driven.h new file mode 100644 index 0000000000000000000000000000000000000000..7165fd31b5927217ed2628ca6044a2ffc41e5274 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/generated_message_table_driven.h @@ -0,0 +1,344 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef GOOGLE_PROTOBUF_GENERATED_MESSAGE_TABLE_DRIVEN_H__ +#define GOOGLE_PROTOBUF_GENERATED_MESSAGE_TABLE_DRIVEN_H__ + +#include +#include +#include +#include +#include + +// We require C++11 and Clang to use constexpr for variables, as GCC 4.8 +// requires constexpr to be consistent between declarations of variables +// unnecessarily (see https://gcc.gnu.org/bugzilla/show_bug.cgi?id=58541). +// VS 2017 Update 3 also supports this usage of constexpr. +#if defined(__clang__) || (defined(_MSC_VER) && _MSC_VER >= 1911) +#define PROTOBUF_CONSTEXPR_VAR constexpr +#else // !__clang__ +#define PROTOBUF_CONSTEXPR_VAR +#endif // !_clang + +#ifdef SWIG +#error "You cannot SWIG proto headers" +#endif + +#include + +namespace google { +namespace protobuf { +namespace internal { + +// Processing-type masks. +static constexpr const unsigned char kOneofMask = 0x40; +static constexpr const unsigned char kRepeatedMask = 0x20; +// Mask for the raw type: either a WireFormatLite::FieldType or one of the +// ProcessingTypes below, without the oneof or repeated flag. +static constexpr const unsigned char kTypeMask = 0x1f; + +// Wire type masks. +static constexpr const unsigned char kNotPackedMask = 0x10; +static constexpr const unsigned char kInvalidMask = 0x20; + +enum ProcessingTypes { + TYPE_STRING_CORD = 19, + TYPE_STRING_STRING_PIECE = 20, + TYPE_BYTES_CORD = 21, + TYPE_BYTES_STRING_PIECE = 22, + TYPE_STRING_INLINED = 23, + TYPE_BYTES_INLINED = 24, + TYPE_MAP = 25, +}; + +static_assert(TYPE_MAP < kRepeatedMask, "Invalid enum"); + +struct PROTOBUF_EXPORT FieldMetadata { + uint32 offset; // offset of this field in the struct + uint32 tag; // field * 8 + wire_type + // byte offset * 8 + bit_offset; + // if the high bit is set then this is the byte offset of the oneof_case + // for this field. + uint32 has_offset; + uint32 type; // the type of this field. + const void* ptr; // auxiliary data + + // From the serializer point of view each fundamental type can occur in + // 4 different ways. For simplicity we treat all combinations as a cartesion + // product although not all combinations are allowed. + enum FieldTypeClass { + kPresence, + kNoPresence, + kRepeated, + kPacked, + kOneOf, + kNumTypeClasses // must be last enum + }; + // C++ protobuf has 20 fundamental types, were we added Cord and StringPiece + // and also distinquish the same types if they have different wire format. + enum { + kCordType = 19, + kStringPieceType = 20, + kInlinedType = 21, + kNumTypes = 21, + kSpecial = kNumTypes * kNumTypeClasses, + }; + + static int CalculateType(int fundamental_type, FieldTypeClass type_class); +}; + +// TODO(ckennelly): Add a static assertion to ensure that these masks do not +// conflict with wiretypes. + +// ParseTableField is kept small to help simplify instructions for computing +// offsets, as we will always need this information to parse a field. +// Additional data, needed for some types, is stored in +// AuxiliaryParseTableField. +struct ParseTableField { + uint32 offset; + // The presence_index ordinarily represents a has_bit index, but for fields + // inside a oneof it represents the index in _oneof_case_. + uint32 presence_index; + unsigned char normal_wiretype; + unsigned char packed_wiretype; + + // processing_type is given by: + // (FieldDescriptor->type() << 1) | FieldDescriptor->is_packed() + unsigned char processing_type; + + unsigned char tag_size; +}; + +struct ParseTable; + +union AuxiliaryParseTableField { + typedef bool (*EnumValidator)(int); + + // Enums + struct enum_aux { + EnumValidator validator; + }; + enum_aux enums; + // Group, messages + struct message_aux { + // ExplicitlyInitialized -> T requires a reinterpret_cast, which prevents + // the tables from being constructed as a constexpr. We use void to avoid + // the cast. + const void* default_message_void; + const MessageLite* default_message() const { + return static_cast(default_message_void); + } + }; + message_aux messages; + // Strings + struct string_aux { + const void* default_ptr; + const char* field_name; + }; + string_aux strings; + + struct map_aux { + bool (*parse_map)(io::CodedInputStream*, void*); + }; + map_aux maps; + + AuxiliaryParseTableField() = default; + constexpr AuxiliaryParseTableField(AuxiliaryParseTableField::enum_aux e) + : enums(e) {} + constexpr AuxiliaryParseTableField(AuxiliaryParseTableField::message_aux m) + : messages(m) {} + constexpr AuxiliaryParseTableField(AuxiliaryParseTableField::string_aux s) + : strings(s) {} + constexpr AuxiliaryParseTableField(AuxiliaryParseTableField::map_aux m) + : maps(m) {} +}; + +struct ParseTable { + const ParseTableField* fields; + const AuxiliaryParseTableField* aux; + int max_field_number; + // TODO(ckennelly): Do something with this padding. + + // TODO(ckennelly): Vet these for sign extension. + int64 has_bits_offset; + int64 oneof_case_offset; + int64 extension_offset; + int64 arena_offset; + + // ExplicitlyInitialized -> T requires a reinterpret_cast, which prevents + // the tables from being constructed as a constexpr. We use void to avoid + // the cast. + const void* default_instance_void; + const MessageLite* default_instance() const { + return static_cast(default_instance_void); + } + + bool unknown_field_set; +}; + +static_assert(sizeof(ParseTableField) <= 16, "ParseTableField is too large"); +// The tables must be composed of POD components to ensure link-time +// initialization. +static_assert(std::is_pod::value, ""); +static_assert(std::is_pod::value, ""); +static_assert(std::is_pod::value, ""); +static_assert(std::is_pod::value, ""); +static_assert(std::is_pod::value, ""); +static_assert(std::is_pod::value, ""); + +// TODO(ckennelly): Consolidate these implementations into a single one, using +// dynamic dispatch to the appropriate unknown field handler. +bool MergePartialFromCodedStream(MessageLite* msg, const ParseTable& table, + io::CodedInputStream* input); +bool MergePartialFromCodedStreamLite(MessageLite* msg, const ParseTable& table, + io::CodedInputStream* input); + +template +bool ParseMap(io::CodedInputStream* input, void* map_field) { + typedef typename MapEntryToMapField::MapFieldType MapFieldType; + typedef Map + MapType; + typedef typename Entry::template Parser ParserType; + + ParserType parser(static_cast(map_field)); + return WireFormatLite::ReadMessageNoVirtual(input, &parser); +} + +struct SerializationTable { + int num_fields; + const FieldMetadata* field_table; +}; + +PROTOBUF_EXPORT void SerializeInternal(const uint8* base, + const FieldMetadata* table, + int32 num_fields, + io::CodedOutputStream* output); + +inline void TableSerialize(const MessageLite& msg, + const SerializationTable* table, + io::CodedOutputStream* output) { + const FieldMetadata* field_table = table->field_table; + int num_fields = table->num_fields - 1; + const uint8* base = reinterpret_cast(&msg); + // TODO(gerbens) This skips the first test if we could use the fast + // array serialization path, we should make this + // int cached_size = + // *reinterpret_cast(base + field_table->offset); + // SerializeWithCachedSize(msg, field_table + 1, num_fields, cached_size, ...) + // But we keep conformance with the old way for now. + SerializeInternal(base, field_table + 1, num_fields, output); +} + +uint8* SerializeInternalToArray(const uint8* base, const FieldMetadata* table, + int32 num_fields, bool is_deterministic, + uint8* buffer); + +inline uint8* TableSerializeToArray(const MessageLite& msg, + const SerializationTable* table, + bool is_deterministic, uint8* buffer) { + const uint8* base = reinterpret_cast(&msg); + const FieldMetadata* field_table = table->field_table + 1; + int num_fields = table->num_fields - 1; + return SerializeInternalToArray(base, field_table, num_fields, + is_deterministic, buffer); +} + +template +struct CompareHelper { + bool operator()(const T& a, const T& b) const { return a < b; } +}; + +template <> +struct CompareHelper { + bool operator()(const ArenaStringPtr& a, const ArenaStringPtr& b) const { + return a.Get() < b.Get(); + } +}; + +struct CompareMapKey { + template + bool operator()(const MapEntryHelper& a, + const MapEntryHelper& b) const { + return Compare(a.key_, b.key_); + } + template + bool Compare(const T& a, const T& b) const { + return CompareHelper()(a, b); + } +}; + +template +void MapFieldSerializer(const uint8* base, uint32 offset, uint32 tag, + uint32 has_offset, io::CodedOutputStream* output) { + typedef MapEntryHelper Entry; + typedef typename MapFieldType::MapType::const_iterator Iter; + + const MapFieldType& map_field = + *reinterpret_cast(base + offset); + const SerializationTable* t = + table + + has_offset; // has_offset is overloaded for maps to mean table offset + if (!output->IsSerializationDeterministic()) { + for (Iter it = map_field.GetMap().begin(); it != map_field.GetMap().end(); + ++it) { + Entry map_entry(*it); + output->WriteVarint32(tag); + output->WriteVarint32(map_entry._cached_size_); + SerializeInternal(reinterpret_cast(&map_entry), + t->field_table, t->num_fields, output); + } + } else { + std::vector v; + for (Iter it = map_field.GetMap().begin(); it != map_field.GetMap().end(); + ++it) { + v.push_back(Entry(*it)); + } + std::sort(v.begin(), v.end(), CompareMapKey()); + for (int i = 0; i < v.size(); i++) { + output->WriteVarint32(tag); + output->WriteVarint32(v[i]._cached_size_); + SerializeInternal(reinterpret_cast(&v[i]), t->field_table, + t->num_fields, output); + } + } +} + +} // namespace internal +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_GENERATED_MESSAGE_TABLE_DRIVEN_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/generated_message_util.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/generated_message_util.h new file mode 100644 index 0000000000000000000000000000000000000000..4b68e93b9b782a74eefcb0cddf844ec0f6a4da8b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/generated_message_util.h @@ -0,0 +1,265 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: kenton@google.com (Kenton Varda) +// Based on original Protocol Buffers design by +// Sanjay Ghemawat, Jeff Dean, and others. +// +// This file contains miscellaneous helper code used by generated code -- +// including lite types -- but which should not be used directly by users. + +#ifndef GOOGLE_PROTOBUF_GENERATED_MESSAGE_UTIL_H__ +#define GOOGLE_PROTOBUF_GENERATED_MESSAGE_UTIL_H__ + +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include // Add direct dep on port for pb.cc +#include +#include +#include +#include +#include + +#include + +#ifdef SWIG +#error "You cannot SWIG proto headers" +#endif + +namespace google { +namespace protobuf { + +class Arena; +class Message; + +namespace io { +class CodedInputStream; +} + +namespace internal { + +template +inline To DownCast(From* f) { + return PROTOBUF_NAMESPACE_ID::internal::down_cast(f); +} +template +inline To DownCast(From& f) { + return PROTOBUF_NAMESPACE_ID::internal::down_cast(f); +} + + +PROTOBUF_EXPORT void InitProtobufDefaults(); + +// This used by proto1 +PROTOBUF_EXPORT inline const std::string& GetEmptyString() { + InitProtobufDefaults(); + return GetEmptyStringAlreadyInited(); +} + + +// True if IsInitialized() is true for all elements of t. Type is expected +// to be a RepeatedPtrField. It's useful to have this +// helper here to keep the protobuf compiler from ever having to emit loops in +// IsInitialized() methods. We want the C++ compiler to inline this or not +// as it sees fit. +template +bool AllAreInitialized(const RepeatedPtrField& t) { + for (int i = t.size(); --i >= 0;) { + if (!t.Get(i).IsInitialized()) return false; + } + return true; +} + +// "Weak" variant of AllAreInitialized, used to implement implicit weak fields. +// This version operates on MessageLite to avoid introducing a dependency on the +// concrete message type. +template +bool AllAreInitializedWeak(const RepeatedPtrField& t) { + for (int i = t.size(); --i >= 0;) { + if (!reinterpret_cast(t) + .Get >(i) + .IsInitialized()) { + return false; + } + } + return true; +} + +inline bool IsPresent(const void* base, uint32 hasbit) { + const uint32* has_bits_array = static_cast(base); + return (has_bits_array[hasbit / 32] & (1u << (hasbit & 31))) != 0; +} + +inline bool IsOneofPresent(const void* base, uint32 offset, uint32 tag) { + const uint32* oneof = + reinterpret_cast(static_cast(base) + offset); + return *oneof == tag >> 3; +} + +typedef void (*SpecialSerializer)(const uint8* base, uint32 offset, uint32 tag, + uint32 has_offset, + io::CodedOutputStream* output); + +PROTOBUF_EXPORT void ExtensionSerializer(const uint8* base, uint32 offset, + uint32 tag, uint32 has_offset, + io::CodedOutputStream* output); +PROTOBUF_EXPORT void UnknownFieldSerializerLite(const uint8* base, + uint32 offset, uint32 tag, + uint32 has_offset, + io::CodedOutputStream* output); + +PROTOBUF_EXPORT MessageLite* DuplicateIfNonNullInternal(MessageLite* message); +PROTOBUF_EXPORT MessageLite* GetOwnedMessageInternal(Arena* message_arena, + MessageLite* submessage, + Arena* submessage_arena); +PROTOBUF_EXPORT void GenericSwap(MessageLite* m1, MessageLite* m2); +// We specialize GenericSwap for non-lite messages to benefit from reflection. +PROTOBUF_EXPORT void GenericSwap(Message* m1, Message* m2); + +template +T* DuplicateIfNonNull(T* message) { + // The casts must be reinterpret_cast<> because T might be a forward-declared + // type that the compiler doesn't know is related to MessageLite. + return reinterpret_cast( + DuplicateIfNonNullInternal(reinterpret_cast(message))); +} + +template +T* GetOwnedMessage(Arena* message_arena, T* submessage, + Arena* submessage_arena) { + // The casts must be reinterpret_cast<> because T might be a forward-declared + // type that the compiler doesn't know is related to MessageLite. + return reinterpret_cast(GetOwnedMessageInternal( + message_arena, reinterpret_cast(submessage), + submessage_arena)); +} + +// Hide atomic from the public header and allow easy change to regular int +// on platforms where the atomic might have a perf impact. +class PROTOBUF_EXPORT CachedSize { + public: + int Get() const { return size_.load(std::memory_order_relaxed); } + void Set(int size) { size_.store(size, std::memory_order_relaxed); } + + private: + std::atomic size_{0}; +}; + +// SCCInfo represents information of a strongly connected component of +// mutual dependent messages. +struct PROTOBUF_EXPORT SCCInfoBase { + // We use 0 for the Initialized state, because test eax,eax, jnz is smaller + // and is subject to macro fusion. + enum { + kInitialized = 0, // final state + kRunning = 1, + kUninitialized = -1, // initial state + }; +#if defined(_MSC_VER) && !defined(__clang__) + // MSVC doesn't make std::atomic constant initialized. This union trick + // makes it so. + union { + int visit_status_to_make_linker_init; + std::atomic visit_status; + }; +#else + std::atomic visit_status; +#endif + int num_deps; + int num_implicit_weak_deps; + void (*init_func)(); + // This is followed by an array of num_deps + // const SCCInfoBase* deps[]; +}; + +// Zero-length arrays are a language extension available in GCC and Clang but +// not MSVC. +#ifdef __GNUC__ +#define PROTOBUF_ARRAY_SIZE(n) (n) +#else +#define PROTOBUF_ARRAY_SIZE(n) ((n) ? (n) : 1) +#endif + +template +struct SCCInfo { + SCCInfoBase base; + // Semantically this is const SCCInfo* which is is a templated type. + // The obvious inheriting from SCCInfoBase mucks with struct initialization. + // Attempts showed the compiler was generating dynamic initialization code. + // This deps array consists of base.num_deps pointers to SCCInfoBase followed + // by base.num_implicit_weak_deps pointers to SCCInfoBase*. We need the extra + // pointer indirection for implicit weak fields. We cannot use a union type + // here, since that would prevent the array from being linker-initialized. + void* deps[PROTOBUF_ARRAY_SIZE(N)]; +}; + +#undef PROTOBUF_ARRAY_SIZE + +PROTOBUF_EXPORT void InitSCCImpl(SCCInfoBase* scc); + +inline void InitSCC(SCCInfoBase* scc) { + auto status = scc->visit_status.load(std::memory_order_acquire); + if (PROTOBUF_PREDICT_FALSE(status != SCCInfoBase::kInitialized)) + InitSCCImpl(scc); +} + +PROTOBUF_EXPORT void DestroyMessage(const void* message); +PROTOBUF_EXPORT void DestroyString(const void* s); +// Destroy (not delete) the message +inline void OnShutdownDestroyMessage(const void* ptr) { + OnShutdownRun(DestroyMessage, ptr); +} +// Destroy the string (call std::string destructor) +inline void OnShutdownDestroyString(const std::string* ptr) { + OnShutdownRun(DestroyString, ptr); +} + +} // namespace internal +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_GENERATED_MESSAGE_UTIL_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/has_bits.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/has_bits.h new file mode 100644 index 0000000000000000000000000000000000000000..f54c11b9035ed2e68493d957ffb3af016fb90f95 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/has_bits.h @@ -0,0 +1,121 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef GOOGLE_PROTOBUF_HAS_BITS_H__ +#define GOOGLE_PROTOBUF_HAS_BITS_H__ + +#include +#include + +#include + +#ifdef SWIG +#error "You cannot SWIG proto headers" +#endif + +namespace google { +namespace protobuf { +namespace internal { + +template +class HasBits { + public: + HasBits() PROTOBUF_ALWAYS_INLINE { Clear(); } + + void Clear() PROTOBUF_ALWAYS_INLINE { + memset(has_bits_, 0, sizeof(has_bits_)); + } + + uint32& operator[](int index) PROTOBUF_ALWAYS_INLINE { + return has_bits_[index]; + } + + const uint32& operator[](int index) const PROTOBUF_ALWAYS_INLINE { + return has_bits_[index]; + } + + bool operator==(const HasBits& rhs) const { + return memcmp(has_bits_, rhs.has_bits_, sizeof(has_bits_)) == 0; + } + + bool operator!=(const HasBits& rhs) const { + return !(*this == rhs); + } + + void Or(const HasBits& rhs) { + for (size_t i = 0; i < doublewords; i++) has_bits_[i] |= rhs[i]; + } + + bool empty() const; + + private: + uint32 has_bits_[doublewords]; +}; + +template <> +inline bool HasBits<1>::empty() const { + return !has_bits_[0]; +} + +template <> +inline bool HasBits<2>::empty() const { + return !(has_bits_[0] | has_bits_[1]); +} + +template <> +inline bool HasBits<3>::empty() const { + return !(has_bits_[0] | has_bits_[1] | has_bits_[2]); +} + +template <> +inline bool HasBits<4>::empty() const { + return !(has_bits_[0] | has_bits_[1] | has_bits_[2] | has_bits_[3]); +} + +template +inline bool HasBits::empty() const { + for (size_t i = 0; i < doublewords; ++i) { + if (has_bits_[i]) return false; + } + return true; +} + +} // namespace internal +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_HAS_BITS_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/implicit_weak_message.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/implicit_weak_message.h new file mode 100644 index 0000000000000000000000000000000000000000..d373a520a14bf5ab8f0bd6a354e84d0a0f1580ce --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/implicit_weak_message.h @@ -0,0 +1,195 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef GOOGLE_PROTOBUF_IMPLICIT_WEAK_MESSAGE_H__ +#define GOOGLE_PROTOBUF_IMPLICIT_WEAK_MESSAGE_H__ + +#include + +#include +#include +#include +#include + +#ifdef SWIG +#error "You cannot SWIG proto headers" +#endif + +#include + +// This file is logically internal-only and should only be used by protobuf +// generated code. + +namespace google { +namespace protobuf { +namespace internal { + +// An implementation of MessageLite that treats all data as unknown. This type +// acts as a placeholder for an implicit weak field in the case where the true +// message type does not get linked into the binary. +class PROTOBUF_EXPORT ImplicitWeakMessage : public MessageLite { + public: + ImplicitWeakMessage() {} + explicit ImplicitWeakMessage(Arena* arena) : MessageLite(arena) {} + + static const ImplicitWeakMessage* default_instance(); + + std::string GetTypeName() const override { return ""; } + + MessageLite* New() const override { return new ImplicitWeakMessage; } + MessageLite* New(Arena* arena) const override { + return Arena::CreateMessage(arena); + } + + void Clear() override { data_.clear(); } + + bool IsInitialized() const override { return true; } + + void CheckTypeAndMergeFrom(const MessageLite& other) override { + data_.append(static_cast(other).data_); + } + + const char* _InternalParse(const char* ptr, ParseContext* ctx) final; + + size_t ByteSizeLong() const override { return data_.size(); } + + uint8* _InternalSerialize(uint8* target, + io::EpsCopyOutputStream* stream) const final { + return stream->WriteRaw(data_.data(), static_cast(data_.size()), + target); + } + + int GetCachedSize() const override { return static_cast(data_.size()); } + + typedef void InternalArenaConstructable_; + + private: + std::string data_; + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(ImplicitWeakMessage); +}; + +// A type handler for use with implicit weak repeated message fields. +template +class ImplicitWeakTypeHandler { + public: + typedef MessageLite Type; + static constexpr bool Moveable = false; + + static inline MessageLite* NewFromPrototype(const MessageLite* prototype, + Arena* arena = NULL) { + return prototype->New(arena); + } + + static inline void Delete(MessageLite* value, Arena* arena) { + if (arena == NULL) { + delete value; + } + } + static inline Arena* GetArena(MessageLite* value) { + return value->GetArena(); + } + static inline void* GetMaybeArenaPointer(MessageLite* value) { + return value->GetArena(); + } + static inline void Clear(MessageLite* value) { value->Clear(); } + static void Merge(const MessageLite& from, MessageLite* to) { + to->CheckTypeAndMergeFrom(from); + } +}; + +} // namespace internal + +template +struct WeakRepeatedPtrField { + using TypeHandler = internal::ImplicitWeakTypeHandler; + WeakRepeatedPtrField() : weak() {} + explicit WeakRepeatedPtrField(Arena* arena) : weak(arena) {} + ~WeakRepeatedPtrField() { weak.template Destroy(); } + + typedef internal::RepeatedPtrIterator iterator; + typedef internal::RepeatedPtrIterator const_iterator; + typedef internal::RepeatedPtrOverPtrsIterator + pointer_iterator; + typedef internal::RepeatedPtrOverPtrsIterator + const_pointer_iterator; + + iterator begin() { return iterator(base().raw_data()); } + const_iterator begin() const { return iterator(base().raw_data()); } + const_iterator cbegin() const { return begin(); } + iterator end() { return begin() + base().size(); } + const_iterator end() const { return begin() + base().size(); } + const_iterator cend() const { return end(); } + pointer_iterator pointer_begin() { + return pointer_iterator(base().raw_mutable_data()); + } + const_pointer_iterator pointer_begin() const { + return const_pointer_iterator(base().raw_mutable_data()); + } + pointer_iterator pointer_end() { + return pointer_iterator(base().raw_mutable_data() + base().size()); + } + const_pointer_iterator pointer_end() const { + return const_pointer_iterator(base().raw_mutable_data() + base().size()); + } + + MessageLite* AddWeak(const MessageLite* prototype) { + return base().AddWeak(prototype); + } + T* Add() { return weak.Add(); } + void Clear() { base().template Clear(); } + void MergeFrom(const WeakRepeatedPtrField& other) { + base().template MergeFrom(other.base()); + } + void InternalSwap(WeakRepeatedPtrField* other) { + base().InternalSwap(&other->base()); + } + + const internal::RepeatedPtrFieldBase& base() const { return weak; } + internal::RepeatedPtrFieldBase& base() { return weak; } + // Union disables running the destructor. Which would create a strong link. + // Instead we explicitly destroy the underlying base through the virtual + // destructor. + union { + RepeatedPtrField weak; + }; +}; + +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_IMPLICIT_WEAK_MESSAGE_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/inlined_string_field.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/inlined_string_field.h new file mode 100644 index 0000000000000000000000000000000000000000..14337107a154f25afb83039a633a3cfe3c8367e1 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/inlined_string_field.h @@ -0,0 +1,265 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef GOOGLE_PROTOBUF_INLINED_STRING_FIELD_H__ +#define GOOGLE_PROTOBUF_INLINED_STRING_FIELD_H__ + +#include +#include + +#include +#include + +// Must be included last. +#include + +#ifdef SWIG +#error "You cannot SWIG proto headers" +#endif + +namespace google { +namespace protobuf { + +class Arena; + +namespace internal { + +// InlinedStringField wraps a std::string instance and exposes an API similar to +// ArenaStringPtr's wrapping of a std::string* instance. As std::string is +// never allocated on the Arena, we expose only the *NoArena methods of +// ArenaStringPtr. +// +// default_value parameters are taken for consistency with ArenaStringPtr, but +// are not used for most methods. With inlining, these should be removed from +// the generated binary. +class PROTOBUF_EXPORT InlinedStringField { + public: + InlinedStringField() PROTOBUF_ALWAYS_INLINE; + explicit InlinedStringField(const std::string& default_value); + + void AssignWithDefault(const std::string* default_value, + const InlinedStringField& from) PROTOBUF_ALWAYS_INLINE; + + void ClearToEmpty(const std::string* default_value, + Arena* /*arena*/) PROTOBUF_ALWAYS_INLINE { + ClearToEmptyNoArena(default_value); + } + void ClearNonDefaultToEmpty() PROTOBUF_ALWAYS_INLINE { + ClearNonDefaultToEmptyNoArena(); + } + void ClearToEmptyNoArena(const std::string* /*default_value*/) + PROTOBUF_ALWAYS_INLINE { + ClearNonDefaultToEmptyNoArena(); + } + void ClearNonDefaultToEmptyNoArena() PROTOBUF_ALWAYS_INLINE; + + void ClearToDefault(const std::string* default_value, + Arena* /*arena*/) PROTOBUF_ALWAYS_INLINE { + ClearToDefaultNoArena(default_value); + } + void ClearToDefaultNoArena(const std::string* default_value) + PROTOBUF_ALWAYS_INLINE; + + void Destroy(const std::string* default_value, + Arena* /*arena*/) PROTOBUF_ALWAYS_INLINE { + DestroyNoArena(default_value); + } + void DestroyNoArena(const std::string* default_value) PROTOBUF_ALWAYS_INLINE; + + const std::string& Get() const PROTOBUF_ALWAYS_INLINE { return GetNoArena(); } + const std::string& GetNoArena() const PROTOBUF_ALWAYS_INLINE; + + std::string* Mutable(const std::string* default_value, + Arena* /*arena*/) PROTOBUF_ALWAYS_INLINE { + return MutableNoArena(default_value); + } + std::string* MutableNoArena(const std::string* default_value) + PROTOBUF_ALWAYS_INLINE; + + std::string* Release(const std::string* default_value, Arena* /*arena*/) { + return ReleaseNoArena(default_value); + } + std::string* ReleaseNonDefault(const std::string* default_value, + Arena* /*arena*/) { + return ReleaseNonDefaultNoArena(default_value); + } + std::string* ReleaseNoArena(const std::string* default_value) { + return ReleaseNonDefaultNoArena(default_value); + } + std::string* ReleaseNonDefaultNoArena(const std::string* default_value); + + void Set(const std::string* default_value, StringPiece value, + Arena* /*arena*/) PROTOBUF_ALWAYS_INLINE { + SetNoArena(default_value, value); + } + void SetLite(const std::string* default_value, StringPiece value, + Arena* /*arena*/) PROTOBUF_ALWAYS_INLINE { + SetNoArena(default_value, value); + } + void SetNoArena(const std::string* default_value, + StringPiece value) PROTOBUF_ALWAYS_INLINE; + + void Set(const std::string* default_value, const std::string& value, + Arena* /*arena*/) PROTOBUF_ALWAYS_INLINE { + SetNoArena(default_value, value); + } + void SetLite(const std::string* default_value, const std::string& value, + Arena* /*arena*/) PROTOBUF_ALWAYS_INLINE { + SetNoArena(default_value, value); + } + void SetNoArena(const std::string* default_value, + const std::string& value) PROTOBUF_ALWAYS_INLINE; + + void SetNoArena(const std::string* default_value, + std::string&& value) PROTOBUF_ALWAYS_INLINE; + void SetAllocated(const std::string* default_value, std::string* value, + Arena* /*arena*/) { + SetAllocatedNoArena(default_value, value); + } + void SetAllocatedNoArena(const std::string* default_value, + std::string* value); + void Swap(InlinedStringField* from) PROTOBUF_ALWAYS_INLINE; + std::string* UnsafeMutablePointer(); + void UnsafeSetDefault(const std::string* default_value); + std::string* UnsafeArenaRelease(const std::string* default_value, + Arena* arena); + void UnsafeArenaSetAllocated(const std::string* default_value, + std::string* value, Arena* arena); + + bool IsDefault(const std::string* /*default_value*/) { return false; } + + private: + std::string value_; +}; + +inline InlinedStringField::InlinedStringField() {} + +inline InlinedStringField::InlinedStringField(const std::string& default_value) + : value_(default_value) {} + +inline void InlinedStringField::AssignWithDefault( + const std::string* /*default_value*/, const InlinedStringField& from) { + value_ = from.value_; +} + +inline const std::string& InlinedStringField::GetNoArena() const { + return value_; +} + +inline std::string* InlinedStringField::MutableNoArena(const std::string*) { + return &value_; +} + +inline void InlinedStringField::SetAllocatedNoArena( + const std::string* default_value, std::string* value) { + if (value == NULL) { + value_.assign(*default_value); + } else { + value_.assign(std::move(*value)); + delete value; + } +} + +inline void InlinedStringField::DestroyNoArena(const std::string*) { + // This is invoked from the generated message's ArenaDtor, which is used to + // clean up objects not allocated on the Arena. + this->~InlinedStringField(); +} + +inline void InlinedStringField::ClearNonDefaultToEmptyNoArena() { + value_.clear(); +} + +inline void InlinedStringField::ClearToDefaultNoArena( + const std::string* default_value) { + value_.assign(*default_value); +} + +inline std::string* InlinedStringField::ReleaseNonDefaultNoArena( + const std::string* default_value) { + std::string* released = new std::string(*default_value); + value_.swap(*released); + return released; +} + +inline void InlinedStringField::SetNoArena(const std::string* /*default_value*/, + StringPiece value) { + value_.assign(value.data(), value.length()); +} + +inline void InlinedStringField::SetNoArena(const std::string* /*default_value*/, + const std::string& value) { + value_.assign(value); +} + +inline void InlinedStringField::SetNoArena(const std::string* /*default_value*/, + std::string&& value) { + value_.assign(std::move(value)); +} + +inline void InlinedStringField::Swap(InlinedStringField* from) { + value_.swap(from->value_); +} + +inline std::string* InlinedStringField::UnsafeMutablePointer() { + return &value_; +} + +inline void InlinedStringField::UnsafeSetDefault( + const std::string* default_value) { + value_.assign(*default_value); +} + +inline std::string* InlinedStringField::UnsafeArenaRelease( + const std::string* default_value, Arena* /*arena*/) { + return ReleaseNoArena(default_value); +} + +inline void InlinedStringField::UnsafeArenaSetAllocated( + const std::string* default_value, std::string* value, Arena* /*arena*/) { + if (value == NULL) { + value_.assign(*default_value); + } else { + value_.assign(*value); + } +} + +} // namespace internal +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_INLINED_STRING_FIELD_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/io/coded_stream.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/io/coded_stream.h new file mode 100644 index 0000000000000000000000000000000000000000..061d60cd71990af74cabfd23441975367ff636ce --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/io/coded_stream.h @@ -0,0 +1,1719 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: kenton@google.com (Kenton Varda) +// Based on original Protocol Buffers design by +// Sanjay Ghemawat, Jeff Dean, and others. +// +// This file contains the CodedInputStream and CodedOutputStream classes, +// which wrap a ZeroCopyInputStream or ZeroCopyOutputStream, respectively, +// and allow you to read or write individual pieces of data in various +// formats. In particular, these implement the varint encoding for +// integers, a simple variable-length encoding in which smaller numbers +// take fewer bytes. +// +// Typically these classes will only be used internally by the protocol +// buffer library in order to encode and decode protocol buffers. Clients +// of the library only need to know about this class if they wish to write +// custom message parsing or serialization procedures. +// +// CodedOutputStream example: +// // Write some data to "myfile". First we write a 4-byte "magic number" +// // to identify the file type, then write a length-delimited string. The +// // string is composed of a varint giving the length followed by the raw +// // bytes. +// int fd = open("myfile", O_CREAT | O_WRONLY); +// ZeroCopyOutputStream* raw_output = new FileOutputStream(fd); +// CodedOutputStream* coded_output = new CodedOutputStream(raw_output); +// +// int magic_number = 1234; +// char text[] = "Hello world!"; +// coded_output->WriteLittleEndian32(magic_number); +// coded_output->WriteVarint32(strlen(text)); +// coded_output->WriteRaw(text, strlen(text)); +// +// delete coded_output; +// delete raw_output; +// close(fd); +// +// CodedInputStream example: +// // Read a file created by the above code. +// int fd = open("myfile", O_RDONLY); +// ZeroCopyInputStream* raw_input = new FileInputStream(fd); +// CodedInputStream* coded_input = new CodedInputStream(raw_input); +// +// coded_input->ReadLittleEndian32(&magic_number); +// if (magic_number != 1234) { +// cerr << "File not in expected format." << endl; +// return; +// } +// +// uint32 size; +// coded_input->ReadVarint32(&size); +// +// char* text = new char[size + 1]; +// coded_input->ReadRaw(buffer, size); +// text[size] = '\0'; +// +// delete coded_input; +// delete raw_input; +// close(fd); +// +// cout << "Text is: " << text << endl; +// delete [] text; +// +// For those who are interested, varint encoding is defined as follows: +// +// The encoding operates on unsigned integers of up to 64 bits in length. +// Each byte of the encoded value has the format: +// * bits 0-6: Seven bits of the number being encoded. +// * bit 7: Zero if this is the last byte in the encoding (in which +// case all remaining bits of the number are zero) or 1 if +// more bytes follow. +// The first byte contains the least-significant 7 bits of the number, the +// second byte (if present) contains the next-least-significant 7 bits, +// and so on. So, the binary number 1011000101011 would be encoded in two +// bytes as "10101011 00101100". +// +// In theory, varint could be used to encode integers of any length. +// However, for practicality we set a limit at 64 bits. The maximum encoded +// length of a number is thus 10 bytes. + +#ifndef GOOGLE_PROTOBUF_IO_CODED_STREAM_H__ +#define GOOGLE_PROTOBUF_IO_CODED_STREAM_H__ + + +#include + +#include +#include +#include +#include +#include +#include +#include + +#ifdef _MSC_VER +// Assuming windows is always little-endian. +#if !defined(PROTOBUF_DISABLE_LITTLE_ENDIAN_OPT_FOR_TEST) +#define PROTOBUF_LITTLE_ENDIAN 1 +#endif +#if _MSC_VER >= 1300 && !defined(__INTEL_COMPILER) +// If MSVC has "/RTCc" set, it will complain about truncating casts at +// runtime. This file contains some intentional truncating casts. +#pragma runtime_checks("c", off) +#endif +#else +#include // __BYTE_ORDER +#if ((defined(__LITTLE_ENDIAN__) && !defined(__BIG_ENDIAN__)) || \ + (defined(__BYTE_ORDER) && __BYTE_ORDER == __LITTLE_ENDIAN)) && \ + !defined(PROTOBUF_DISABLE_LITTLE_ENDIAN_OPT_FOR_TEST) +#define PROTOBUF_LITTLE_ENDIAN 1 +#endif +#endif +#include +#include +#include +#include +#include + + +#include + +namespace google { +namespace protobuf { + +class DescriptorPool; +class MessageFactory; +class ZeroCopyCodedInputStream; + +namespace internal { +void MapTestForceDeterministic(); +class EpsCopyByteStream; +} // namespace internal + +namespace io { + +// Defined in this file. +class CodedInputStream; +class CodedOutputStream; + +// Defined in other files. +class ZeroCopyInputStream; // zero_copy_stream.h +class ZeroCopyOutputStream; // zero_copy_stream.h + +// Class which reads and decodes binary data which is composed of varint- +// encoded integers and fixed-width pieces. Wraps a ZeroCopyInputStream. +// Most users will not need to deal with CodedInputStream. +// +// Most methods of CodedInputStream that return a bool return false if an +// underlying I/O error occurs or if the data is malformed. Once such a +// failure occurs, the CodedInputStream is broken and is no longer useful. +// After a failure, callers also should assume writes to "out" args may have +// occurred, though nothing useful can be determined from those writes. +class PROTOBUF_EXPORT CodedInputStream { + public: + // Create a CodedInputStream that reads from the given ZeroCopyInputStream. + explicit CodedInputStream(ZeroCopyInputStream* input); + + // Create a CodedInputStream that reads from the given flat array. This is + // faster than using an ArrayInputStream. PushLimit(size) is implied by + // this constructor. + explicit CodedInputStream(const uint8* buffer, int size); + + // Destroy the CodedInputStream and position the underlying + // ZeroCopyInputStream at the first unread byte. If an error occurred while + // reading (causing a method to return false), then the exact position of + // the input stream may be anywhere between the last value that was read + // successfully and the stream's byte limit. + ~CodedInputStream(); + + // Return true if this CodedInputStream reads from a flat array instead of + // a ZeroCopyInputStream. + inline bool IsFlat() const; + + // Skips a number of bytes. Returns false if an underlying read error + // occurs. + inline bool Skip(int count); + + // Sets *data to point directly at the unread part of the CodedInputStream's + // underlying buffer, and *size to the size of that buffer, but does not + // advance the stream's current position. This will always either produce + // a non-empty buffer or return false. If the caller consumes any of + // this data, it should then call Skip() to skip over the consumed bytes. + // This may be useful for implementing external fast parsing routines for + // types of data not covered by the CodedInputStream interface. + bool GetDirectBufferPointer(const void** data, int* size); + + // Like GetDirectBufferPointer, but this method is inlined, and does not + // attempt to Refresh() if the buffer is currently empty. + PROTOBUF_ALWAYS_INLINE + void GetDirectBufferPointerInline(const void** data, int* size); + + // Read raw bytes, copying them into the given buffer. + bool ReadRaw(void* buffer, int size); + + // Like ReadRaw, but reads into a string. + bool ReadString(std::string* buffer, int size); + + + // Read a 32-bit little-endian integer. + bool ReadLittleEndian32(uint32* value); + // Read a 64-bit little-endian integer. + bool ReadLittleEndian64(uint64* value); + + // These methods read from an externally provided buffer. The caller is + // responsible for ensuring that the buffer has sufficient space. + // Read a 32-bit little-endian integer. + static const uint8* ReadLittleEndian32FromArray(const uint8* buffer, + uint32* value); + // Read a 64-bit little-endian integer. + static const uint8* ReadLittleEndian64FromArray(const uint8* buffer, + uint64* value); + + // Read an unsigned integer with Varint encoding, truncating to 32 bits. + // Reading a 32-bit value is equivalent to reading a 64-bit one and casting + // it to uint32, but may be more efficient. + bool ReadVarint32(uint32* value); + // Read an unsigned integer with Varint encoding. + bool ReadVarint64(uint64* value); + + // Reads a varint off the wire into an "int". This should be used for reading + // sizes off the wire (sizes of strings, submessages, bytes fields, etc). + // + // The value from the wire is interpreted as unsigned. If its value exceeds + // the representable value of an integer on this platform, instead of + // truncating we return false. Truncating (as performed by ReadVarint32() + // above) is an acceptable approach for fields representing an integer, but + // when we are parsing a size from the wire, truncating the value would result + // in us misparsing the payload. + bool ReadVarintSizeAsInt(int* value); + + // Read a tag. This calls ReadVarint32() and returns the result, or returns + // zero (which is not a valid tag) if ReadVarint32() fails. Also, ReadTag + // (but not ReadTagNoLastTag) updates the last tag value, which can be checked + // with LastTagWas(). + // + // Always inline because this is only called in one place per parse loop + // but it is called for every iteration of said loop, so it should be fast. + // GCC doesn't want to inline this by default. + PROTOBUF_ALWAYS_INLINE uint32 ReadTag() { + return last_tag_ = ReadTagNoLastTag(); + } + + PROTOBUF_ALWAYS_INLINE uint32 ReadTagNoLastTag(); + + // This usually a faster alternative to ReadTag() when cutoff is a manifest + // constant. It does particularly well for cutoff >= 127. The first part + // of the return value is the tag that was read, though it can also be 0 in + // the cases where ReadTag() would return 0. If the second part is true + // then the tag is known to be in [0, cutoff]. If not, the tag either is + // above cutoff or is 0. (There's intentional wiggle room when tag is 0, + // because that can arise in several ways, and for best performance we want + // to avoid an extra "is tag == 0?" check here.) + PROTOBUF_ALWAYS_INLINE + std::pair ReadTagWithCutoff(uint32 cutoff) { + std::pair result = ReadTagWithCutoffNoLastTag(cutoff); + last_tag_ = result.first; + return result; + } + + PROTOBUF_ALWAYS_INLINE + std::pair ReadTagWithCutoffNoLastTag(uint32 cutoff); + + // Usually returns true if calling ReadVarint32() now would produce the given + // value. Will always return false if ReadVarint32() would not return the + // given value. If ExpectTag() returns true, it also advances past + // the varint. For best performance, use a compile-time constant as the + // parameter. + // Always inline because this collapses to a small number of instructions + // when given a constant parameter, but GCC doesn't want to inline by default. + PROTOBUF_ALWAYS_INLINE bool ExpectTag(uint32 expected); + + // Like above, except this reads from the specified buffer. The caller is + // responsible for ensuring that the buffer is large enough to read a varint + // of the expected size. For best performance, use a compile-time constant as + // the expected tag parameter. + // + // Returns a pointer beyond the expected tag if it was found, or NULL if it + // was not. + PROTOBUF_ALWAYS_INLINE + static const uint8* ExpectTagFromArray(const uint8* buffer, uint32 expected); + + // Usually returns true if no more bytes can be read. Always returns false + // if more bytes can be read. If ExpectAtEnd() returns true, a subsequent + // call to LastTagWas() will act as if ReadTag() had been called and returned + // zero, and ConsumedEntireMessage() will return true. + bool ExpectAtEnd(); + + // If the last call to ReadTag() or ReadTagWithCutoff() returned the given + // value, returns true. Otherwise, returns false. + // ReadTagNoLastTag/ReadTagWithCutoffNoLastTag do not preserve the last + // returned value. + // + // This is needed because parsers for some types of embedded messages + // (with field type TYPE_GROUP) don't actually know that they've reached the + // end of a message until they see an ENDGROUP tag, which was actually part + // of the enclosing message. The enclosing message would like to check that + // tag to make sure it had the right number, so it calls LastTagWas() on + // return from the embedded parser to check. + bool LastTagWas(uint32 expected); + void SetLastTag(uint32 tag) { last_tag_ = tag; } + + // When parsing message (but NOT a group), this method must be called + // immediately after MergeFromCodedStream() returns (if it returns true) + // to further verify that the message ended in a legitimate way. For + // example, this verifies that parsing did not end on an end-group tag. + // It also checks for some cases where, due to optimizations, + // MergeFromCodedStream() can incorrectly return true. + bool ConsumedEntireMessage(); + void SetConsumed() { legitimate_message_end_ = true; } + + // Limits ---------------------------------------------------------- + // Limits are used when parsing length-delimited embedded messages. + // After the message's length is read, PushLimit() is used to prevent + // the CodedInputStream from reading beyond that length. Once the + // embedded message has been parsed, PopLimit() is called to undo the + // limit. + + // Opaque type used with PushLimit() and PopLimit(). Do not modify + // values of this type yourself. The only reason that this isn't a + // struct with private internals is for efficiency. + typedef int Limit; + + // Places a limit on the number of bytes that the stream may read, + // starting from the current position. Once the stream hits this limit, + // it will act like the end of the input has been reached until PopLimit() + // is called. + // + // As the names imply, the stream conceptually has a stack of limits. The + // shortest limit on the stack is always enforced, even if it is not the + // top limit. + // + // The value returned by PushLimit() is opaque to the caller, and must + // be passed unchanged to the corresponding call to PopLimit(). + Limit PushLimit(int byte_limit); + + // Pops the last limit pushed by PushLimit(). The input must be the value + // returned by that call to PushLimit(). + void PopLimit(Limit limit); + + // Returns the number of bytes left until the nearest limit on the + // stack is hit, or -1 if no limits are in place. + int BytesUntilLimit() const; + + // Returns current position relative to the beginning of the input stream. + int CurrentPosition() const; + + // Total Bytes Limit ----------------------------------------------- + // To prevent malicious users from sending excessively large messages + // and causing memory exhaustion, CodedInputStream imposes a hard limit on + // the total number of bytes it will read. + + // Sets the maximum number of bytes that this CodedInputStream will read + // before refusing to continue. To prevent servers from allocating enormous + // amounts of memory to hold parsed messages, the maximum message length + // should be limited to the shortest length that will not harm usability. + // The default limit is INT_MAX (~2GB) and apps should set shorter limits + // if possible. An error will always be printed to stderr if the limit is + // reached. + // + // Note: setting a limit less than the current read position is interpreted + // as a limit on the current position. + // + // This is unrelated to PushLimit()/PopLimit(). + void SetTotalBytesLimit(int total_bytes_limit); + + PROTOBUF_DEPRECATED_MSG( + "Please use the single parameter version of SetTotalBytesLimit(). The " + "second parameter is ignored.") + void SetTotalBytesLimit(int total_bytes_limit, int) { + SetTotalBytesLimit(total_bytes_limit); + } + + // The Total Bytes Limit minus the Current Position, or -1 if the total bytes + // limit is INT_MAX. + int BytesUntilTotalBytesLimit() const; + + // Recursion Limit ------------------------------------------------- + // To prevent corrupt or malicious messages from causing stack overflows, + // we must keep track of the depth of recursion when parsing embedded + // messages and groups. CodedInputStream keeps track of this because it + // is the only object that is passed down the stack during parsing. + + // Sets the maximum recursion depth. The default is 100. + void SetRecursionLimit(int limit); + int RecursionBudget() { return recursion_budget_; } + + static int GetDefaultRecursionLimit() { return default_recursion_limit_; } + + // Increments the current recursion depth. Returns true if the depth is + // under the limit, false if it has gone over. + bool IncrementRecursionDepth(); + + // Decrements the recursion depth if possible. + void DecrementRecursionDepth(); + + // Decrements the recursion depth blindly. This is faster than + // DecrementRecursionDepth(). It should be used only if all previous + // increments to recursion depth were successful. + void UnsafeDecrementRecursionDepth(); + + // Shorthand for make_pair(PushLimit(byte_limit), --recursion_budget_). + // Using this can reduce code size and complexity in some cases. The caller + // is expected to check that the second part of the result is non-negative (to + // bail out if the depth of recursion is too high) and, if all is well, to + // later pass the first part of the result to PopLimit() or similar. + std::pair IncrementRecursionDepthAndPushLimit( + int byte_limit); + + // Shorthand for PushLimit(ReadVarint32(&length) ? length : 0). + Limit ReadLengthAndPushLimit(); + + // Helper that is equivalent to: { + // bool result = ConsumedEntireMessage(); + // PopLimit(limit); + // UnsafeDecrementRecursionDepth(); + // return result; } + // Using this can reduce code size and complexity in some cases. + // Do not use unless the current recursion depth is greater than zero. + bool DecrementRecursionDepthAndPopLimit(Limit limit); + + // Helper that is equivalent to: { + // bool result = ConsumedEntireMessage(); + // PopLimit(limit); + // return result; } + // Using this can reduce code size and complexity in some cases. + bool CheckEntireMessageConsumedAndPopLimit(Limit limit); + + // Extension Registry ---------------------------------------------- + // ADVANCED USAGE: 99.9% of people can ignore this section. + // + // By default, when parsing extensions, the parser looks for extension + // definitions in the pool which owns the outer message's Descriptor. + // However, you may call SetExtensionRegistry() to provide an alternative + // pool instead. This makes it possible, for example, to parse a message + // using a generated class, but represent some extensions using + // DynamicMessage. + + // Set the pool used to look up extensions. Most users do not need to call + // this as the correct pool will be chosen automatically. + // + // WARNING: It is very easy to misuse this. Carefully read the requirements + // below. Do not use this unless you are sure you need it. Almost no one + // does. + // + // Let's say you are parsing a message into message object m, and you want + // to take advantage of SetExtensionRegistry(). You must follow these + // requirements: + // + // The given DescriptorPool must contain m->GetDescriptor(). It is not + // sufficient for it to simply contain a descriptor that has the same name + // and content -- it must be the *exact object*. In other words: + // assert(pool->FindMessageTypeByName(m->GetDescriptor()->full_name()) == + // m->GetDescriptor()); + // There are two ways to satisfy this requirement: + // 1) Use m->GetDescriptor()->pool() as the pool. This is generally useless + // because this is the pool that would be used anyway if you didn't call + // SetExtensionRegistry() at all. + // 2) Use a DescriptorPool which has m->GetDescriptor()->pool() as an + // "underlay". Read the documentation for DescriptorPool for more + // information about underlays. + // + // You must also provide a MessageFactory. This factory will be used to + // construct Message objects representing extensions. The factory's + // GetPrototype() MUST return non-NULL for any Descriptor which can be found + // through the provided pool. + // + // If the provided factory might return instances of protocol-compiler- + // generated (i.e. compiled-in) types, or if the outer message object m is + // a generated type, then the given factory MUST have this property: If + // GetPrototype() is given a Descriptor which resides in + // DescriptorPool::generated_pool(), the factory MUST return the same + // prototype which MessageFactory::generated_factory() would return. That + // is, given a descriptor for a generated type, the factory must return an + // instance of the generated class (NOT DynamicMessage). However, when + // given a descriptor for a type that is NOT in generated_pool, the factory + // is free to return any implementation. + // + // The reason for this requirement is that generated sub-objects may be + // accessed via the standard (non-reflection) extension accessor methods, + // and these methods will down-cast the object to the generated class type. + // If the object is not actually of that type, the results would be undefined. + // On the other hand, if an extension is not compiled in, then there is no + // way the code could end up accessing it via the standard accessors -- the + // only way to access the extension is via reflection. When using reflection, + // DynamicMessage and generated messages are indistinguishable, so it's fine + // if these objects are represented using DynamicMessage. + // + // Using DynamicMessageFactory on which you have called + // SetDelegateToGeneratedFactory(true) should be sufficient to satisfy the + // above requirement. + // + // If either pool or factory is NULL, both must be NULL. + // + // Note that this feature is ignored when parsing "lite" messages as they do + // not have descriptors. + void SetExtensionRegistry(const DescriptorPool* pool, + MessageFactory* factory); + + // Get the DescriptorPool set via SetExtensionRegistry(), or NULL if no pool + // has been provided. + const DescriptorPool* GetExtensionPool(); + + // Get the MessageFactory set via SetExtensionRegistry(), or NULL if no + // factory has been provided. + MessageFactory* GetExtensionFactory(); + + private: + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(CodedInputStream); + + const uint8* buffer_; + const uint8* buffer_end_; // pointer to the end of the buffer. + ZeroCopyInputStream* input_; + int total_bytes_read_; // total bytes read from input_, including + // the current buffer + + // If total_bytes_read_ surpasses INT_MAX, we record the extra bytes here + // so that we can BackUp() on destruction. + int overflow_bytes_; + + // LastTagWas() stuff. + uint32 last_tag_; // result of last ReadTag() or ReadTagWithCutoff(). + + // This is set true by ReadTag{Fallback/Slow}() if it is called when exactly + // at EOF, or by ExpectAtEnd() when it returns true. This happens when we + // reach the end of a message and attempt to read another tag. + bool legitimate_message_end_; + + // See EnableAliasing(). + bool aliasing_enabled_; + + // Limits + Limit current_limit_; // if position = -1, no limit is applied + + // For simplicity, if the current buffer crosses a limit (either a normal + // limit created by PushLimit() or the total bytes limit), buffer_size_ + // only tracks the number of bytes before that limit. This field + // contains the number of bytes after it. Note that this implies that if + // buffer_size_ == 0 and buffer_size_after_limit_ > 0, we know we've + // hit a limit. However, if both are zero, it doesn't necessarily mean + // we aren't at a limit -- the buffer may have ended exactly at the limit. + int buffer_size_after_limit_; + + // Maximum number of bytes to read, period. This is unrelated to + // current_limit_. Set using SetTotalBytesLimit(). + int total_bytes_limit_; + + // Current recursion budget, controlled by IncrementRecursionDepth() and + // similar. Starts at recursion_limit_ and goes down: if this reaches + // -1 we are over budget. + int recursion_budget_; + // Recursion depth limit, set by SetRecursionLimit(). + int recursion_limit_; + + // See SetExtensionRegistry(). + const DescriptorPool* extension_pool_; + MessageFactory* extension_factory_; + + // Private member functions. + + // Fallback when Skip() goes past the end of the current buffer. + bool SkipFallback(int count, int original_buffer_size); + + // Advance the buffer by a given number of bytes. + void Advance(int amount); + + // Back up input_ to the current buffer position. + void BackUpInputToCurrentPosition(); + + // Recomputes the value of buffer_size_after_limit_. Must be called after + // current_limit_ or total_bytes_limit_ changes. + void RecomputeBufferLimits(); + + // Writes an error message saying that we hit total_bytes_limit_. + void PrintTotalBytesLimitError(); + + // Called when the buffer runs out to request more data. Implies an + // Advance(BufferSize()). + bool Refresh(); + + // When parsing varints, we optimize for the common case of small values, and + // then optimize for the case when the varint fits within the current buffer + // piece. The Fallback method is used when we can't use the one-byte + // optimization. The Slow method is yet another fallback when the buffer is + // not large enough. Making the slow path out-of-line speeds up the common + // case by 10-15%. The slow path is fairly uncommon: it only triggers when a + // message crosses multiple buffers. Note: ReadVarint32Fallback() and + // ReadVarint64Fallback() are called frequently and generally not inlined, so + // they have been optimized to avoid "out" parameters. The former returns -1 + // if it fails and the uint32 it read otherwise. The latter has a bool + // indicating success or failure as part of its return type. + int64 ReadVarint32Fallback(uint32 first_byte_or_zero); + int ReadVarintSizeAsIntFallback(); + std::pair ReadVarint64Fallback(); + bool ReadVarint32Slow(uint32* value); + bool ReadVarint64Slow(uint64* value); + int ReadVarintSizeAsIntSlow(); + bool ReadLittleEndian32Fallback(uint32* value); + bool ReadLittleEndian64Fallback(uint64* value); + + // Fallback/slow methods for reading tags. These do not update last_tag_, + // but will set legitimate_message_end_ if we are at the end of the input + // stream. + uint32 ReadTagFallback(uint32 first_byte_or_zero); + uint32 ReadTagSlow(); + bool ReadStringFallback(std::string* buffer, int size); + + // Return the size of the buffer. + int BufferSize() const; + + static const int kDefaultTotalBytesLimit = INT_MAX; + + static int default_recursion_limit_; // 100 by default. + + friend class google::protobuf::ZeroCopyCodedInputStream; + friend class google::protobuf::internal::EpsCopyByteStream; +}; + +// EpsCopyOutputStream wraps a ZeroCopyOutputStream and exposes a new stream, +// which has the property you can write kSlopBytes (16 bytes) from the current +// position without bounds checks. The cursor into the stream is managed by +// the user of the class and is an explicit parameter in the methods. Careful +// use of this class, ie. keep ptr a local variable, eliminates the need to +// for the compiler to sync the ptr value between register and memory. +class PROTOBUF_EXPORT EpsCopyOutputStream { + public: + enum { kSlopBytes = 16 }; + + // Initialize from a stream. + EpsCopyOutputStream(ZeroCopyOutputStream* stream, bool deterministic, + uint8** pp) + : end_(buffer_), + stream_(stream), + is_serialization_deterministic_(deterministic) { + *pp = buffer_; + } + + // Only for array serialization. No overflow protection, end_ will be the + // pointed to the end of the array. When using this the total size is already + // known, so no need to maintain the slop region. + EpsCopyOutputStream(void* data, int size, bool deterministic) + : end_(static_cast(data) + size), + buffer_end_(nullptr), + stream_(nullptr), + is_serialization_deterministic_(deterministic) {} + + // Initialize from stream but with the first buffer already given (eager). + EpsCopyOutputStream(void* data, int size, ZeroCopyOutputStream* stream, + bool deterministic, uint8** pp) + : stream_(stream), is_serialization_deterministic_(deterministic) { + *pp = SetInitialBuffer(data, size); + } + + // Flush everything that's written into the underlying ZeroCopyOutputStream + // and trims the underlying stream to the location of ptr. + uint8* Trim(uint8* ptr); + + // After this it's guaranteed you can safely write kSlopBytes to ptr. This + // will never fail! The underlying stream can produce an error. Use HadError + // to check for errors. + PROTOBUF_MUST_USE_RESULT uint8* EnsureSpace(uint8* ptr) { + if (PROTOBUF_PREDICT_FALSE(ptr >= end_)) { + return EnsureSpaceFallback(ptr); + } + return ptr; + } + + uint8* WriteRaw(const void* data, int size, uint8* ptr) { + if (PROTOBUF_PREDICT_FALSE(end_ - ptr < size)) { + return WriteRawFallback(data, size, ptr); + } + std::memcpy(ptr, data, size); + return ptr + size; + } + // Writes the buffer specified by data, size to the stream. Possibly by + // aliasing the buffer (ie. not copying the data). The caller is responsible + // to make sure the buffer is alive for the duration of the + // ZeroCopyOutputStream. + uint8* WriteRawMaybeAliased(const void* data, int size, uint8* ptr) { + if (aliasing_enabled_) { + return WriteAliasedRaw(data, size, ptr); + } else { + return WriteRaw(data, size, ptr); + } + } + + + uint8* WriteStringMaybeAliased(uint32 num, const std::string& s, uint8* ptr) { + std::ptrdiff_t size = s.size(); + if (PROTOBUF_PREDICT_FALSE( + size >= 128 || end_ - ptr + 16 - TagSize(num << 3) - 1 < size)) { + return WriteStringMaybeAliasedOutline(num, s, ptr); + } + ptr = UnsafeVarint((num << 3) | 2, ptr); + *ptr++ = static_cast(size); + std::memcpy(ptr, s.data(), size); + return ptr + size; + } + uint8* WriteBytesMaybeAliased(uint32 num, const std::string& s, uint8* ptr) { + return WriteStringMaybeAliased(num, s, ptr); + } + + template + PROTOBUF_ALWAYS_INLINE uint8* WriteString(uint32 num, const T& s, + uint8* ptr) { + std::ptrdiff_t size = s.size(); + if (PROTOBUF_PREDICT_FALSE( + size >= 128 || end_ - ptr + 16 - TagSize(num << 3) - 1 < size)) { + return WriteStringOutline(num, s, ptr); + } + ptr = UnsafeVarint((num << 3) | 2, ptr); + *ptr++ = static_cast(size); + std::memcpy(ptr, s.data(), size); + return ptr + size; + } + template + uint8* WriteBytes(uint32 num, const T& s, uint8* ptr) { + return WriteString(num, s, ptr); + } + + template + PROTOBUF_ALWAYS_INLINE uint8* WriteInt32Packed(int num, const T& r, int size, + uint8* ptr) { + return WriteVarintPacked(num, r, size, ptr, Encode64); + } + template + PROTOBUF_ALWAYS_INLINE uint8* WriteUInt32Packed(int num, const T& r, int size, + uint8* ptr) { + return WriteVarintPacked(num, r, size, ptr, Encode32); + } + template + PROTOBUF_ALWAYS_INLINE uint8* WriteSInt32Packed(int num, const T& r, int size, + uint8* ptr) { + return WriteVarintPacked(num, r, size, ptr, ZigZagEncode32); + } + template + PROTOBUF_ALWAYS_INLINE uint8* WriteInt64Packed(int num, const T& r, int size, + uint8* ptr) { + return WriteVarintPacked(num, r, size, ptr, Encode64); + } + template + PROTOBUF_ALWAYS_INLINE uint8* WriteUInt64Packed(int num, const T& r, int size, + uint8* ptr) { + return WriteVarintPacked(num, r, size, ptr, Encode64); + } + template + PROTOBUF_ALWAYS_INLINE uint8* WriteSInt64Packed(int num, const T& r, int size, + uint8* ptr) { + return WriteVarintPacked(num, r, size, ptr, ZigZagEncode64); + } + template + PROTOBUF_ALWAYS_INLINE uint8* WriteEnumPacked(int num, const T& r, int size, + uint8* ptr) { + return WriteVarintPacked(num, r, size, ptr, Encode64); + } + + template + PROTOBUF_ALWAYS_INLINE uint8* WriteFixedPacked(int num, const T& r, + uint8* ptr) { + ptr = EnsureSpace(ptr); + constexpr auto element_size = sizeof(typename T::value_type); + auto size = r.size() * element_size; + ptr = WriteLengthDelim(num, size, ptr); + return WriteRawLittleEndian(r.data(), static_cast(size), + ptr); + } + + // Returns true if there was an underlying I/O error since this object was + // created. + bool HadError() const { return had_error_; } + + // Instructs the EpsCopyOutputStream to allow the underlying + // ZeroCopyOutputStream to hold pointers to the original structure instead of + // copying, if it supports it (i.e. output->AllowsAliasing() is true). If the + // underlying stream does not support aliasing, then enabling it has no + // affect. For now, this only affects the behavior of + // WriteRawMaybeAliased(). + // + // NOTE: It is caller's responsibility to ensure that the chunk of memory + // remains live until all of the data has been consumed from the stream. + void EnableAliasing(bool enabled); + + // See documentation on CodedOutputStream::SetSerializationDeterministic. + void SetSerializationDeterministic(bool value) { + is_serialization_deterministic_ = value; + } + + // See documentation on CodedOutputStream::IsSerializationDeterministic. + bool IsSerializationDeterministic() const { + return is_serialization_deterministic_; + } + + // The number of bytes written to the stream at position ptr, relative to the + // stream's overall position. + int64 ByteCount(uint8* ptr) const; + + + private: + uint8* end_; + uint8* buffer_end_ = buffer_; + uint8 buffer_[2 * kSlopBytes]; + ZeroCopyOutputStream* stream_; + bool had_error_ = false; + bool aliasing_enabled_ = false; // See EnableAliasing(). + bool is_serialization_deterministic_; + + uint8* EnsureSpaceFallback(uint8* ptr); + inline uint8* Next(); + int Flush(uint8* ptr); + std::ptrdiff_t GetSize(uint8* ptr) const { + GOOGLE_DCHECK(ptr <= end_ + kSlopBytes); // NOLINT + return end_ + kSlopBytes - ptr; + } + + uint8* Error() { + had_error_ = true; + // We use the patch buffer to always guarantee space to write to. + end_ = buffer_ + kSlopBytes; + return buffer_; + } + + static constexpr int TagSize(uint32 tag) { + return (tag < (1 << 7)) + ? 1 + : (tag < (1 << 14)) + ? 2 + : (tag < (1 << 21)) ? 3 : (tag < (1 << 28)) ? 4 : 5; + } + + PROTOBUF_ALWAYS_INLINE uint8* WriteTag(uint32 num, uint32 wt, uint8* ptr) { + GOOGLE_DCHECK(ptr < end_); // NOLINT + return UnsafeVarint((num << 3) | wt, ptr); + } + + PROTOBUF_ALWAYS_INLINE uint8* WriteLengthDelim(int num, uint32 size, + uint8* ptr) { + ptr = WriteTag(num, 2, ptr); + return UnsafeWriteSize(size, ptr); + } + + uint8* WriteRawFallback(const void* data, int size, uint8* ptr); + + uint8* WriteAliasedRaw(const void* data, int size, uint8* ptr); + + uint8* WriteStringMaybeAliasedOutline(uint32 num, const std::string& s, + uint8* ptr); + uint8* WriteStringOutline(uint32 num, const std::string& s, uint8* ptr); + + template + PROTOBUF_ALWAYS_INLINE uint8* WriteVarintPacked(int num, const T& r, int size, + uint8* ptr, const E& encode) { + ptr = EnsureSpace(ptr); + ptr = WriteLengthDelim(num, size, ptr); + auto it = r.data(); + auto end = it + r.size(); + do { + ptr = EnsureSpace(ptr); + ptr = UnsafeVarint(encode(*it++), ptr); + } while (it < end); + return ptr; + } + + static uint32 Encode32(uint32 v) { return v; } + static uint64 Encode64(uint64 v) { return v; } + static uint32 ZigZagEncode32(int32 v) { + return (static_cast(v) << 1) ^ static_cast(v >> 31); + } + static uint64 ZigZagEncode64(int64 v) { + return (static_cast(v) << 1) ^ static_cast(v >> 63); + } + + template + PROTOBUF_ALWAYS_INLINE static uint8* UnsafeVarint(T value, uint8* ptr) { + static_assert(std::is_unsigned::value, + "Varint serialization must be unsigned"); + if (value < 0x80) { + ptr[0] = static_cast(value); + return ptr + 1; + } + ptr[0] = static_cast(value | 0x80); + value >>= 7; + if (value < 0x80) { + ptr[1] = static_cast(value); + return ptr + 2; + } + ptr++; + do { + *ptr = static_cast(value | 0x80); + value >>= 7; + ++ptr; + } while (PROTOBUF_PREDICT_FALSE(value >= 0x80)); + *ptr++ = static_cast(value); + return ptr; + } + + PROTOBUF_ALWAYS_INLINE static uint8* UnsafeWriteSize(uint32 value, + uint8* ptr) { + while (PROTOBUF_PREDICT_FALSE(value >= 0x80)) { + *ptr = static_cast(value | 0x80); + value >>= 7; + ++ptr; + } + *ptr++ = static_cast(value); + return ptr; + } + + template + uint8* WriteRawLittleEndian(const void* data, int size, uint8* ptr); +#ifndef PROTOBUF_LITTLE_ENDIAN + uint8* WriteRawLittleEndian32(const void* data, int size, uint8* ptr); + uint8* WriteRawLittleEndian64(const void* data, int size, uint8* ptr); +#endif + + // These methods are for CodedOutputStream. Ideally they should be private + // but to match current behavior of CodedOutputStream as close as possible + // we allow it some functionality. + public: + uint8* SetInitialBuffer(void* data, int size) { + auto ptr = static_cast(data); + if (size > kSlopBytes) { + end_ = ptr + size - kSlopBytes; + buffer_end_ = nullptr; + return ptr; + } else { + end_ = buffer_ + size; + buffer_end_ = ptr; + return buffer_; + } + } + + private: + // Needed by CodedOutputStream HadError. HadError needs to flush the patch + // buffers to ensure there is no error as of yet. + uint8* FlushAndResetBuffer(uint8*); + + // The following functions mimick the old CodedOutputStream behavior as close + // as possible. They flush the current state to the stream, behave as + // the old CodedOutputStream and then return to normal operation. + bool Skip(int count, uint8** pp); + bool GetDirectBufferPointer(void** data, int* size, uint8** pp); + uint8* GetDirectBufferForNBytesAndAdvance(int size, uint8** pp); + + friend class CodedOutputStream; +}; + +template <> +inline uint8* EpsCopyOutputStream::WriteRawLittleEndian<1>(const void* data, + int size, + uint8* ptr) { + return WriteRaw(data, size, ptr); +} +template <> +inline uint8* EpsCopyOutputStream::WriteRawLittleEndian<4>(const void* data, + int size, + uint8* ptr) { +#ifdef PROTOBUF_LITTLE_ENDIAN + return WriteRaw(data, size, ptr); +#else + return WriteRawLittleEndian32(data, size, ptr); +#endif +} +template <> +inline uint8* EpsCopyOutputStream::WriteRawLittleEndian<8>(const void* data, + int size, + uint8* ptr) { +#ifdef PROTOBUF_LITTLE_ENDIAN + return WriteRaw(data, size, ptr); +#else + return WriteRawLittleEndian64(data, size, ptr); +#endif +} + +// Class which encodes and writes binary data which is composed of varint- +// encoded integers and fixed-width pieces. Wraps a ZeroCopyOutputStream. +// Most users will not need to deal with CodedOutputStream. +// +// Most methods of CodedOutputStream which return a bool return false if an +// underlying I/O error occurs. Once such a failure occurs, the +// CodedOutputStream is broken and is no longer useful. The Write* methods do +// not return the stream status, but will invalidate the stream if an error +// occurs. The client can probe HadError() to determine the status. +// +// Note that every method of CodedOutputStream which writes some data has +// a corresponding static "ToArray" version. These versions write directly +// to the provided buffer, returning a pointer past the last written byte. +// They require that the buffer has sufficient capacity for the encoded data. +// This allows an optimization where we check if an output stream has enough +// space for an entire message before we start writing and, if there is, we +// call only the ToArray methods to avoid doing bound checks for each +// individual value. +// i.e., in the example above: +// +// CodedOutputStream* coded_output = new CodedOutputStream(raw_output); +// int magic_number = 1234; +// char text[] = "Hello world!"; +// +// int coded_size = sizeof(magic_number) + +// CodedOutputStream::VarintSize32(strlen(text)) + +// strlen(text); +// +// uint8* buffer = +// coded_output->GetDirectBufferForNBytesAndAdvance(coded_size); +// if (buffer != nullptr) { +// // The output stream has enough space in the buffer: write directly to +// // the array. +// buffer = CodedOutputStream::WriteLittleEndian32ToArray(magic_number, +// buffer); +// buffer = CodedOutputStream::WriteVarint32ToArray(strlen(text), buffer); +// buffer = CodedOutputStream::WriteRawToArray(text, strlen(text), buffer); +// } else { +// // Make bound-checked writes, which will ask the underlying stream for +// // more space as needed. +// coded_output->WriteLittleEndian32(magic_number); +// coded_output->WriteVarint32(strlen(text)); +// coded_output->WriteRaw(text, strlen(text)); +// } +// +// delete coded_output; +class PROTOBUF_EXPORT CodedOutputStream { + public: + // Create an CodedOutputStream that writes to the given ZeroCopyOutputStream. + explicit CodedOutputStream(ZeroCopyOutputStream* stream) + : CodedOutputStream(stream, true) {} + CodedOutputStream(ZeroCopyOutputStream* stream, bool do_eager_refresh); + + // Destroy the CodedOutputStream and position the underlying + // ZeroCopyOutputStream immediately after the last byte written. + ~CodedOutputStream(); + + // Returns true if there was an underlying I/O error since this object was + // created. On should call Trim before this function in order to catch all + // errors. + bool HadError() { + cur_ = impl_.FlushAndResetBuffer(cur_); + GOOGLE_DCHECK(cur_); + return impl_.HadError(); + } + + // Trims any unused space in the underlying buffer so that its size matches + // the number of bytes written by this stream. The underlying buffer will + // automatically be trimmed when this stream is destroyed; this call is only + // necessary if the underlying buffer is accessed *before* the stream is + // destroyed. + void Trim() { cur_ = impl_.Trim(cur_); } + + // Skips a number of bytes, leaving the bytes unmodified in the underlying + // buffer. Returns false if an underlying write error occurs. This is + // mainly useful with GetDirectBufferPointer(). + // Note of caution, the skipped bytes may contain uninitialized data. The + // caller must make sure that the skipped bytes are properly initialized, + // otherwise you might leak bytes from your heap. + bool Skip(int count) { return impl_.Skip(count, &cur_); } + + // Sets *data to point directly at the unwritten part of the + // CodedOutputStream's underlying buffer, and *size to the size of that + // buffer, but does not advance the stream's current position. This will + // always either produce a non-empty buffer or return false. If the caller + // writes any data to this buffer, it should then call Skip() to skip over + // the consumed bytes. This may be useful for implementing external fast + // serialization routines for types of data not covered by the + // CodedOutputStream interface. + bool GetDirectBufferPointer(void** data, int* size) { + return impl_.GetDirectBufferPointer(data, size, &cur_); + } + + // If there are at least "size" bytes available in the current buffer, + // returns a pointer directly into the buffer and advances over these bytes. + // The caller may then write directly into this buffer (e.g. using the + // *ToArray static methods) rather than go through CodedOutputStream. If + // there are not enough bytes available, returns NULL. The return pointer is + // invalidated as soon as any other non-const method of CodedOutputStream + // is called. + inline uint8* GetDirectBufferForNBytesAndAdvance(int size) { + return impl_.GetDirectBufferForNBytesAndAdvance(size, &cur_); + } + + // Write raw bytes, copying them from the given buffer. + void WriteRaw(const void* buffer, int size) { + cur_ = impl_.WriteRaw(buffer, size, cur_); + } + // Like WriteRaw() but will try to write aliased data if aliasing is + // turned on. + void WriteRawMaybeAliased(const void* data, int size); + // Like WriteRaw() but writing directly to the target array. + // This is _not_ inlined, as the compiler often optimizes memcpy into inline + // copy loops. Since this gets called by every field with string or bytes + // type, inlining may lead to a significant amount of code bloat, with only a + // minor performance gain. + static uint8* WriteRawToArray(const void* buffer, int size, uint8* target); + + // Equivalent to WriteRaw(str.data(), str.size()). + void WriteString(const std::string& str); + // Like WriteString() but writing directly to the target array. + static uint8* WriteStringToArray(const std::string& str, uint8* target); + // Write the varint-encoded size of str followed by str. + static uint8* WriteStringWithSizeToArray(const std::string& str, + uint8* target); + + + // Write a 32-bit little-endian integer. + void WriteLittleEndian32(uint32 value) { + cur_ = impl_.EnsureSpace(cur_); + SetCur(WriteLittleEndian32ToArray(value, Cur())); + } + // Like WriteLittleEndian32() but writing directly to the target array. + static uint8* WriteLittleEndian32ToArray(uint32 value, uint8* target); + // Write a 64-bit little-endian integer. + void WriteLittleEndian64(uint64 value) { + cur_ = impl_.EnsureSpace(cur_); + SetCur(WriteLittleEndian64ToArray(value, Cur())); + } + // Like WriteLittleEndian64() but writing directly to the target array. + static uint8* WriteLittleEndian64ToArray(uint64 value, uint8* target); + + // Write an unsigned integer with Varint encoding. Writing a 32-bit value + // is equivalent to casting it to uint64 and writing it as a 64-bit value, + // but may be more efficient. + void WriteVarint32(uint32 value); + // Like WriteVarint32() but writing directly to the target array. + static uint8* WriteVarint32ToArray(uint32 value, uint8* target); + // Write an unsigned integer with Varint encoding. + void WriteVarint64(uint64 value); + // Like WriteVarint64() but writing directly to the target array. + static uint8* WriteVarint64ToArray(uint64 value, uint8* target); + + // Equivalent to WriteVarint32() except when the value is negative, + // in which case it must be sign-extended to a full 10 bytes. + void WriteVarint32SignExtended(int32 value); + // Like WriteVarint32SignExtended() but writing directly to the target array. + static uint8* WriteVarint32SignExtendedToArray(int32 value, uint8* target); + + // This is identical to WriteVarint32(), but optimized for writing tags. + // In particular, if the input is a compile-time constant, this method + // compiles down to a couple instructions. + // Always inline because otherwise the aformentioned optimization can't work, + // but GCC by default doesn't want to inline this. + void WriteTag(uint32 value); + // Like WriteTag() but writing directly to the target array. + PROTOBUF_ALWAYS_INLINE + static uint8* WriteTagToArray(uint32 value, uint8* target); + + // Returns the number of bytes needed to encode the given value as a varint. + static size_t VarintSize32(uint32 value); + // Returns the number of bytes needed to encode the given value as a varint. + static size_t VarintSize64(uint64 value); + + // If negative, 10 bytes. Otherwise, same as VarintSize32(). + static size_t VarintSize32SignExtended(int32 value); + + // Compile-time equivalent of VarintSize32(). + template + struct StaticVarintSize32 { + static const size_t value = + (Value < (1 << 7)) + ? 1 + : (Value < (1 << 14)) + ? 2 + : (Value < (1 << 21)) ? 3 : (Value < (1 << 28)) ? 4 : 5; + }; + + // Returns the total number of bytes written since this object was created. + int ByteCount() const { + return static_cast(impl_.ByteCount(cur_) - start_count_); + } + + // Instructs the CodedOutputStream to allow the underlying + // ZeroCopyOutputStream to hold pointers to the original structure instead of + // copying, if it supports it (i.e. output->AllowsAliasing() is true). If the + // underlying stream does not support aliasing, then enabling it has no + // affect. For now, this only affects the behavior of + // WriteRawMaybeAliased(). + // + // NOTE: It is caller's responsibility to ensure that the chunk of memory + // remains live until all of the data has been consumed from the stream. + void EnableAliasing(bool enabled) { impl_.EnableAliasing(enabled); } + + // Indicate to the serializer whether the user wants derministic + // serialization. The default when this is not called comes from the global + // default, controlled by SetDefaultSerializationDeterministic. + // + // What deterministic serialization means is entirely up to the driver of the + // serialization process (i.e. the caller of methods like WriteVarint32). In + // the case of serializing a proto buffer message using one of the methods of + // MessageLite, this means that for a given binary equal messages will always + // be serialized to the same bytes. This implies: + // + // * Repeated serialization of a message will return the same bytes. + // + // * Different processes running the same binary (including on different + // machines) will serialize equal messages to the same bytes. + // + // Note that this is *not* canonical across languages. It is also unstable + // across different builds with intervening message definition changes, due to + // unknown fields. Users who need canonical serialization (e.g. persistent + // storage in a canonical form, fingerprinting) should define their own + // canonicalization specification and implement the serializer using + // reflection APIs rather than relying on this API. + void SetSerializationDeterministic(bool value) { + impl_.SetSerializationDeterministic(value); + } + + // Return whether the user wants deterministic serialization. See above. + bool IsSerializationDeterministic() const { + return impl_.IsSerializationDeterministic(); + } + + static bool IsDefaultSerializationDeterministic() { + return default_serialization_deterministic_.load( + std::memory_order_relaxed) != 0; + } + + template + void Serialize(const Func& func); + + uint8* Cur() const { return cur_; } + void SetCur(uint8* ptr) { cur_ = ptr; } + EpsCopyOutputStream* EpsCopy() { return &impl_; } + + private: + EpsCopyOutputStream impl_; + uint8* cur_; + int64 start_count_; + static std::atomic default_serialization_deterministic_; + + // See above. Other projects may use "friend" to allow them to call this. + // After SetDefaultSerializationDeterministic() completes, all protocol + // buffer serializations will be deterministic by default. Thread safe. + // However, the meaning of "after" is subtle here: to be safe, each thread + // that wants deterministic serialization by default needs to call + // SetDefaultSerializationDeterministic() or ensure on its own that another + // thread has done so. + friend void internal::MapTestForceDeterministic(); + static void SetDefaultSerializationDeterministic() { + default_serialization_deterministic_.store(true, std::memory_order_relaxed); + } + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(CodedOutputStream); +}; + +// inline methods ==================================================== +// The vast majority of varints are only one byte. These inline +// methods optimize for that case. + +inline bool CodedInputStream::ReadVarint32(uint32* value) { + uint32 v = 0; + if (PROTOBUF_PREDICT_TRUE(buffer_ < buffer_end_)) { + v = *buffer_; + if (v < 0x80) { + *value = v; + Advance(1); + return true; + } + } + int64 result = ReadVarint32Fallback(v); + *value = static_cast(result); + return result >= 0; +} + +inline bool CodedInputStream::ReadVarint64(uint64* value) { + if (PROTOBUF_PREDICT_TRUE(buffer_ < buffer_end_) && *buffer_ < 0x80) { + *value = *buffer_; + Advance(1); + return true; + } + std::pair p = ReadVarint64Fallback(); + *value = p.first; + return p.second; +} + +inline bool CodedInputStream::ReadVarintSizeAsInt(int* value) { + if (PROTOBUF_PREDICT_TRUE(buffer_ < buffer_end_)) { + int v = *buffer_; + if (v < 0x80) { + *value = v; + Advance(1); + return true; + } + } + *value = ReadVarintSizeAsIntFallback(); + return *value >= 0; +} + +// static +inline const uint8* CodedInputStream::ReadLittleEndian32FromArray( + const uint8* buffer, uint32* value) { +#if defined(PROTOBUF_LITTLE_ENDIAN) + memcpy(value, buffer, sizeof(*value)); + return buffer + sizeof(*value); +#else + *value = (static_cast(buffer[0])) | + (static_cast(buffer[1]) << 8) | + (static_cast(buffer[2]) << 16) | + (static_cast(buffer[3]) << 24); + return buffer + sizeof(*value); +#endif +} +// static +inline const uint8* CodedInputStream::ReadLittleEndian64FromArray( + const uint8* buffer, uint64* value) { +#if defined(PROTOBUF_LITTLE_ENDIAN) + memcpy(value, buffer, sizeof(*value)); + return buffer + sizeof(*value); +#else + uint32 part0 = (static_cast(buffer[0])) | + (static_cast(buffer[1]) << 8) | + (static_cast(buffer[2]) << 16) | + (static_cast(buffer[3]) << 24); + uint32 part1 = (static_cast(buffer[4])) | + (static_cast(buffer[5]) << 8) | + (static_cast(buffer[6]) << 16) | + (static_cast(buffer[7]) << 24); + *value = static_cast(part0) | (static_cast(part1) << 32); + return buffer + sizeof(*value); +#endif +} + +inline bool CodedInputStream::ReadLittleEndian32(uint32* value) { +#if defined(PROTOBUF_LITTLE_ENDIAN) + if (PROTOBUF_PREDICT_TRUE(BufferSize() >= static_cast(sizeof(*value)))) { + buffer_ = ReadLittleEndian32FromArray(buffer_, value); + return true; + } else { + return ReadLittleEndian32Fallback(value); + } +#else + return ReadLittleEndian32Fallback(value); +#endif +} + +inline bool CodedInputStream::ReadLittleEndian64(uint64* value) { +#if defined(PROTOBUF_LITTLE_ENDIAN) + if (PROTOBUF_PREDICT_TRUE(BufferSize() >= static_cast(sizeof(*value)))) { + buffer_ = ReadLittleEndian64FromArray(buffer_, value); + return true; + } else { + return ReadLittleEndian64Fallback(value); + } +#else + return ReadLittleEndian64Fallback(value); +#endif +} + +inline uint32 CodedInputStream::ReadTagNoLastTag() { + uint32 v = 0; + if (PROTOBUF_PREDICT_TRUE(buffer_ < buffer_end_)) { + v = *buffer_; + if (v < 0x80) { + Advance(1); + return v; + } + } + v = ReadTagFallback(v); + return v; +} + +inline std::pair CodedInputStream::ReadTagWithCutoffNoLastTag( + uint32 cutoff) { + // In performance-sensitive code we can expect cutoff to be a compile-time + // constant, and things like "cutoff >= kMax1ByteVarint" to be evaluated at + // compile time. + uint32 first_byte_or_zero = 0; + if (PROTOBUF_PREDICT_TRUE(buffer_ < buffer_end_)) { + // Hot case: buffer_ non_empty, buffer_[0] in [1, 128). + // TODO(gpike): Is it worth rearranging this? E.g., if the number of fields + // is large enough then is it better to check for the two-byte case first? + first_byte_or_zero = buffer_[0]; + if (static_cast(buffer_[0]) > 0) { + const uint32 kMax1ByteVarint = 0x7f; + uint32 tag = buffer_[0]; + Advance(1); + return std::make_pair(tag, cutoff >= kMax1ByteVarint || tag <= cutoff); + } + // Other hot case: cutoff >= 0x80, buffer_ has at least two bytes available, + // and tag is two bytes. The latter is tested by bitwise-and-not of the + // first byte and the second byte. + if (cutoff >= 0x80 && PROTOBUF_PREDICT_TRUE(buffer_ + 1 < buffer_end_) && + PROTOBUF_PREDICT_TRUE((buffer_[0] & ~buffer_[1]) >= 0x80)) { + const uint32 kMax2ByteVarint = (0x7f << 7) + 0x7f; + uint32 tag = (1u << 7) * buffer_[1] + (buffer_[0] - 0x80); + Advance(2); + // It might make sense to test for tag == 0 now, but it is so rare that + // that we don't bother. A varint-encoded 0 should be one byte unless + // the encoder lost its mind. The second part of the return value of + // this function is allowed to be either true or false if the tag is 0, + // so we don't have to check for tag == 0. We may need to check whether + // it exceeds cutoff. + bool at_or_below_cutoff = cutoff >= kMax2ByteVarint || tag <= cutoff; + return std::make_pair(tag, at_or_below_cutoff); + } + } + // Slow path + const uint32 tag = ReadTagFallback(first_byte_or_zero); + return std::make_pair(tag, static_cast(tag - 1) < cutoff); +} + +inline bool CodedInputStream::LastTagWas(uint32 expected) { + return last_tag_ == expected; +} + +inline bool CodedInputStream::ConsumedEntireMessage() { + return legitimate_message_end_; +} + +inline bool CodedInputStream::ExpectTag(uint32 expected) { + if (expected < (1 << 7)) { + if (PROTOBUF_PREDICT_TRUE(buffer_ < buffer_end_) && + buffer_[0] == expected) { + Advance(1); + return true; + } else { + return false; + } + } else if (expected < (1 << 14)) { + if (PROTOBUF_PREDICT_TRUE(BufferSize() >= 2) && + buffer_[0] == static_cast(expected | 0x80) && + buffer_[1] == static_cast(expected >> 7)) { + Advance(2); + return true; + } else { + return false; + } + } else { + // Don't bother optimizing for larger values. + return false; + } +} + +inline const uint8* CodedInputStream::ExpectTagFromArray(const uint8* buffer, + uint32 expected) { + if (expected < (1 << 7)) { + if (buffer[0] == expected) { + return buffer + 1; + } + } else if (expected < (1 << 14)) { + if (buffer[0] == static_cast(expected | 0x80) && + buffer[1] == static_cast(expected >> 7)) { + return buffer + 2; + } + } + return nullptr; +} + +inline void CodedInputStream::GetDirectBufferPointerInline(const void** data, + int* size) { + *data = buffer_; + *size = static_cast(buffer_end_ - buffer_); +} + +inline bool CodedInputStream::ExpectAtEnd() { + // If we are at a limit we know no more bytes can be read. Otherwise, it's + // hard to say without calling Refresh(), and we'd rather not do that. + + if (buffer_ == buffer_end_ && ((buffer_size_after_limit_ != 0) || + (total_bytes_read_ == current_limit_))) { + last_tag_ = 0; // Pretend we called ReadTag()... + legitimate_message_end_ = true; // ... and it hit EOF. + return true; + } else { + return false; + } +} + +inline int CodedInputStream::CurrentPosition() const { + return total_bytes_read_ - (BufferSize() + buffer_size_after_limit_); +} + +inline void CodedInputStream::Advance(int amount) { buffer_ += amount; } + +inline void CodedInputStream::SetRecursionLimit(int limit) { + recursion_budget_ += limit - recursion_limit_; + recursion_limit_ = limit; +} + +inline bool CodedInputStream::IncrementRecursionDepth() { + --recursion_budget_; + return recursion_budget_ >= 0; +} + +inline void CodedInputStream::DecrementRecursionDepth() { + if (recursion_budget_ < recursion_limit_) ++recursion_budget_; +} + +inline void CodedInputStream::UnsafeDecrementRecursionDepth() { + assert(recursion_budget_ < recursion_limit_); + ++recursion_budget_; +} + +inline void CodedInputStream::SetExtensionRegistry(const DescriptorPool* pool, + MessageFactory* factory) { + extension_pool_ = pool; + extension_factory_ = factory; +} + +inline const DescriptorPool* CodedInputStream::GetExtensionPool() { + return extension_pool_; +} + +inline MessageFactory* CodedInputStream::GetExtensionFactory() { + return extension_factory_; +} + +inline int CodedInputStream::BufferSize() const { + return static_cast(buffer_end_ - buffer_); +} + +inline CodedInputStream::CodedInputStream(ZeroCopyInputStream* input) + : buffer_(nullptr), + buffer_end_(nullptr), + input_(input), + total_bytes_read_(0), + overflow_bytes_(0), + last_tag_(0), + legitimate_message_end_(false), + aliasing_enabled_(false), + current_limit_(kint32max), + buffer_size_after_limit_(0), + total_bytes_limit_(kDefaultTotalBytesLimit), + recursion_budget_(default_recursion_limit_), + recursion_limit_(default_recursion_limit_), + extension_pool_(nullptr), + extension_factory_(nullptr) { + // Eagerly Refresh() so buffer space is immediately available. + Refresh(); +} + +inline CodedInputStream::CodedInputStream(const uint8* buffer, int size) + : buffer_(buffer), + buffer_end_(buffer + size), + input_(nullptr), + total_bytes_read_(size), + overflow_bytes_(0), + last_tag_(0), + legitimate_message_end_(false), + aliasing_enabled_(false), + current_limit_(size), + buffer_size_after_limit_(0), + total_bytes_limit_(kDefaultTotalBytesLimit), + recursion_budget_(default_recursion_limit_), + recursion_limit_(default_recursion_limit_), + extension_pool_(nullptr), + extension_factory_(nullptr) { + // Note that setting current_limit_ == size is important to prevent some + // code paths from trying to access input_ and segfaulting. +} + +inline bool CodedInputStream::IsFlat() const { return input_ == nullptr; } + +inline bool CodedInputStream::Skip(int count) { + if (count < 0) return false; // security: count is often user-supplied + + const int original_buffer_size = BufferSize(); + + if (count <= original_buffer_size) { + // Just skipping within the current buffer. Easy. + Advance(count); + return true; + } + + return SkipFallback(count, original_buffer_size); +} + +inline uint8* CodedOutputStream::WriteVarint32ToArray(uint32 value, + uint8* target) { + return EpsCopyOutputStream::UnsafeVarint(value, target); +} + +inline uint8* CodedOutputStream::WriteVarint64ToArray(uint64 value, + uint8* target) { + return EpsCopyOutputStream::UnsafeVarint(value, target); +} + +inline void CodedOutputStream::WriteVarint32SignExtended(int32 value) { + WriteVarint64(static_cast(value)); +} + +inline uint8* CodedOutputStream::WriteVarint32SignExtendedToArray( + int32 value, uint8* target) { + return WriteVarint64ToArray(static_cast(value), target); +} + +inline uint8* CodedOutputStream::WriteLittleEndian32ToArray(uint32 value, + uint8* target) { +#if defined(PROTOBUF_LITTLE_ENDIAN) + memcpy(target, &value, sizeof(value)); +#else + target[0] = static_cast(value); + target[1] = static_cast(value >> 8); + target[2] = static_cast(value >> 16); + target[3] = static_cast(value >> 24); +#endif + return target + sizeof(value); +} + +inline uint8* CodedOutputStream::WriteLittleEndian64ToArray(uint64 value, + uint8* target) { +#if defined(PROTOBUF_LITTLE_ENDIAN) + memcpy(target, &value, sizeof(value)); +#else + uint32 part0 = static_cast(value); + uint32 part1 = static_cast(value >> 32); + + target[0] = static_cast(part0); + target[1] = static_cast(part0 >> 8); + target[2] = static_cast(part0 >> 16); + target[3] = static_cast(part0 >> 24); + target[4] = static_cast(part1); + target[5] = static_cast(part1 >> 8); + target[6] = static_cast(part1 >> 16); + target[7] = static_cast(part1 >> 24); +#endif + return target + sizeof(value); +} + +inline void CodedOutputStream::WriteVarint32(uint32 value) { + cur_ = impl_.EnsureSpace(cur_); + SetCur(WriteVarint32ToArray(value, Cur())); +} + +inline void CodedOutputStream::WriteVarint64(uint64 value) { + cur_ = impl_.EnsureSpace(cur_); + SetCur(WriteVarint64ToArray(value, Cur())); +} + +inline void CodedOutputStream::WriteTag(uint32 value) { WriteVarint32(value); } + +inline uint8* CodedOutputStream::WriteTagToArray(uint32 value, uint8* target) { + return WriteVarint32ToArray(value, target); +} + +inline size_t CodedOutputStream::VarintSize32(uint32 value) { + // This computes value == 0 ? 1 : floor(log2(value)) / 7 + 1 + // Use an explicit multiplication to implement the divide of + // a number in the 1..31 range. + // Explicit OR 0x1 to avoid calling Bits::Log2FloorNonZero(0), which is + // undefined. + uint32 log2value = Bits::Log2FloorNonZero(value | 0x1); + return static_cast((log2value * 9 + 73) / 64); +} + +inline size_t CodedOutputStream::VarintSize64(uint64 value) { + // This computes value == 0 ? 1 : floor(log2(value)) / 7 + 1 + // Use an explicit multiplication to implement the divide of + // a number in the 1..63 range. + // Explicit OR 0x1 to avoid calling Bits::Log2FloorNonZero(0), which is + // undefined. + uint32 log2value = Bits::Log2FloorNonZero64(value | 0x1); + return static_cast((log2value * 9 + 73) / 64); +} + +inline size_t CodedOutputStream::VarintSize32SignExtended(int32 value) { + if (value < 0) { + return 10; // TODO(kenton): Make this a symbolic constant. + } else { + return VarintSize32(static_cast(value)); + } +} + +inline void CodedOutputStream::WriteString(const std::string& str) { + WriteRaw(str.data(), static_cast(str.size())); +} + +inline void CodedOutputStream::WriteRawMaybeAliased(const void* data, + int size) { + cur_ = impl_.WriteRawMaybeAliased(data, size, cur_); +} + +inline uint8* CodedOutputStream::WriteRawToArray(const void* data, int size, + uint8* target) { + memcpy(target, data, size); + return target + size; +} + +inline uint8* CodedOutputStream::WriteStringToArray(const std::string& str, + uint8* target) { + return WriteRawToArray(str.data(), static_cast(str.size()), target); +} + +} // namespace io +} // namespace protobuf +} // namespace google + +#if defined(_MSC_VER) && _MSC_VER >= 1300 && !defined(__INTEL_COMPILER) +#pragma runtime_checks("c", restore) +#endif // _MSC_VER && !defined(__INTEL_COMPILER) + +#include + +#endif // GOOGLE_PROTOBUF_IO_CODED_STREAM_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/io/gzip_stream.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/io/gzip_stream.h new file mode 100644 index 0000000000000000000000000000000000000000..cb0dac875a0720d0143d23bb2163f3f249cc5594 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/io/gzip_stream.h @@ -0,0 +1,207 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: brianolson@google.com (Brian Olson) +// +// This file contains the definition for classes GzipInputStream and +// GzipOutputStream. +// +// GzipInputStream decompresses data from an underlying +// ZeroCopyInputStream and provides the decompressed data as a +// ZeroCopyInputStream. +// +// GzipOutputStream is an ZeroCopyOutputStream that compresses data to +// an underlying ZeroCopyOutputStream. + +#ifndef GOOGLE_PROTOBUF_IO_GZIP_STREAM_H__ +#define GOOGLE_PROTOBUF_IO_GZIP_STREAM_H__ + + +#include +#include +#include +#include + +#include + +namespace google { +namespace protobuf { +namespace io { + +// A ZeroCopyInputStream that reads compressed data through zlib +class PROTOBUF_EXPORT GzipInputStream : public ZeroCopyInputStream { + public: + // Format key for constructor + enum Format { + // zlib will autodetect gzip header or deflate stream + AUTO = 0, + + // GZIP streams have some extra header data for file attributes. + GZIP = 1, + + // Simpler zlib stream format. + ZLIB = 2, + }; + + // buffer_size and format may be -1 for default of 64kB and GZIP format + explicit GzipInputStream(ZeroCopyInputStream* sub_stream, + Format format = AUTO, int buffer_size = -1); + virtual ~GzipInputStream(); + + // Return last error message or NULL if no error. + inline const char* ZlibErrorMessage() const { return zcontext_.msg; } + inline int ZlibErrorCode() const { return zerror_; } + + // implements ZeroCopyInputStream ---------------------------------- + bool Next(const void** data, int* size); + void BackUp(int count); + bool Skip(int count); + int64_t ByteCount() const; + + private: + Format format_; + + ZeroCopyInputStream* sub_stream_; + + z_stream zcontext_; + int zerror_; + + void* output_buffer_; + void* output_position_; + size_t output_buffer_length_; + int64 byte_count_; + + int Inflate(int flush); + void DoNextOutput(const void** data, int* size); + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(GzipInputStream); +}; + +class PROTOBUF_EXPORT GzipOutputStream : public ZeroCopyOutputStream { + public: + // Format key for constructor + enum Format { + // GZIP streams have some extra header data for file attributes. + GZIP = 1, + + // Simpler zlib stream format. + ZLIB = 2, + }; + + struct PROTOBUF_EXPORT Options { + // Defaults to GZIP. + Format format; + + // What size buffer to use internally. Defaults to 64kB. + int buffer_size; + + // A number between 0 and 9, where 0 is no compression and 9 is best + // compression. Defaults to Z_DEFAULT_COMPRESSION (see zlib.h). + int compression_level; + + // Defaults to Z_DEFAULT_STRATEGY. Can also be set to Z_FILTERED, + // Z_HUFFMAN_ONLY, or Z_RLE. See the documentation for deflateInit2 in + // zlib.h for definitions of these constants. + int compression_strategy; + + Options(); // Initializes with default values. + }; + + // Create a GzipOutputStream with default options. + explicit GzipOutputStream(ZeroCopyOutputStream* sub_stream); + + // Create a GzipOutputStream with the given options. + GzipOutputStream(ZeroCopyOutputStream* sub_stream, const Options& options); + + virtual ~GzipOutputStream(); + + // Return last error message or NULL if no error. + inline const char* ZlibErrorMessage() const { return zcontext_.msg; } + inline int ZlibErrorCode() const { return zerror_; } + + // Flushes data written so far to zipped data in the underlying stream. + // It is the caller's responsibility to flush the underlying stream if + // necessary. + // Compression may be less efficient stopping and starting around flushes. + // Returns true if no error. + // + // Please ensure that block size is > 6. Here is an excerpt from the zlib + // doc that explains why: + // + // In the case of a Z_FULL_FLUSH or Z_SYNC_FLUSH, make sure that avail_out + // is greater than six to avoid repeated flush markers due to + // avail_out == 0 on return. + bool Flush(); + + // Writes out all data and closes the gzip stream. + // It is the caller's responsibility to close the underlying stream if + // necessary. + // Returns true if no error. + bool Close(); + + // implements ZeroCopyOutputStream --------------------------------- + bool Next(void** data, int* size); + void BackUp(int count); + int64_t ByteCount() const; + + private: + ZeroCopyOutputStream* sub_stream_; + // Result from calling Next() on sub_stream_ + void* sub_data_; + int sub_data_size_; + + z_stream zcontext_; + int zerror_; + void* input_buffer_; + size_t input_buffer_length_; + + // Shared constructor code. + void Init(ZeroCopyOutputStream* sub_stream, const Options& options); + + // Do some compression. + // Takes zlib flush mode. + // Returns zlib error code. + int Deflate(int flush); + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(GzipOutputStream); +}; + +} // namespace io +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_IO_GZIP_STREAM_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/io/io_win32.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/io/io_win32.h new file mode 100644 index 0000000000000000000000000000000000000000..bbbae7e4f95770749524b88618561c0759e3aa3b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/io/io_win32.h @@ -0,0 +1,144 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: laszlocsomor@google.com (Laszlo Csomor) +// Based on original Protocol Buffers design by +// Sanjay Ghemawat, Jeff Dean, and others. + +// This file contains the declarations for Windows implementations of +// commonly used POSIX functions such as open(2) and access(2), as well +// as macro definitions for flags of these functions. +// +// By including this file you'll redefine open/access/etc. to +// ::google::protobuf::io::win32::{open/access/etc.}. +// Make sure you don't include a header that attempts to redeclare or +// redefine these functions, that'll lead to confusing compilation +// errors. It's best to #include this file as the last one to ensure that. +// +// This file is only used on Windows, it's empty on other platforms. + +#ifndef GOOGLE_PROTOBUF_IO_IO_WIN32_H__ +#define GOOGLE_PROTOBUF_IO_IO_WIN32_H__ + +#if defined(_WIN32) + +#include +#include + +#include +#include + +// Compilers on Windows other than MSVC (e.g. Cygwin, MinGW32) define the +// following functions already, except for mkdir. +namespace google { +namespace protobuf { +namespace io { +namespace win32 { + +PROTOBUF_EXPORT FILE* fopen(const char* path, const char* mode); +PROTOBUF_EXPORT int access(const char* path, int mode); +PROTOBUF_EXPORT int chdir(const char* path); +PROTOBUF_EXPORT int close(int fd); +PROTOBUF_EXPORT int dup(int fd); +PROTOBUF_EXPORT int dup2(int fd1, int fd2); +PROTOBUF_EXPORT int mkdir(const char* path, int _mode); +PROTOBUF_EXPORT int open(const char* path, int flags, int mode = 0); +PROTOBUF_EXPORT int read(int fd, void* buffer, size_t size); +PROTOBUF_EXPORT int setmode(int fd, int mode); +PROTOBUF_EXPORT int stat(const char* path, struct _stat* buffer); +PROTOBUF_EXPORT int write(int fd, const void* buffer, size_t size); +PROTOBUF_EXPORT std::wstring testonly_utf8_to_winpath(const char* path); + +enum class ExpandWildcardsResult { + kSuccess = 0, + kErrorNoMatchingFile = 1, + kErrorInputPathConversion = 2, + kErrorOutputPathConversion = 3, +}; + +// Expand wildcards in a path pattern, feed the result to a consumer function. +// +// `path` must be a valid, Windows-style path. It may be absolute, or relative +// to the current working directory, and it may contain wildcards ("*" and "?") +// in the last path segment. This function passes all matching file names to +// `consume`. The resulting paths may not be absolute nor normalized. +// +// The function returns a value from `ExpandWildcardsResult`. +PROTOBUF_EXPORT ExpandWildcardsResult ExpandWildcards( + const std::string& path, std::function consume); + +namespace strings { + +// Convert from UTF-16 to Active-Code-Page-encoded or to UTF-8-encoded text. +PROTOBUF_EXPORT bool wcs_to_mbs(const wchar_t* s, std::string* out, + bool outUtf8); + +// Convert from Active-Code-Page-encoded or UTF-8-encoded text to UTF-16. +PROTOBUF_EXPORT bool mbs_to_wcs(const char* s, std::wstring* out, bool inUtf8); + +// Convert from UTF-8-encoded text to UTF-16. +PROTOBUF_EXPORT bool utf8_to_wcs(const char* input, std::wstring* out); + +// Convert from UTF-16-encoded text to UTF-8. +PROTOBUF_EXPORT bool wcs_to_utf8(const wchar_t* input, std::string* out); + +} // namespace strings + +} // namespace win32 +} // namespace io +} // namespace protobuf +} // namespace google + +#ifndef W_OK +#define W_OK 02 // not defined by MSVC for whatever reason +#endif + +#ifndef F_OK +#define F_OK 00 // not defined by MSVC for whatever reason +#endif + +#ifndef STDIN_FILENO +#define STDIN_FILENO 0 +#endif + +#ifndef STDOUT_FILENO +#define STDOUT_FILENO 1 +#endif + +#include + +#endif // defined(_WIN32) + +#endif // GOOGLE_PROTOBUF_IO_IO_WIN32_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/io/printer.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/io/printer.h new file mode 100644 index 0000000000000000000000000000000000000000..ad6985d6f019a7508ae929148a30b5b1b1d2268a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/io/printer.h @@ -0,0 +1,390 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: kenton@google.com (Kenton Varda) +// Based on original Protocol Buffers design by +// Sanjay Ghemawat, Jeff Dean, and others. +// +// Utility class for writing text to a ZeroCopyOutputStream. + +#ifndef GOOGLE_PROTOBUF_IO_PRINTER_H__ +#define GOOGLE_PROTOBUF_IO_PRINTER_H__ + + +#include +#include +#include + +#include +#include + +namespace google { +namespace protobuf { +namespace io { + +class ZeroCopyOutputStream; // zero_copy_stream.h + +// Records annotations about a Printer's output. +class PROTOBUF_EXPORT AnnotationCollector { + public: + // Annotation is a offset range and a payload pair. + typedef std::pair, std::string> Annotation; + + // Records that the bytes in file_path beginning with begin_offset and ending + // before end_offset are associated with the SourceCodeInfo-style path. + virtual void AddAnnotation(size_t begin_offset, size_t end_offset, + const std::string& file_path, + const std::vector& path) = 0; + + // TODO(gerbens) I don't see why we need virtuals here. Just a vector of + // range, payload pairs stored in a context should suffice. + virtual void AddAnnotationNew(Annotation& a) {} + + virtual ~AnnotationCollector() {} +}; + +// Records annotations about a Printer's output to the given protocol buffer, +// assuming that the buffer has an ::Annotation message exposing path, +// source_file, begin and end fields. +template +class AnnotationProtoCollector : public AnnotationCollector { + public: + // annotation_proto is the protocol buffer to which new Annotations should be + // added. It is not owned by the AnnotationProtoCollector. + explicit AnnotationProtoCollector(AnnotationProto* annotation_proto) + : annotation_proto_(annotation_proto) {} + + // Override for AnnotationCollector::AddAnnotation. + virtual void AddAnnotation(size_t begin_offset, size_t end_offset, + const std::string& file_path, + const std::vector& path) { + typename AnnotationProto::Annotation* annotation = + annotation_proto_->add_annotation(); + for (int i = 0; i < path.size(); ++i) { + annotation->add_path(path[i]); + } + annotation->set_source_file(file_path); + annotation->set_begin(begin_offset); + annotation->set_end(end_offset); + } + // Override for AnnotationCollector::AddAnnotation. + virtual void AddAnnotationNew(Annotation& a) { + auto* annotation = annotation_proto_->add_annotation(); + annotation->ParseFromString(a.second); + annotation->set_begin(a.first.first); + annotation->set_end(a.first.second); + } + + private: + // The protocol buffer to which new annotations should be added. + AnnotationProto* const annotation_proto_; +}; + +// This simple utility class assists in code generation. It basically +// allows the caller to define a set of variables and then output some +// text with variable substitutions. Example usage: +// +// Printer printer(output, '$'); +// map vars; +// vars["name"] = "Bob"; +// printer.Print(vars, "My name is $name$."); +// +// The above writes "My name is Bob." to the output stream. +// +// Printer aggressively enforces correct usage, crashing (with assert failures) +// in the case of undefined variables in debug builds. This helps greatly in +// debugging code which uses it. +// +// If a Printer is constructed with an AnnotationCollector, it will provide it +// with annotations that connect the Printer's output to paths that can identify +// various descriptors. In the above example, if person_ is a descriptor that +// identifies Bob, we can associate the output string "My name is Bob." with +// a source path pointing to that descriptor with: +// +// printer.Annotate("name", person_); +// +// The AnnotationCollector will be sent an annotation linking the output range +// covering "Bob" to the logical path provided by person_. Tools may use +// this association to (for example) link "Bob" in the output back to the +// source file that defined the person_ descriptor identifying Bob. +// +// Annotate can only examine variables substituted during the last call to +// Print. It is invalid to refer to a variable that was used multiple times +// in a single Print call. +// +// In full generality, one may specify a range of output text using a beginning +// substitution variable and an ending variable. The resulting annotation will +// span from the first character of the substituted value for the beginning +// variable to the last character of the substituted value for the ending +// variable. For example, the Annotate call above is equivalent to this one: +// +// printer.Annotate("name", "name", person_); +// +// This is useful if multiple variables combine to form a single span of output +// that should be annotated with the same source path. For example: +// +// Printer printer(output, '$'); +// map vars; +// vars["first"] = "Alice"; +// vars["last"] = "Smith"; +// printer.Print(vars, "My name is $first$ $last$."); +// printer.Annotate("first", "last", person_); +// +// This code would associate the span covering "Alice Smith" in the output with +// the person_ descriptor. +// +// Note that the beginning variable must come before (or overlap with, in the +// case of zero-sized substitution values) the ending variable. +// +// It is also sometimes useful to use variables with zero-sized values as +// markers. This avoids issues with multiple references to the same variable +// and also allows annotation ranges to span literal text from the Print +// templates: +// +// Printer printer(output, '$'); +// map vars; +// vars["foo"] = "bar"; +// vars["function"] = "call"; +// vars["mark"] = ""; +// printer.Print(vars, "$function$($foo$,$foo$)$mark$"); +// printer.Annotate("function", "mark", call_); +// +// This code associates the span covering "call(bar,bar)" in the output with the +// call_ descriptor. + +class PROTOBUF_EXPORT Printer { + public: + // Create a printer that writes text to the given output stream. Use the + // given character as the delimiter for variables. + Printer(ZeroCopyOutputStream* output, char variable_delimiter); + + // Create a printer that writes text to the given output stream. Use the + // given character as the delimiter for variables. If annotation_collector + // is not null, Printer will provide it with annotations about code written + // to the stream. annotation_collector is not owned by Printer. + Printer(ZeroCopyOutputStream* output, char variable_delimiter, + AnnotationCollector* annotation_collector); + + ~Printer(); + + // Link a substitution variable emitted by the last call to Print to the + // object described by descriptor. + template + void Annotate(const char* varname, const SomeDescriptor* descriptor) { + Annotate(varname, varname, descriptor); + } + + // Link the output range defined by the substitution variables as emitted by + // the last call to Print to the object described by descriptor. The range + // begins at begin_varname's value and ends after the last character of the + // value substituted for end_varname. + template + void Annotate(const char* begin_varname, const char* end_varname, + const SomeDescriptor* descriptor) { + if (annotation_collector_ == NULL) { + // Annotations aren't turned on for this Printer, so don't pay the cost + // of building the location path. + return; + } + std::vector path; + descriptor->GetLocationPath(&path); + Annotate(begin_varname, end_varname, descriptor->file()->name(), path); + } + + // Link a substitution variable emitted by the last call to Print to the file + // with path file_name. + void Annotate(const char* varname, const std::string& file_name) { + Annotate(varname, varname, file_name); + } + + // Link the output range defined by the substitution variables as emitted by + // the last call to Print to the file with path file_name. The range begins + // at begin_varname's value and ends after the last character of the value + // substituted for end_varname. + void Annotate(const char* begin_varname, const char* end_varname, + const std::string& file_name) { + if (annotation_collector_ == NULL) { + // Annotations aren't turned on for this Printer. + return; + } + std::vector empty_path; + Annotate(begin_varname, end_varname, file_name, empty_path); + } + + // Print some text after applying variable substitutions. If a particular + // variable in the text is not defined, this will crash. Variables to be + // substituted are identified by their names surrounded by delimiter + // characters (as given to the constructor). The variable bindings are + // defined by the given map. + void Print(const std::map& variables, + const char* text); + + // Like the first Print(), except the substitutions are given as parameters. + template + void Print(const char* text, const Args&... args) { + std::map vars; + PrintInternal(&vars, text, args...); + } + + // Indent text by two spaces. After calling Indent(), two spaces will be + // inserted at the beginning of each line of text. Indent() may be called + // multiple times to produce deeper indents. + void Indent(); + + // Reduces the current indent level by two spaces, or crashes if the indent + // level is zero. + void Outdent(); + + // Write a string to the output buffer. + // This method does not look for newlines to add indentation. + void PrintRaw(const std::string& data); + + // Write a zero-delimited string to output buffer. + // This method does not look for newlines to add indentation. + void PrintRaw(const char* data); + + // Write some bytes to the output buffer. + // This method does not look for newlines to add indentation. + void WriteRaw(const char* data, int size); + + // FormatInternal is a helper function not meant to use directly, use + // compiler::cpp::Formatter instead. This function is meant to support + // formatting text using named variables (eq. "$foo$) from a lookup map (vars) + // and variables directly supplied by arguments (eq "$1$" meaning first + // argument which is the zero index element of args). + void FormatInternal(const std::vector& args, + const std::map& vars, + const char* format); + + // True if any write to the underlying stream failed. (We don't just + // crash in this case because this is an I/O failure, not a programming + // error.) + bool failed() const { return failed_; } + + private: + // Link the output range defined by the substitution variables as emitted by + // the last call to Print to the object found at the SourceCodeInfo-style path + // in a file with path file_path. The range begins at the start of + // begin_varname's value and ends after the last character of the value + // substituted for end_varname. Note that begin_varname and end_varname + // may refer to the same variable. + void Annotate(const char* begin_varname, const char* end_varname, + const std::string& file_path, const std::vector& path); + + // Base case + void PrintInternal(std::map* vars, + const char* text) { + Print(*vars, text); + } + + template + void PrintInternal(std::map* vars, const char* text, + const char* key, const std::string& value, + const Args&... args) { + (*vars)[key] = value; + PrintInternal(vars, text, args...); + } + + // Copy size worth of bytes from data to buffer_. + void CopyToBuffer(const char* data, int size); + + void push_back(char c) { + if (failed_) return; + if (buffer_size_ == 0) { + if (!Next()) return; + } + *buffer_++ = c; + buffer_size_--; + offset_++; + } + + bool Next(); + + inline void IndentIfAtStart(); + const char* WriteVariable( + const std::vector& args, + const std::map& vars, const char* format, + int* arg_index, + std::vector* annotations); + + const char variable_delimiter_; + + ZeroCopyOutputStream* const output_; + char* buffer_; + int buffer_size_; + // The current position, in bytes, in the output stream. This is equivalent + // to the total number of bytes that have been written so far. This value is + // used to calculate annotation ranges in the substitutions_ map below. + size_t offset_; + + std::string indent_; + bool at_start_of_line_; + bool failed_; + + // A map from variable name to [start, end) offsets in the output buffer. + // These refer to the offsets used for a variable after the last call to + // Print. If a variable was used more than once, the entry used in + // this map is set to a negative-length span. For singly-used variables, the + // start offset is the beginning of the substitution; the end offset is the + // last byte of the substitution plus one (such that (end - start) is the + // length of the substituted string). + std::map > substitutions_; + + // Keeps track of the keys in substitutions_ that need to be updated when + // indents are inserted. These are keys that refer to the beginning of the + // current line. + std::vector line_start_variables_; + + // Returns true and sets range to the substitution range in the output for + // varname if varname was used once in the last call to Print. If varname + // was not used, or if it was used multiple times, returns false (and + // fails a debug assertion). + bool GetSubstitutionRange(const char* varname, + std::pair* range); + + // If non-null, annotation_collector_ is used to store annotations about + // generated code. + AnnotationCollector* const annotation_collector_; + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(Printer); +}; + +} // namespace io +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_IO_PRINTER_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/io/strtod.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/io/strtod.h new file mode 100644 index 0000000000000000000000000000000000000000..e05ba81b001b4bd4c5c0c59a175b02083b74e501 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/io/strtod.h @@ -0,0 +1,60 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// A locale-independent version of strtod(), used to parse floating +// point default values in .proto files, where the decimal separator +// is always a dot. + +#ifndef GOOGLE_PROTOBUF_IO_STRTOD_H__ +#define GOOGLE_PROTOBUF_IO_STRTOD_H__ + +namespace google { +namespace protobuf { +namespace io { + +// A locale-independent version of the standard strtod(), which always +// uses a dot as the decimal separator. +double NoLocaleStrtod(const char* str, char** endptr); + +// Casts a double value to a float value. If the value is outside of the +// representable range of float, it will be converted to positive or negative +// infinity. +float SafeDoubleToFloat(double value); + +} // namespace io +} // namespace protobuf +} // namespace google + +#endif // GOOGLE_PROTOBUF_IO_STRTOD_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/io/tokenizer.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/io/tokenizer.h new file mode 100644 index 0000000000000000000000000000000000000000..984a0597e6a916eaa809333d99e138b7c962cb40 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/io/tokenizer.h @@ -0,0 +1,418 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: kenton@google.com (Kenton Varda) +// Based on original Protocol Buffers design by +// Sanjay Ghemawat, Jeff Dean, and others. +// +// Class for parsing tokenized text from a ZeroCopyInputStream. + +#ifndef GOOGLE_PROTOBUF_IO_TOKENIZER_H__ +#define GOOGLE_PROTOBUF_IO_TOKENIZER_H__ + + +#include +#include + +#include +#include +#include + +namespace google { +namespace protobuf { +namespace io { + +class ZeroCopyInputStream; // zero_copy_stream.h + +// Defined in this file. +class ErrorCollector; +class Tokenizer; + +// By "column number", the proto compiler refers to a count of the number +// of bytes before a given byte, except that a tab character advances to +// the next multiple of 8 bytes. Note in particular that column numbers +// are zero-based, while many user interfaces use one-based column numbers. +typedef int ColumnNumber; + +// Abstract interface for an object which collects the errors that occur +// during parsing. A typical implementation might simply print the errors +// to stdout. +class PROTOBUF_EXPORT ErrorCollector { + public: + inline ErrorCollector() {} + virtual ~ErrorCollector(); + + // Indicates that there was an error in the input at the given line and + // column numbers. The numbers are zero-based, so you may want to add + // 1 to each before printing them. + virtual void AddError(int line, ColumnNumber column, + const std::string& message) = 0; + + // Indicates that there was a warning in the input at the given line and + // column numbers. The numbers are zero-based, so you may want to add + // 1 to each before printing them. + virtual void AddWarning(int line, ColumnNumber column, + const std::string& message) {} + + private: + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(ErrorCollector); +}; + +// This class converts a stream of raw text into a stream of tokens for +// the protocol definition parser to parse. The tokens recognized are +// similar to those that make up the C language; see the TokenType enum for +// precise descriptions. Whitespace and comments are skipped. By default, +// C- and C++-style comments are recognized, but other styles can be used by +// calling set_comment_style(). +class PROTOBUF_EXPORT Tokenizer { + public: + // Construct a Tokenizer that reads and tokenizes text from the given + // input stream and writes errors to the given error_collector. + // The caller keeps ownership of input and error_collector. + Tokenizer(ZeroCopyInputStream* input, ErrorCollector* error_collector); + ~Tokenizer(); + + enum TokenType { + TYPE_START, // Next() has not yet been called. + TYPE_END, // End of input reached. "text" is empty. + + TYPE_IDENTIFIER, // A sequence of letters, digits, and underscores, not + // starting with a digit. It is an error for a number + // to be followed by an identifier with no space in + // between. + TYPE_INTEGER, // A sequence of digits representing an integer. Normally + // the digits are decimal, but a prefix of "0x" indicates + // a hex number and a leading zero indicates octal, just + // like with C numeric literals. A leading negative sign + // is NOT included in the token; it's up to the parser to + // interpret the unary minus operator on its own. + TYPE_FLOAT, // A floating point literal, with a fractional part and/or + // an exponent. Always in decimal. Again, never + // negative. + TYPE_STRING, // A quoted sequence of escaped characters. Either single + // or double quotes can be used, but they must match. + // A string literal cannot cross a line break. + TYPE_SYMBOL, // Any other printable character, like '!' or '+'. + // Symbols are always a single character, so "!+$%" is + // four tokens. + }; + + // Structure representing a token read from the token stream. + struct Token { + TokenType type; + std::string text; // The exact text of the token as it appeared in + // the input. e.g. tokens of TYPE_STRING will still + // be escaped and in quotes. + + // "line" and "column" specify the position of the first character of + // the token within the input stream. They are zero-based. + int line; + ColumnNumber column; + ColumnNumber end_column; + }; + + // Get the current token. This is updated when Next() is called. Before + // the first call to Next(), current() has type TYPE_START and no contents. + const Token& current(); + + // Return the previous token -- i.e. what current() returned before the + // previous call to Next(). + const Token& previous(); + + // Advance to the next token. Returns false if the end of the input is + // reached. + bool Next(); + + // Like Next(), but also collects comments which appear between the previous + // and next tokens. + // + // Comments which appear to be attached to the previous token are stored + // in *prev_tailing_comments. Comments which appear to be attached to the + // next token are stored in *next_leading_comments. Comments appearing in + // between which do not appear to be attached to either will be added to + // detached_comments. Any of these parameters can be NULL to simply discard + // the comments. + // + // A series of line comments appearing on consecutive lines, with no other + // tokens appearing on those lines, will be treated as a single comment. + // + // Only the comment content is returned; comment markers (e.g. //) are + // stripped out. For block comments, leading whitespace and an asterisk will + // be stripped from the beginning of each line other than the first. Newlines + // are included in the output. + // + // Examples: + // + // optional int32 foo = 1; // Comment attached to foo. + // // Comment attached to bar. + // optional int32 bar = 2; + // + // optional string baz = 3; + // // Comment attached to baz. + // // Another line attached to baz. + // + // // Comment attached to qux. + // // + // // Another line attached to qux. + // optional double qux = 4; + // + // // Detached comment. This is not attached to qux or corge + // // because there are blank lines separating it from both. + // + // optional string corge = 5; + // /* Block comment attached + // * to corge. Leading asterisks + // * will be removed. */ + // /* Block comment attached to + // * grault. */ + // optional int32 grault = 6; + bool NextWithComments(std::string* prev_trailing_comments, + std::vector* detached_comments, + std::string* next_leading_comments); + + // Parse helpers --------------------------------------------------- + + // Parses a TYPE_FLOAT token. This never fails, so long as the text actually + // comes from a TYPE_FLOAT token parsed by Tokenizer. If it doesn't, the + // result is undefined (possibly an assert failure). + static double ParseFloat(const std::string& text); + + // Parses a TYPE_STRING token. This never fails, so long as the text actually + // comes from a TYPE_STRING token parsed by Tokenizer. If it doesn't, the + // result is undefined (possibly an assert failure). + static void ParseString(const std::string& text, std::string* output); + + // Identical to ParseString, but appends to output. + static void ParseStringAppend(const std::string& text, std::string* output); + + // Parses a TYPE_INTEGER token. Returns false if the result would be + // greater than max_value. Otherwise, returns true and sets *output to the + // result. If the text is not from a Token of type TYPE_INTEGER originally + // parsed by a Tokenizer, the result is undefined (possibly an assert + // failure). + static bool ParseInteger(const std::string& text, uint64 max_value, + uint64* output); + + // Options --------------------------------------------------------- + + // Set true to allow floats to be suffixed with the letter 'f'. Tokens + // which would otherwise be integers but which have the 'f' suffix will be + // forced to be interpreted as floats. For all other purposes, the 'f' is + // ignored. + void set_allow_f_after_float(bool value) { allow_f_after_float_ = value; } + + // Valid values for set_comment_style(). + enum CommentStyle { + // Line comments begin with "//", block comments are delimited by "/*" and + // "*/". + CPP_COMMENT_STYLE, + // Line comments begin with "#". No way to write block comments. + SH_COMMENT_STYLE + }; + + // Sets the comment style. + void set_comment_style(CommentStyle style) { comment_style_ = style; } + + // Whether to require whitespace between a number and a field name. + // Default is true. Do not use this; for Google-internal cleanup only. + void set_require_space_after_number(bool require) { + require_space_after_number_ = require; + } + + // Whether to allow string literals to span multiple lines. Default is false. + // Do not use this; for Google-internal cleanup only. + void set_allow_multiline_strings(bool allow) { + allow_multiline_strings_ = allow; + } + + // External helper: validate an identifier. + static bool IsIdentifier(const std::string& text); + + // ----------------------------------------------------------------- + private: + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(Tokenizer); + + Token current_; // Returned by current(). + Token previous_; // Returned by previous(). + + ZeroCopyInputStream* input_; + ErrorCollector* error_collector_; + + char current_char_; // == buffer_[buffer_pos_], updated by NextChar(). + const char* buffer_; // Current buffer returned from input_. + int buffer_size_; // Size of buffer_. + int buffer_pos_; // Current position within the buffer. + bool read_error_; // Did we previously encounter a read error? + + // Line and column number of current_char_ within the whole input stream. + int line_; + ColumnNumber column_; + + // String to which text should be appended as we advance through it. + // Call RecordTo(&str) to start recording and StopRecording() to stop. + // E.g. StartToken() calls RecordTo(¤t_.text). record_start_ is the + // position within the current buffer where recording started. + std::string* record_target_; + int record_start_; + + // Options. + bool allow_f_after_float_; + CommentStyle comment_style_; + bool require_space_after_number_; + bool allow_multiline_strings_; + + // Since we count columns we need to interpret tabs somehow. We'll take + // the standard 8-character definition for lack of any way to do better. + // This must match the documentation of ColumnNumber. + static const int kTabWidth = 8; + + // ----------------------------------------------------------------- + // Helper methods. + + // Consume this character and advance to the next one. + void NextChar(); + + // Read a new buffer from the input. + void Refresh(); + + inline void RecordTo(std::string* target); + inline void StopRecording(); + + // Called when the current character is the first character of a new + // token (not including whitespace or comments). + inline void StartToken(); + // Called when the current character is the first character after the + // end of the last token. After this returns, current_.text will + // contain all text consumed since StartToken() was called. + inline void EndToken(); + + // Convenience method to add an error at the current line and column. + void AddError(const std::string& message) { + error_collector_->AddError(line_, column_, message); + } + + // ----------------------------------------------------------------- + // The following four methods are used to consume tokens of specific + // types. They are actually used to consume all characters *after* + // the first, since the calling function consumes the first character + // in order to decide what kind of token is being read. + + // Read and consume a string, ending when the given delimiter is + // consumed. + void ConsumeString(char delimiter); + + // Read and consume a number, returning TYPE_FLOAT or TYPE_INTEGER + // depending on what was read. This needs to know if the first + // character was a zero in order to correctly recognize hex and octal + // numbers. + // It also needs to know if the first character was a . to parse floating + // point correctly. + TokenType ConsumeNumber(bool started_with_zero, bool started_with_dot); + + // Consume the rest of a line. + void ConsumeLineComment(std::string* content); + // Consume until "*/". + void ConsumeBlockComment(std::string* content); + + enum NextCommentStatus { + // Started a line comment. + LINE_COMMENT, + + // Started a block comment. + BLOCK_COMMENT, + + // Consumed a slash, then realized it wasn't a comment. current_ has + // been filled in with a slash token. The caller should return it. + SLASH_NOT_COMMENT, + + // We do not appear to be starting a comment here. + NO_COMMENT + }; + + // If we're at the start of a new comment, consume it and return what kind + // of comment it is. + NextCommentStatus TryConsumeCommentStart(); + + // ----------------------------------------------------------------- + // These helper methods make the parsing code more readable. The + // "character classes" referred to are defined at the top of the .cc file. + // Basically it is a C++ class with one method: + // static bool InClass(char c); + // The method returns true if c is a member of this "class", like "Letter" + // or "Digit". + + // Returns true if the current character is of the given character + // class, but does not consume anything. + template + inline bool LookingAt(); + + // If the current character is in the given class, consume it and return + // true. Otherwise return false. + // e.g. TryConsumeOne() + template + inline bool TryConsumeOne(); + + // Like above, but try to consume the specific character indicated. + inline bool TryConsume(char c); + + // Consume zero or more of the given character class. + template + inline void ConsumeZeroOrMore(); + + // Consume one or more of the given character class or log the given + // error message. + // e.g. ConsumeOneOrMore("Expected digits."); + template + inline void ConsumeOneOrMore(const char* error); +}; + +// inline methods ==================================================== +inline const Tokenizer::Token& Tokenizer::current() { return current_; } + +inline const Tokenizer::Token& Tokenizer::previous() { return previous_; } + +inline void Tokenizer::ParseString(const std::string& text, + std::string* output) { + output->clear(); + ParseStringAppend(text, output); +} + +} // namespace io +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_IO_TOKENIZER_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/io/zero_copy_stream.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/io/zero_copy_stream.h new file mode 100644 index 0000000000000000000000000000000000000000..b310d3a56b8949247d9b7cc4fec78fd0a356c12c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/io/zero_copy_stream.h @@ -0,0 +1,258 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: kenton@google.com (Kenton Varda) +// Based on original Protocol Buffers design by +// Sanjay Ghemawat, Jeff Dean, and others. +// +// This file contains the ZeroCopyInputStream and ZeroCopyOutputStream +// interfaces, which represent abstract I/O streams to and from which +// protocol buffers can be read and written. For a few simple +// implementations of these interfaces, see zero_copy_stream_impl.h. +// +// These interfaces are different from classic I/O streams in that they +// try to minimize the amount of data copying that needs to be done. +// To accomplish this, responsibility for allocating buffers is moved to +// the stream object, rather than being the responsibility of the caller. +// So, the stream can return a buffer which actually points directly into +// the final data structure where the bytes are to be stored, and the caller +// can interact directly with that buffer, eliminating an intermediate copy +// operation. +// +// As an example, consider the common case in which you are reading bytes +// from an array that is already in memory (or perhaps an mmap()ed file). +// With classic I/O streams, you would do something like: +// char buffer[BUFFER_SIZE]; +// input->Read(buffer, BUFFER_SIZE); +// DoSomething(buffer, BUFFER_SIZE); +// Then, the stream basically just calls memcpy() to copy the data from +// the array into your buffer. With a ZeroCopyInputStream, you would do +// this instead: +// const void* buffer; +// int size; +// input->Next(&buffer, &size); +// DoSomething(buffer, size); +// Here, no copy is performed. The input stream returns a pointer directly +// into the backing array, and the caller ends up reading directly from it. +// +// If you want to be able to read the old-fashion way, you can create +// a CodedInputStream or CodedOutputStream wrapping these objects and use +// their ReadRaw()/WriteRaw() methods. These will, of course, add a copy +// step, but Coded*Stream will handle buffering so at least it will be +// reasonably efficient. +// +// ZeroCopyInputStream example: +// // Read in a file and print its contents to stdout. +// int fd = open("myfile", O_RDONLY); +// ZeroCopyInputStream* input = new FileInputStream(fd); +// +// const void* buffer; +// int size; +// while (input->Next(&buffer, &size)) { +// cout.write(buffer, size); +// } +// +// delete input; +// close(fd); +// +// ZeroCopyOutputStream example: +// // Copy the contents of "infile" to "outfile", using plain read() for +// // "infile" but a ZeroCopyOutputStream for "outfile". +// int infd = open("infile", O_RDONLY); +// int outfd = open("outfile", O_WRONLY); +// ZeroCopyOutputStream* output = new FileOutputStream(outfd); +// +// void* buffer; +// int size; +// while (output->Next(&buffer, &size)) { +// int bytes = read(infd, buffer, size); +// if (bytes < size) { +// // Reached EOF. +// output->BackUp(size - bytes); +// break; +// } +// } +// +// delete output; +// close(infd); +// close(outfd); + +#ifndef GOOGLE_PROTOBUF_IO_ZERO_COPY_STREAM_H__ +#define GOOGLE_PROTOBUF_IO_ZERO_COPY_STREAM_H__ + + +#include + +#include +#include + + +namespace google { +namespace protobuf { +namespace io { + +// Defined in this file. +class ZeroCopyInputStream; +class ZeroCopyOutputStream; + +// Abstract interface similar to an input stream but designed to minimize +// copying. +class PROTOBUF_EXPORT ZeroCopyInputStream { + public: + ZeroCopyInputStream() {} + virtual ~ZeroCopyInputStream() {} + + // Obtains a chunk of data from the stream. + // + // Preconditions: + // * "size" and "data" are not NULL. + // + // Postconditions: + // * If the returned value is false, there is no more data to return or + // an error occurred. All errors are permanent. + // * Otherwise, "size" points to the actual number of bytes read and "data" + // points to a pointer to a buffer containing these bytes. + // * Ownership of this buffer remains with the stream, and the buffer + // remains valid only until some other method of the stream is called + // or the stream is destroyed. + // * It is legal for the returned buffer to have zero size, as long + // as repeatedly calling Next() eventually yields a buffer with non-zero + // size. + virtual bool Next(const void** data, int* size) = 0; + + // Backs up a number of bytes, so that the next call to Next() returns + // data again that was already returned by the last call to Next(). This + // is useful when writing procedures that are only supposed to read up + // to a certain point in the input, then return. If Next() returns a + // buffer that goes beyond what you wanted to read, you can use BackUp() + // to return to the point where you intended to finish. + // + // Preconditions: + // * The last method called must have been Next(). + // * count must be less than or equal to the size of the last buffer + // returned by Next(). + // + // Postconditions: + // * The last "count" bytes of the last buffer returned by Next() will be + // pushed back into the stream. Subsequent calls to Next() will return + // the same data again before producing new data. + virtual void BackUp(int count) = 0; + + // Skips a number of bytes. Returns false if the end of the stream is + // reached or some input error occurred. In the end-of-stream case, the + // stream is advanced to the end of the stream (so ByteCount() will return + // the total size of the stream). + virtual bool Skip(int count) = 0; + + // Returns the total number of bytes read since this object was created. + virtual int64_t ByteCount() const = 0; + + + private: + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(ZeroCopyInputStream); +}; + +// Abstract interface similar to an output stream but designed to minimize +// copying. +class PROTOBUF_EXPORT ZeroCopyOutputStream { + public: + ZeroCopyOutputStream() {} + virtual ~ZeroCopyOutputStream() {} + + // Obtains a buffer into which data can be written. Any data written + // into this buffer will eventually (maybe instantly, maybe later on) + // be written to the output. + // + // Preconditions: + // * "size" and "data" are not NULL. + // + // Postconditions: + // * If the returned value is false, an error occurred. All errors are + // permanent. + // * Otherwise, "size" points to the actual number of bytes in the buffer + // and "data" points to the buffer. + // * Ownership of this buffer remains with the stream, and the buffer + // remains valid only until some other method of the stream is called + // or the stream is destroyed. + // * Any data which the caller stores in this buffer will eventually be + // written to the output (unless BackUp() is called). + // * It is legal for the returned buffer to have zero size, as long + // as repeatedly calling Next() eventually yields a buffer with non-zero + // size. + virtual bool Next(void** data, int* size) = 0; + + // Backs up a number of bytes, so that the end of the last buffer returned + // by Next() is not actually written. This is needed when you finish + // writing all the data you want to write, but the last buffer was bigger + // than you needed. You don't want to write a bunch of garbage after the + // end of your data, so you use BackUp() to back up. + // + // Preconditions: + // * The last method called must have been Next(). + // * count must be less than or equal to the size of the last buffer + // returned by Next(). + // * The caller must not have written anything to the last "count" bytes + // of that buffer. + // + // Postconditions: + // * The last "count" bytes of the last buffer returned by Next() will be + // ignored. + virtual void BackUp(int count) = 0; + + // Returns the total number of bytes written since this object was created. + virtual int64_t ByteCount() const = 0; + + // Write a given chunk of data to the output. Some output streams may + // implement this in a way that avoids copying. Check AllowsAliasing() before + // calling WriteAliasedRaw(). It will GOOGLE_CHECK fail if WriteAliasedRaw() is + // called on a stream that does not allow aliasing. + // + // NOTE: It is caller's responsibility to ensure that the chunk of memory + // remains live until all of the data has been consumed from the stream. + virtual bool WriteAliasedRaw(const void* data, int size); + virtual bool AllowsAliasing() const { return false; } + + + private: + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(ZeroCopyOutputStream); +}; + +} // namespace io +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_IO_ZERO_COPY_STREAM_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/io/zero_copy_stream_impl.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/io/zero_copy_stream_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..b18d5451c90c08ac6e80d62534510c52051a1067 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/io/zero_copy_stream_impl.h @@ -0,0 +1,343 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: kenton@google.com (Kenton Varda) +// Based on original Protocol Buffers design by +// Sanjay Ghemawat, Jeff Dean, and others. +// +// This file contains common implementations of the interfaces defined in +// zero_copy_stream.h which are only included in the full (non-lite) +// protobuf library. These implementations include Unix file descriptors +// and C++ iostreams. See also: zero_copy_stream_impl_lite.h + +#ifndef GOOGLE_PROTOBUF_IO_ZERO_COPY_STREAM_IMPL_H__ +#define GOOGLE_PROTOBUF_IO_ZERO_COPY_STREAM_IMPL_H__ + + +#include +#include + +#include +#include +#include + + +#include + +namespace google { +namespace protobuf { +namespace io { + +// =================================================================== + +// A ZeroCopyInputStream which reads from a file descriptor. +// +// FileInputStream is preferred over using an ifstream with IstreamInputStream. +// The latter will introduce an extra layer of buffering, harming performance. +// Also, it's conceivable that FileInputStream could someday be enhanced +// to use zero-copy file descriptors on OSs which support them. +class PROTOBUF_EXPORT FileInputStream : public ZeroCopyInputStream { + public: + // Creates a stream that reads from the given Unix file descriptor. + // If a block_size is given, it specifies the number of bytes that + // should be read and returned with each call to Next(). Otherwise, + // a reasonable default is used. + explicit FileInputStream(int file_descriptor, int block_size = -1); + + // Flushes any buffers and closes the underlying file. Returns false if + // an error occurs during the process; use GetErrno() to examine the error. + // Even if an error occurs, the file descriptor is closed when this returns. + bool Close(); + + // By default, the file descriptor is not closed when the stream is + // destroyed. Call SetCloseOnDelete(true) to change that. WARNING: + // This leaves no way for the caller to detect if close() fails. If + // detecting close() errors is important to you, you should arrange + // to close the descriptor yourself. + void SetCloseOnDelete(bool value) { copying_input_.SetCloseOnDelete(value); } + + // If an I/O error has occurred on this file descriptor, this is the + // errno from that error. Otherwise, this is zero. Once an error + // occurs, the stream is broken and all subsequent operations will + // fail. + int GetErrno() const { return copying_input_.GetErrno(); } + + // implements ZeroCopyInputStream ---------------------------------- + bool Next(const void** data, int* size) override; + void BackUp(int count) override; + bool Skip(int count) override; + int64_t ByteCount() const override; + + private: + class PROTOBUF_EXPORT CopyingFileInputStream : public CopyingInputStream { + public: + CopyingFileInputStream(int file_descriptor); + ~CopyingFileInputStream() override; + + bool Close(); + void SetCloseOnDelete(bool value) { close_on_delete_ = value; } + int GetErrno() const { return errno_; } + + // implements CopyingInputStream --------------------------------- + int Read(void* buffer, int size) override; + int Skip(int count) override; + + private: + // The file descriptor. + const int file_; + bool close_on_delete_; + bool is_closed_; + + // The errno of the I/O error, if one has occurred. Otherwise, zero. + int errno_; + + // Did we try to seek once and fail? If so, we assume this file descriptor + // doesn't support seeking and won't try again. + bool previous_seek_failed_; + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(CopyingFileInputStream); + }; + + CopyingFileInputStream copying_input_; + CopyingInputStreamAdaptor impl_; + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(FileInputStream); +}; + +// =================================================================== + +// A ZeroCopyOutputStream which writes to a file descriptor. +// +// FileOutputStream is preferred over using an ofstream with +// OstreamOutputStream. The latter will introduce an extra layer of buffering, +// harming performance. Also, it's conceivable that FileOutputStream could +// someday be enhanced to use zero-copy file descriptors on OSs which +// support them. +class PROTOBUF_EXPORT FileOutputStream : public ZeroCopyOutputStream { + public: + // Creates a stream that writes to the given Unix file descriptor. + // If a block_size is given, it specifies the size of the buffers + // that should be returned by Next(). Otherwise, a reasonable default + // is used. + explicit FileOutputStream(int file_descriptor, int block_size = -1); + ~FileOutputStream() override; + + // Flushes any buffers and closes the underlying file. Returns false if + // an error occurs during the process; use GetErrno() to examine the error. + // Even if an error occurs, the file descriptor is closed when this returns. + bool Close(); + + // Flushes FileOutputStream's buffers but does not close the + // underlying file. No special measures are taken to ensure that + // underlying operating system file object is synchronized to disk. + bool Flush(); + + // By default, the file descriptor is not closed when the stream is + // destroyed. Call SetCloseOnDelete(true) to change that. WARNING: + // This leaves no way for the caller to detect if close() fails. If + // detecting close() errors is important to you, you should arrange + // to close the descriptor yourself. + void SetCloseOnDelete(bool value) { copying_output_.SetCloseOnDelete(value); } + + // If an I/O error has occurred on this file descriptor, this is the + // errno from that error. Otherwise, this is zero. Once an error + // occurs, the stream is broken and all subsequent operations will + // fail. + int GetErrno() const { return copying_output_.GetErrno(); } + + // implements ZeroCopyOutputStream --------------------------------- + bool Next(void** data, int* size) override; + void BackUp(int count) override; + int64_t ByteCount() const override; + + private: + class PROTOBUF_EXPORT CopyingFileOutputStream : public CopyingOutputStream { + public: + CopyingFileOutputStream(int file_descriptor); + ~CopyingFileOutputStream() override; + + bool Close(); + void SetCloseOnDelete(bool value) { close_on_delete_ = value; } + int GetErrno() const { return errno_; } + + // implements CopyingOutputStream -------------------------------- + bool Write(const void* buffer, int size) override; + + private: + // The file descriptor. + const int file_; + bool close_on_delete_; + bool is_closed_; + + // The errno of the I/O error, if one has occurred. Otherwise, zero. + int errno_; + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(CopyingFileOutputStream); + }; + + CopyingFileOutputStream copying_output_; + CopyingOutputStreamAdaptor impl_; + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(FileOutputStream); +}; + +// =================================================================== + +// A ZeroCopyInputStream which reads from a C++ istream. +// +// Note that for reading files (or anything represented by a file descriptor), +// FileInputStream is more efficient. +class PROTOBUF_EXPORT IstreamInputStream : public ZeroCopyInputStream { + public: + // Creates a stream that reads from the given C++ istream. + // If a block_size is given, it specifies the number of bytes that + // should be read and returned with each call to Next(). Otherwise, + // a reasonable default is used. + explicit IstreamInputStream(std::istream* stream, int block_size = -1); + + // implements ZeroCopyInputStream ---------------------------------- + bool Next(const void** data, int* size) override; + void BackUp(int count) override; + bool Skip(int count) override; + int64_t ByteCount() const override; + + private: + class PROTOBUF_EXPORT CopyingIstreamInputStream : public CopyingInputStream { + public: + CopyingIstreamInputStream(std::istream* input); + ~CopyingIstreamInputStream() override; + + // implements CopyingInputStream --------------------------------- + int Read(void* buffer, int size) override; + // (We use the default implementation of Skip().) + + private: + // The stream. + std::istream* input_; + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(CopyingIstreamInputStream); + }; + + CopyingIstreamInputStream copying_input_; + CopyingInputStreamAdaptor impl_; + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(IstreamInputStream); +}; + +// =================================================================== + +// A ZeroCopyOutputStream which writes to a C++ ostream. +// +// Note that for writing files (or anything represented by a file descriptor), +// FileOutputStream is more efficient. +class PROTOBUF_EXPORT OstreamOutputStream : public ZeroCopyOutputStream { + public: + // Creates a stream that writes to the given C++ ostream. + // If a block_size is given, it specifies the size of the buffers + // that should be returned by Next(). Otherwise, a reasonable default + // is used. + explicit OstreamOutputStream(std::ostream* stream, int block_size = -1); + ~OstreamOutputStream() override; + + // implements ZeroCopyOutputStream --------------------------------- + bool Next(void** data, int* size) override; + void BackUp(int count) override; + int64_t ByteCount() const override; + + private: + class PROTOBUF_EXPORT CopyingOstreamOutputStream + : public CopyingOutputStream { + public: + CopyingOstreamOutputStream(std::ostream* output); + ~CopyingOstreamOutputStream() override; + + // implements CopyingOutputStream -------------------------------- + bool Write(const void* buffer, int size) override; + + private: + // The stream. + std::ostream* output_; + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(CopyingOstreamOutputStream); + }; + + CopyingOstreamOutputStream copying_output_; + CopyingOutputStreamAdaptor impl_; + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(OstreamOutputStream); +}; + +// =================================================================== + +// A ZeroCopyInputStream which reads from several other streams in sequence. +// ConcatenatingInputStream is unable to distinguish between end-of-stream +// and read errors in the underlying streams, so it assumes any errors mean +// end-of-stream. So, if the underlying streams fail for any other reason, +// ConcatenatingInputStream may do odd things. It is suggested that you do +// not use ConcatenatingInputStream on streams that might produce read errors +// other than end-of-stream. +class PROTOBUF_EXPORT ConcatenatingInputStream : public ZeroCopyInputStream { + public: + // All streams passed in as well as the array itself must remain valid + // until the ConcatenatingInputStream is destroyed. + ConcatenatingInputStream(ZeroCopyInputStream* const streams[], int count); + ~ConcatenatingInputStream() override = default; + + // implements ZeroCopyInputStream ---------------------------------- + bool Next(const void** data, int* size) override; + void BackUp(int count) override; + bool Skip(int count) override; + int64_t ByteCount() const override; + + + private: + // As streams are retired, streams_ is incremented and count_ is + // decremented. + ZeroCopyInputStream* const* streams_; + int stream_count_; + int64 bytes_retired_; // Bytes read from previous streams. + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(ConcatenatingInputStream); +}; + +// =================================================================== + +} // namespace io +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_IO_ZERO_COPY_STREAM_IMPL_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/io/zero_copy_stream_impl_lite.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/io/zero_copy_stream_impl_lite.h new file mode 100644 index 0000000000000000000000000000000000000000..83d2ac0dcf4e85f8055afae40ae5bd221fbe2383 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/io/zero_copy_stream_impl_lite.h @@ -0,0 +1,411 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: kenton@google.com (Kenton Varda) +// Based on original Protocol Buffers design by +// Sanjay Ghemawat, Jeff Dean, and others. +// +// This file contains common implementations of the interfaces defined in +// zero_copy_stream.h which are included in the "lite" protobuf library. +// These implementations cover I/O on raw arrays and strings, as well as +// adaptors which make it easy to implement streams based on traditional +// streams. Of course, many users will probably want to write their own +// implementations of these interfaces specific to the particular I/O +// abstractions they prefer to use, but these should cover the most common +// cases. + +#ifndef GOOGLE_PROTOBUF_IO_ZERO_COPY_STREAM_IMPL_LITE_H__ +#define GOOGLE_PROTOBUF_IO_ZERO_COPY_STREAM_IMPL_LITE_H__ + + +#include +#include +#include + +#include +#include +#include +#include + + +#include + +namespace google { +namespace protobuf { +namespace io { + +// =================================================================== + +// A ZeroCopyInputStream backed by an in-memory array of bytes. +class PROTOBUF_EXPORT ArrayInputStream : public ZeroCopyInputStream { + public: + // Create an InputStream that returns the bytes pointed to by "data". + // "data" remains the property of the caller but must remain valid until + // the stream is destroyed. If a block_size is given, calls to Next() + // will return data blocks no larger than the given size. Otherwise, the + // first call to Next() returns the entire array. block_size is mainly + // useful for testing; in production you would probably never want to set + // it. + ArrayInputStream(const void* data, int size, int block_size = -1); + ~ArrayInputStream() override = default; + + // implements ZeroCopyInputStream ---------------------------------- + bool Next(const void** data, int* size) override; + void BackUp(int count) override; + bool Skip(int count) override; + int64_t ByteCount() const override; + + + private: + const uint8* const data_; // The byte array. + const int size_; // Total size of the array. + const int block_size_; // How many bytes to return at a time. + + int position_; + int last_returned_size_; // How many bytes we returned last time Next() + // was called (used for error checking only). + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(ArrayInputStream); +}; + +// =================================================================== + +// A ZeroCopyOutputStream backed by an in-memory array of bytes. +class PROTOBUF_EXPORT ArrayOutputStream : public ZeroCopyOutputStream { + public: + // Create an OutputStream that writes to the bytes pointed to by "data". + // "data" remains the property of the caller but must remain valid until + // the stream is destroyed. If a block_size is given, calls to Next() + // will return data blocks no larger than the given size. Otherwise, the + // first call to Next() returns the entire array. block_size is mainly + // useful for testing; in production you would probably never want to set + // it. + ArrayOutputStream(void* data, int size, int block_size = -1); + ~ArrayOutputStream() override = default; + + // implements ZeroCopyOutputStream --------------------------------- + bool Next(void** data, int* size) override; + void BackUp(int count) override; + int64_t ByteCount() const override; + + private: + uint8* const data_; // The byte array. + const int size_; // Total size of the array. + const int block_size_; // How many bytes to return at a time. + + int position_; + int last_returned_size_; // How many bytes we returned last time Next() + // was called (used for error checking only). + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(ArrayOutputStream); +}; + +// =================================================================== + +// A ZeroCopyOutputStream which appends bytes to a string. +class PROTOBUF_EXPORT StringOutputStream : public ZeroCopyOutputStream { + public: + // Create a StringOutputStream which appends bytes to the given string. + // The string remains property of the caller, but it is mutated in arbitrary + // ways and MUST NOT be accessed in any way until you're done with the + // stream. Either be sure there's no further usage, or (safest) destroy the + // stream before using the contents. + // + // Hint: If you call target->reserve(n) before creating the stream, + // the first call to Next() will return at least n bytes of buffer + // space. + explicit StringOutputStream(std::string* target); + ~StringOutputStream() override = default; + + // implements ZeroCopyOutputStream --------------------------------- + bool Next(void** data, int* size) override; + void BackUp(int count) override; + int64_t ByteCount() const override; + + private: + static const int kMinimumSize = 16; + + std::string* target_; + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(StringOutputStream); +}; + +// Note: There is no StringInputStream. Instead, just create an +// ArrayInputStream as follows: +// ArrayInputStream input(str.data(), str.size()); + +// =================================================================== + +// A generic traditional input stream interface. +// +// Lots of traditional input streams (e.g. file descriptors, C stdio +// streams, and C++ iostreams) expose an interface where every read +// involves copying bytes into a buffer. If you want to take such an +// interface and make a ZeroCopyInputStream based on it, simply implement +// CopyingInputStream and then use CopyingInputStreamAdaptor. +// +// CopyingInputStream implementations should avoid buffering if possible. +// CopyingInputStreamAdaptor does its own buffering and will read data +// in large blocks. +class PROTOBUF_EXPORT CopyingInputStream { + public: + virtual ~CopyingInputStream() {} + + // Reads up to "size" bytes into the given buffer. Returns the number of + // bytes read. Read() waits until at least one byte is available, or + // returns zero if no bytes will ever become available (EOF), or -1 if a + // permanent read error occurred. + virtual int Read(void* buffer, int size) = 0; + + // Skips the next "count" bytes of input. Returns the number of bytes + // actually skipped. This will always be exactly equal to "count" unless + // EOF was reached or a permanent read error occurred. + // + // The default implementation just repeatedly calls Read() into a scratch + // buffer. + virtual int Skip(int count); +}; + +// A ZeroCopyInputStream which reads from a CopyingInputStream. This is +// useful for implementing ZeroCopyInputStreams that read from traditional +// streams. Note that this class is not really zero-copy. +// +// If you want to read from file descriptors or C++ istreams, this is +// already implemented for you: use FileInputStream or IstreamInputStream +// respectively. +class PROTOBUF_EXPORT CopyingInputStreamAdaptor : public ZeroCopyInputStream { + public: + // Creates a stream that reads from the given CopyingInputStream. + // If a block_size is given, it specifies the number of bytes that + // should be read and returned with each call to Next(). Otherwise, + // a reasonable default is used. The caller retains ownership of + // copying_stream unless SetOwnsCopyingStream(true) is called. + explicit CopyingInputStreamAdaptor(CopyingInputStream* copying_stream, + int block_size = -1); + ~CopyingInputStreamAdaptor() override; + + // Call SetOwnsCopyingStream(true) to tell the CopyingInputStreamAdaptor to + // delete the underlying CopyingInputStream when it is destroyed. + void SetOwnsCopyingStream(bool value) { owns_copying_stream_ = value; } + + // implements ZeroCopyInputStream ---------------------------------- + bool Next(const void** data, int* size) override; + void BackUp(int count) override; + bool Skip(int count) override; + int64_t ByteCount() const override; + + private: + // Insures that buffer_ is not NULL. + void AllocateBufferIfNeeded(); + // Frees the buffer and resets buffer_used_. + void FreeBuffer(); + + // The underlying copying stream. + CopyingInputStream* copying_stream_; + bool owns_copying_stream_; + + // True if we have seen a permanent error from the underlying stream. + bool failed_; + + // The current position of copying_stream_, relative to the point where + // we started reading. + int64 position_; + + // Data is read into this buffer. It may be NULL if no buffer is currently + // in use. Otherwise, it points to an array of size buffer_size_. + std::unique_ptr buffer_; + const int buffer_size_; + + // Number of valid bytes currently in the buffer (i.e. the size last + // returned by Next()). 0 <= buffer_used_ <= buffer_size_. + int buffer_used_; + + // Number of bytes in the buffer which were backed up over by a call to + // BackUp(). These need to be returned again. + // 0 <= backup_bytes_ <= buffer_used_ + int backup_bytes_; + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(CopyingInputStreamAdaptor); +}; + +// =================================================================== + +// A generic traditional output stream interface. +// +// Lots of traditional output streams (e.g. file descriptors, C stdio +// streams, and C++ iostreams) expose an interface where every write +// involves copying bytes from a buffer. If you want to take such an +// interface and make a ZeroCopyOutputStream based on it, simply implement +// CopyingOutputStream and then use CopyingOutputStreamAdaptor. +// +// CopyingOutputStream implementations should avoid buffering if possible. +// CopyingOutputStreamAdaptor does its own buffering and will write data +// in large blocks. +class PROTOBUF_EXPORT CopyingOutputStream { + public: + virtual ~CopyingOutputStream() {} + + // Writes "size" bytes from the given buffer to the output. Returns true + // if successful, false on a write error. + virtual bool Write(const void* buffer, int size) = 0; +}; + +// A ZeroCopyOutputStream which writes to a CopyingOutputStream. This is +// useful for implementing ZeroCopyOutputStreams that write to traditional +// streams. Note that this class is not really zero-copy. +// +// If you want to write to file descriptors or C++ ostreams, this is +// already implemented for you: use FileOutputStream or OstreamOutputStream +// respectively. +class PROTOBUF_EXPORT CopyingOutputStreamAdaptor : public ZeroCopyOutputStream { + public: + // Creates a stream that writes to the given Unix file descriptor. + // If a block_size is given, it specifies the size of the buffers + // that should be returned by Next(). Otherwise, a reasonable default + // is used. + explicit CopyingOutputStreamAdaptor(CopyingOutputStream* copying_stream, + int block_size = -1); + ~CopyingOutputStreamAdaptor() override; + + // Writes all pending data to the underlying stream. Returns false if a + // write error occurred on the underlying stream. (The underlying + // stream itself is not necessarily flushed.) + bool Flush(); + + // Call SetOwnsCopyingStream(true) to tell the CopyingOutputStreamAdaptor to + // delete the underlying CopyingOutputStream when it is destroyed. + void SetOwnsCopyingStream(bool value) { owns_copying_stream_ = value; } + + // implements ZeroCopyOutputStream --------------------------------- + bool Next(void** data, int* size) override; + void BackUp(int count) override; + int64_t ByteCount() const override; + + private: + // Write the current buffer, if it is present. + bool WriteBuffer(); + // Insures that buffer_ is not NULL. + void AllocateBufferIfNeeded(); + // Frees the buffer. + void FreeBuffer(); + + // The underlying copying stream. + CopyingOutputStream* copying_stream_; + bool owns_copying_stream_; + + // True if we have seen a permanent error from the underlying stream. + bool failed_; + + // The current position of copying_stream_, relative to the point where + // we started writing. + int64 position_; + + // Data is written from this buffer. It may be NULL if no buffer is + // currently in use. Otherwise, it points to an array of size buffer_size_. + std::unique_ptr buffer_; + const int buffer_size_; + + // Number of valid bytes currently in the buffer (i.e. the size last + // returned by Next()). When BackUp() is called, we just reduce this. + // 0 <= buffer_used_ <= buffer_size_. + int buffer_used_; + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(CopyingOutputStreamAdaptor); +}; + +// =================================================================== + +// A ZeroCopyInputStream which wraps some other stream and limits it to +// a particular byte count. +class PROTOBUF_EXPORT LimitingInputStream : public ZeroCopyInputStream { + public: + LimitingInputStream(ZeroCopyInputStream* input, int64 limit); + ~LimitingInputStream() override; + + // implements ZeroCopyInputStream ---------------------------------- + bool Next(const void** data, int* size) override; + void BackUp(int count) override; + bool Skip(int count) override; + int64_t ByteCount() const override; + + + private: + ZeroCopyInputStream* input_; + int64 limit_; // Decreases as we go, becomes negative if we overshoot. + int64 prior_bytes_read_; // Bytes read on underlying stream at construction + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(LimitingInputStream); +}; + + +// =================================================================== + +// mutable_string_data() and as_string_data() are workarounds to improve +// the performance of writing new data to an existing string. Unfortunately +// the methods provided by the string class are suboptimal, and using memcpy() +// is mildly annoying because it requires its pointer args to be non-NULL even +// if we ask it to copy 0 bytes. Furthermore, string_as_array() has the +// property that it always returns NULL if its arg is the empty string, exactly +// what we want to avoid if we're using it in conjunction with memcpy()! +// With C++11, the desired memcpy() boils down to memcpy(..., &(*s)[0], size), +// where s is a string*. Without C++11, &(*s)[0] is not guaranteed to be safe, +// so we use string_as_array(), and live with the extra logic that tests whether +// *s is empty. + +// Return a pointer to mutable characters underlying the given string. The +// return value is valid until the next time the string is resized. We +// trust the caller to treat the return value as an array of length s->size(). +inline char* mutable_string_data(std::string* s) { + // This should be simpler & faster than string_as_array() because the latter + // is guaranteed to return NULL when *s is empty, so it has to check for that. + return &(*s)[0]; +} + +// as_string_data(s) is equivalent to +// ({ char* p = mutable_string_data(s); make_pair(p, p != NULL); }) +// Sometimes it's faster: in some scenarios p cannot be NULL, and then the +// code can avoid that check. +inline std::pair as_string_data(std::string* s) { + char* p = mutable_string_data(s); + return std::make_pair(p, true); +} + +} // namespace io +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_IO_ZERO_COPY_STREAM_IMPL_LITE_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/map.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/map.h new file mode 100644 index 0000000000000000000000000000000000000000..540c914b1d675aab8ba3a843500c098f4959863f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/map.h @@ -0,0 +1,1280 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// This file defines the map container and its helpers to support protobuf maps. +// +// The Map and MapIterator types are provided by this header file. +// Please avoid using other types defined here, unless they are public +// types within Map or MapIterator, such as Map::value_type. + +#ifndef GOOGLE_PROTOBUF_MAP_H__ +#define GOOGLE_PROTOBUF_MAP_H__ + +#include +#include +#include +#include // To support Visual Studio 2008 +#include +#include +#include +#include + +#if defined(__cpp_lib_string_view) +#include +#endif // defined(__cpp_lib_string_view) + +#include +#include +#include +#include +#include + +#ifdef SWIG +#error "You cannot SWIG proto headers" +#endif + +#include + +namespace google { +namespace protobuf { + +template +class Map; + +class MapIterator; + +template +struct is_proto_enum; + +namespace internal { +template +class MapFieldLite; + +template +class MapField; + +template +class TypeDefinedMapFieldBase; + +class DynamicMapField; + +class GeneratedMessageReflection; + +// re-implement std::allocator to use arena allocator for memory allocation. +// Used for Map implementation. Users should not use this class +// directly. +template +class MapAllocator { + public: + using value_type = U; + using pointer = value_type*; + using const_pointer = const value_type*; + using reference = value_type&; + using const_reference = const value_type&; + using size_type = size_t; + using difference_type = ptrdiff_t; + + MapAllocator() : arena_(nullptr) {} + explicit MapAllocator(Arena* arena) : arena_(arena) {} + template + MapAllocator(const MapAllocator& allocator) // NOLINT(runtime/explicit) + : arena_(allocator.arena()) {} + + pointer allocate(size_type n, const void* /* hint */ = nullptr) { + // If arena is not given, malloc needs to be called which doesn't + // construct element object. + if (arena_ == nullptr) { + return static_cast(::operator new(n * sizeof(value_type))); + } else { + return reinterpret_cast( + Arena::CreateArray(arena_, n * sizeof(value_type))); + } + } + + void deallocate(pointer p, size_type n) { + if (arena_ == nullptr) { +#if defined(__GXX_DELETE_WITH_SIZE__) || defined(__cpp_sized_deallocation) + ::operator delete(p, n * sizeof(value_type)); +#else + (void)n; + ::operator delete(p); +#endif + } + } + +#if __cplusplus >= 201103L && !defined(GOOGLE_PROTOBUF_OS_APPLE) && \ + !defined(GOOGLE_PROTOBUF_OS_NACL) && \ + !defined(GOOGLE_PROTOBUF_OS_EMSCRIPTEN) + template + void construct(NodeType* p, Args&&... args) { + // Clang 3.6 doesn't compile static casting to void* directly. (Issue + // #1266) According C++ standard 5.2.9/1: "The static_cast operator shall + // not cast away constness". So first the maybe const pointer is casted to + // const void* and after the const void* is const casted. + new (const_cast(static_cast(p))) + NodeType(std::forward(args)...); + } + + template + void destroy(NodeType* p) { + p->~NodeType(); + } +#else + void construct(pointer p, const_reference t) { new (p) value_type(t); } + + void destroy(pointer p) { p->~value_type(); } +#endif + + template + struct rebind { + using other = MapAllocator; + }; + + template + bool operator==(const MapAllocator& other) const { + return arena_ == other.arena_; + } + + template + bool operator!=(const MapAllocator& other) const { + return arena_ != other.arena_; + } + + // To support Visual Studio 2008 + size_type max_size() const { + // parentheses around (std::...:max) prevents macro warning of max() + return (std::numeric_limits::max)(); + } + + // To support gcc-4.4, which does not properly + // support templated friend classes + Arena* arena() const { return arena_; } + + private: + using DestructorSkippable_ = void; + Arena* const arena_; +}; + +template +using KeyForTree = + typename std::conditional::value, T, + std::reference_wrapper>::type; + +// Default case: Not transparent. +// We use std::hash/std::less and all the lookup functions +// only accept `key_type`. +template +struct TransparentSupport { + using hash = std::hash; + using less = std::less; + + static bool Equals(const key_type& a, const key_type& b) { return a == b; } + + template + using key_arg = key_type; +}; + +#if defined(__cpp_lib_string_view) +// If std::string_view is available, we add transparent support for std::string +// keys. We use std::hash as it supports the input types we +// care about. The lookup functions accept arbitrary `K`. This will include any +// key type that is convertible to std::string_view. +template <> +struct TransparentSupport { + static std::string_view ImplicitConvert(std::string_view str) { return str; } + // If the element is not convertible to std::string_view, try to convert to + // std::string first. + // The template makes this overload lose resolution when both have the same + // rank otherwise. + template + static std::string_view ImplicitConvert(const std::string& str) { + return str; + } + + struct hash : private std::hash { + using is_transparent = void; + + template + size_t operator()(const T& str) const { + return base()(ImplicitConvert(str)); + } + + private: + const std::hash& base() const { return *this; } + }; + struct less { + using is_transparent = void; + + template + bool operator()(const T& t, const U& u) const { + return ImplicitConvert(t) < ImplicitConvert(u); + } + }; + + template + static bool Equals(const T& t, const U& u) { + return ImplicitConvert(t) == ImplicitConvert(u); + } + + template + using key_arg = K; +}; +#endif // defined(__cpp_lib_string_view) + +} // namespace internal + +// This is the class for Map's internal value_type. Instead of using +// std::pair as value_type, we use this class which provides us more control of +// its process of construction and destruction. +template +struct MapPair { + using first_type = const Key; + using second_type = T; + + MapPair(const Key& other_first, const T& other_second) + : first(other_first), second(other_second) {} + explicit MapPair(const Key& other_first) : first(other_first), second() {} + MapPair(const MapPair& other) : first(other.first), second(other.second) {} + + ~MapPair() {} + + // Implicitly convertible to std::pair of compatible types. + template + operator std::pair() const { // NOLINT(runtime/explicit) + return std::pair(first, second); + } + + const Key first; + T second; + + private: + friend class Arena; + friend class Map; +}; + +// Map is an associative container type used to store protobuf map +// fields. Each Map instance may or may not use a different hash function, a +// different iteration order, and so on. E.g., please don't examine +// implementation details to decide if the following would work: +// Map m0, m1; +// m0[0] = m1[0] = m0[1] = m1[1] = 0; +// assert(m0.begin()->first == m1.begin()->first); // Bug! +// +// Map's interface is similar to std::unordered_map, except that Map is not +// designed to play well with exceptions. +template +class Map { + public: + using key_type = Key; + using mapped_type = T; + using value_type = MapPair; + + using pointer = value_type*; + using const_pointer = const value_type*; + using reference = value_type&; + using const_reference = const value_type&; + + using size_type = size_t; + using hasher = typename internal::TransparentSupport::hash; + + Map() : arena_(nullptr), default_enum_value_(0) { Init(); } + explicit Map(Arena* arena) : arena_(arena), default_enum_value_(0) { Init(); } + + Map(const Map& other) + : arena_(nullptr), default_enum_value_(other.default_enum_value_) { + Init(); + insert(other.begin(), other.end()); + } + + Map(Map&& other) noexcept : Map() { + if (other.arena_) { + *this = other; + } else { + swap(other); + } + } + Map& operator=(Map&& other) noexcept { + if (this != &other) { + if (arena_ != other.arena_) { + *this = other; + } else { + swap(other); + } + } + return *this; + } + + template + Map(const InputIt& first, const InputIt& last) + : arena_(nullptr), default_enum_value_(0) { + Init(); + insert(first, last); + } + + ~Map() { + if (arena_ == nullptr) { + clear(); + delete elements_; + } + } + + private: + void Init() { elements_ = Arena::CreateMessage(arena_, 0); } + + using Allocator = internal::MapAllocator; + + // InnerMap is a generic hash-based map. It doesn't contain any + // protocol-buffer-specific logic. It is a chaining hash map with the + // additional feature that some buckets can be converted to use an ordered + // container. This ensures O(lg n) bounds on find, insert, and erase, while + // avoiding the overheads of ordered containers most of the time. + // + // The implementation doesn't need the full generality of unordered_map, + // and it doesn't have it. More bells and whistles can be added as needed. + // Some implementation details: + // 1. The hash function has type hasher and the equality function + // equal_to. We inherit from hasher to save space + // (empty-base-class optimization). + // 2. The number of buckets is a power of two. + // 3. Buckets are converted to trees in pairs: if we convert bucket b then + // buckets b and b^1 will share a tree. Invariant: buckets b and b^1 have + // the same non-null value iff they are sharing a tree. (An alternative + // implementation strategy would be to have a tag bit per bucket.) + // 4. As is typical for hash_map and such, the Keys and Values are always + // stored in linked list nodes. Pointers to elements are never invalidated + // until the element is deleted. + // 5. The trees' payload type is pointer to linked-list node. Tree-converting + // a bucket doesn't copy Key-Value pairs. + // 6. Once we've tree-converted a bucket, it is never converted back. However, + // the items a tree contains may wind up assigned to trees or lists upon a + // rehash. + // 7. The code requires no C++ features from C++14 or later. + // 8. Mutations to a map do not invalidate the map's iterators, pointers to + // elements, or references to elements. + // 9. Except for erase(iterator), any non-const method can reorder iterators. + // 10. InnerMap uses KeyForTree when using the Tree representation, which + // is either `Key`, if Key is a scalar, or `reference_wrapper` + // otherwise. This avoids unncessary copies of string keys, for example. + class InnerMap : private hasher { + public: + explicit InnerMap(size_type n) : InnerMap(nullptr, n) {} + InnerMap(Arena* arena, size_type n) + : hasher(), + num_elements_(0), + seed_(Seed()), + table_(nullptr), + alloc_(arena) { + n = TableSize(n); + table_ = CreateEmptyTable(n); + num_buckets_ = index_of_first_non_null_ = n; + } + + ~InnerMap() { + if (table_ != nullptr) { + clear(); + Dealloc(table_, num_buckets_); + } + } + + private: + enum { kMinTableSize = 8 }; + + // Linked-list nodes, as one would expect for a chaining hash table. + struct Node { + value_type kv; + Node* next; + }; + + // Trees. The payload type is a copy of Key, so that we can query the tree + // with Keys that are not in any particular data structure. + // The value is a void* pointing to Node. We use void* instead of Node* to + // avoid code bloat. That way there is only one instantiation of the tree + // class per key type. + using TreeAllocator = typename Allocator::template rebind< + std::pair, void*>>::other; + using Tree = std::map, void*, + typename internal::TransparentSupport::less, + TreeAllocator>; + using TreeIterator = typename Tree::iterator; + + static Node* NodeFromTreeIterator(TreeIterator it) { + return static_cast(it->second); + } + + // iterator and const_iterator are instantiations of iterator_base. + template + class iterator_base { + public: + using reference = KeyValueType&; + using pointer = KeyValueType*; + + // Invariants: + // node_ is always correct. This is handy because the most common + // operations are operator* and operator-> and they only use node_. + // When node_ is set to a non-null value, all the other non-const fields + // are updated to be correct also, but those fields can become stale + // if the underlying map is modified. When those fields are needed they + // are rechecked, and updated if necessary. + iterator_base() : node_(nullptr), m_(nullptr), bucket_index_(0) {} + + explicit iterator_base(const InnerMap* m) : m_(m) { + SearchFrom(m->index_of_first_non_null_); + } + + // Any iterator_base can convert to any other. This is overkill, and we + // rely on the enclosing class to use it wisely. The standard "iterator + // can convert to const_iterator" is OK but the reverse direction is not. + template + explicit iterator_base(const iterator_base& it) + : node_(it.node_), m_(it.m_), bucket_index_(it.bucket_index_) {} + + iterator_base(Node* n, const InnerMap* m, size_type index) + : node_(n), m_(m), bucket_index_(index) {} + + iterator_base(TreeIterator tree_it, const InnerMap* m, size_type index) + : node_(NodeFromTreeIterator(tree_it)), m_(m), bucket_index_(index) { + // Invariant: iterators that use buckets with trees have an even + // bucket_index_. + GOOGLE_DCHECK_EQ(bucket_index_ % 2, 0u); + } + + // Advance through buckets, looking for the first that isn't empty. + // If nothing non-empty is found then leave node_ == nullptr. + void SearchFrom(size_type start_bucket) { + GOOGLE_DCHECK(m_->index_of_first_non_null_ == m_->num_buckets_ || + m_->table_[m_->index_of_first_non_null_] != nullptr); + node_ = nullptr; + for (bucket_index_ = start_bucket; bucket_index_ < m_->num_buckets_; + bucket_index_++) { + if (m_->TableEntryIsNonEmptyList(bucket_index_)) { + node_ = static_cast(m_->table_[bucket_index_]); + break; + } else if (m_->TableEntryIsTree(bucket_index_)) { + Tree* tree = static_cast(m_->table_[bucket_index_]); + GOOGLE_DCHECK(!tree->empty()); + node_ = NodeFromTreeIterator(tree->begin()); + break; + } + } + } + + reference operator*() const { return node_->kv; } + pointer operator->() const { return &(operator*()); } + + friend bool operator==(const iterator_base& a, const iterator_base& b) { + return a.node_ == b.node_; + } + friend bool operator!=(const iterator_base& a, const iterator_base& b) { + return a.node_ != b.node_; + } + + iterator_base& operator++() { + if (node_->next == nullptr) { + TreeIterator tree_it; + const bool is_list = revalidate_if_necessary(&tree_it); + if (is_list) { + SearchFrom(bucket_index_ + 1); + } else { + GOOGLE_DCHECK_EQ(bucket_index_ & 1, 0u); + Tree* tree = static_cast(m_->table_[bucket_index_]); + if (++tree_it == tree->end()) { + SearchFrom(bucket_index_ + 2); + } else { + node_ = NodeFromTreeIterator(tree_it); + } + } + } else { + node_ = node_->next; + } + return *this; + } + + iterator_base operator++(int /* unused */) { + iterator_base tmp = *this; + ++*this; + return tmp; + } + + // Assumes node_ and m_ are correct and non-null, but other fields may be + // stale. Fix them as needed. Then return true iff node_ points to a + // Node in a list. If false is returned then *it is modified to be + // a valid iterator for node_. + bool revalidate_if_necessary(TreeIterator* it) { + GOOGLE_DCHECK(node_ != nullptr && m_ != nullptr); + // Force bucket_index_ to be in range. + bucket_index_ &= (m_->num_buckets_ - 1); + // Common case: the bucket we think is relevant points to node_. + if (m_->table_[bucket_index_] == static_cast(node_)) return true; + // Less common: the bucket is a linked list with node_ somewhere in it, + // but not at the head. + if (m_->TableEntryIsNonEmptyList(bucket_index_)) { + Node* l = static_cast(m_->table_[bucket_index_]); + while ((l = l->next) != nullptr) { + if (l == node_) { + return true; + } + } + } + // Well, bucket_index_ still might be correct, but probably + // not. Revalidate just to be sure. This case is rare enough that we + // don't worry about potential optimizations, such as having a custom + // find-like method that compares Node* instead of the key. + iterator_base i(m_->find(node_->kv.first, it)); + bucket_index_ = i.bucket_index_; + return m_->TableEntryIsList(bucket_index_); + } + + Node* node_; + const InnerMap* m_; + size_type bucket_index_; + }; + + public: + using iterator = iterator_base; + using const_iterator = iterator_base; + + iterator begin() { return iterator(this); } + iterator end() { return iterator(); } + const_iterator begin() const { return const_iterator(this); } + const_iterator end() const { return const_iterator(); } + + void clear() { + for (size_type b = 0; b < num_buckets_; b++) { + if (TableEntryIsNonEmptyList(b)) { + Node* node = static_cast(table_[b]); + table_[b] = nullptr; + do { + Node* next = node->next; + DestroyNode(node); + node = next; + } while (node != nullptr); + } else if (TableEntryIsTree(b)) { + Tree* tree = static_cast(table_[b]); + GOOGLE_DCHECK(table_[b] == table_[b + 1] && (b & 1) == 0); + table_[b] = table_[b + 1] = nullptr; + typename Tree::iterator tree_it = tree->begin(); + do { + Node* node = NodeFromTreeIterator(tree_it); + typename Tree::iterator next = tree_it; + ++next; + tree->erase(tree_it); + DestroyNode(node); + tree_it = next; + } while (tree_it != tree->end()); + DestroyTree(tree); + b++; + } + } + num_elements_ = 0; + index_of_first_non_null_ = num_buckets_; + } + + const hasher& hash_function() const { return *this; } + + static size_type max_size() { + return static_cast(1) << (sizeof(void**) >= 8 ? 60 : 28); + } + size_type size() const { return num_elements_; } + bool empty() const { return size() == 0; } + + template + iterator find(const K& k) { + return iterator(FindHelper(k).first); + } + + // Insert the key into the map, if not present. In that case, the value will + // be value initialized. + std::pair insert(const Key& k) { + std::pair p = FindHelper(k); + // Case 1: key was already present. + if (p.first.node_ != nullptr) + return std::make_pair(iterator(p.first), false); + // Case 2: insert. + if (ResizeIfLoadIsOutOfRange(num_elements_ + 1)) { + p = FindHelper(k); + } + const size_type b = p.second; // bucket number + Node* node; + if (alloc_.arena() == nullptr) { + node = new Node{value_type(k), nullptr}; + } else { + node = Alloc(1); + Arena::CreateInArenaStorage(const_cast(&node->kv.first), + alloc_.arena(), k); + Arena::CreateInArenaStorage(&node->kv.second, alloc_.arena()); + } + + iterator result = InsertUnique(b, node); + ++num_elements_; + return std::make_pair(result, true); + } + + value_type& operator[](const Key& k) { return *insert(k).first; } + + void erase(iterator it) { + GOOGLE_DCHECK_EQ(it.m_, this); + typename Tree::iterator tree_it; + const bool is_list = it.revalidate_if_necessary(&tree_it); + size_type b = it.bucket_index_; + Node* const item = it.node_; + if (is_list) { + GOOGLE_DCHECK(TableEntryIsNonEmptyList(b)); + Node* head = static_cast(table_[b]); + head = EraseFromLinkedList(item, head); + table_[b] = static_cast(head); + } else { + GOOGLE_DCHECK(TableEntryIsTree(b)); + Tree* tree = static_cast(table_[b]); + tree->erase(tree_it); + if (tree->empty()) { + // Force b to be the minimum of b and b ^ 1. This is important + // only because we want index_of_first_non_null_ to be correct. + b &= ~static_cast(1); + DestroyTree(tree); + table_[b] = table_[b + 1] = nullptr; + } + } + DestroyNode(item); + --num_elements_; + if (PROTOBUF_PREDICT_FALSE(b == index_of_first_non_null_)) { + while (index_of_first_non_null_ < num_buckets_ && + table_[index_of_first_non_null_] == nullptr) { + ++index_of_first_non_null_; + } + } + } + + private: + const_iterator find(const Key& k, TreeIterator* it) const { + return FindHelper(k, it).first; + } + template + std::pair FindHelper(const K& k) const { + return FindHelper(k, nullptr); + } + template + std::pair FindHelper(const K& k, + TreeIterator* it) const { + size_type b = BucketNumber(k); + if (TableEntryIsNonEmptyList(b)) { + Node* node = static_cast(table_[b]); + do { + if (internal::TransparentSupport::Equals(node->kv.first, k)) { + return std::make_pair(const_iterator(node, this, b), b); + } else { + node = node->next; + } + } while (node != nullptr); + } else if (TableEntryIsTree(b)) { + GOOGLE_DCHECK_EQ(table_[b], table_[b ^ 1]); + b &= ~static_cast(1); + Tree* tree = static_cast(table_[b]); + auto tree_it = tree->find(k); + if (tree_it != tree->end()) { + if (it != nullptr) *it = tree_it; + return std::make_pair(const_iterator(tree_it, this, b), b); + } + } + return std::make_pair(end(), b); + } + + // Insert the given Node in bucket b. If that would make bucket b too big, + // and bucket b is not a tree, create a tree for buckets b and b^1 to share. + // Requires count(*KeyPtrFromNodePtr(node)) == 0 and that b is the correct + // bucket. num_elements_ is not modified. + iterator InsertUnique(size_type b, Node* node) { + GOOGLE_DCHECK(index_of_first_non_null_ == num_buckets_ || + table_[index_of_first_non_null_] != nullptr); + // In practice, the code that led to this point may have already + // determined whether we are inserting into an empty list, a short list, + // or whatever. But it's probably cheap enough to recompute that here; + // it's likely that we're inserting into an empty or short list. + iterator result; + GOOGLE_DCHECK(find(node->kv.first) == end()); + if (TableEntryIsEmpty(b)) { + result = InsertUniqueInList(b, node); + } else if (TableEntryIsNonEmptyList(b)) { + if (PROTOBUF_PREDICT_FALSE(TableEntryIsTooLong(b))) { + TreeConvert(b); + result = InsertUniqueInTree(b, node); + GOOGLE_DCHECK_EQ(result.bucket_index_, b & ~static_cast(1)); + } else { + // Insert into a pre-existing list. This case cannot modify + // index_of_first_non_null_, so we skip the code to update it. + return InsertUniqueInList(b, node); + } + } else { + // Insert into a pre-existing tree. This case cannot modify + // index_of_first_non_null_, so we skip the code to update it. + return InsertUniqueInTree(b, node); + } + // parentheses around (std::min) prevents macro expansion of min(...) + index_of_first_non_null_ = + (std::min)(index_of_first_non_null_, result.bucket_index_); + return result; + } + + // Returns whether we should insert after the head of the list. For + // non-optimized builds, we randomly decide whether to insert right at the + // head of the list or just after the head. This helps add a little bit of + // non-determinism to the map ordering. + bool ShouldInsertAfterHead(void* node) { +#ifdef NDEBUG + return false; +#else + // Doing modulo with a prime mixes the bits more. + return (reinterpret_cast(node) ^ seed_) % 13 > 6; +#endif + } + + // Helper for InsertUnique. Handles the case where bucket b is a + // not-too-long linked list. + iterator InsertUniqueInList(size_type b, Node* node) { + if (table_[b] != nullptr && ShouldInsertAfterHead(node)) { + Node* first = static_cast(table_[b]); + node->next = first->next; + first->next = node; + return iterator(node, this, b); + } + + node->next = static_cast(table_[b]); + table_[b] = static_cast(node); + return iterator(node, this, b); + } + + // Helper for InsertUnique. Handles the case where bucket b points to a + // Tree. + iterator InsertUniqueInTree(size_type b, Node* node) { + GOOGLE_DCHECK_EQ(table_[b], table_[b ^ 1]); + // Maintain the invariant that node->next is null for all Nodes in Trees. + node->next = nullptr; + return iterator( + static_cast(table_[b])->insert({node->kv.first, node}).first, + this, b & ~static_cast(1)); + } + + // Returns whether it did resize. Currently this is only used when + // num_elements_ increases, though it could be used in other situations. + // It checks for load too low as well as load too high: because any number + // of erases can occur between inserts, the load could be as low as 0 here. + // Resizing to a lower size is not always helpful, but failing to do so can + // destroy the expected big-O bounds for some operations. By having the + // policy that sometimes we resize down as well as up, clients can easily + // keep O(size()) = O(number of buckets) if they want that. + bool ResizeIfLoadIsOutOfRange(size_type new_size) { + const size_type kMaxMapLoadTimes16 = 12; // controls RAM vs CPU tradeoff + const size_type hi_cutoff = num_buckets_ * kMaxMapLoadTimes16 / 16; + const size_type lo_cutoff = hi_cutoff / 4; + // We don't care how many elements are in trees. If a lot are, + // we may resize even though there are many empty buckets. In + // practice, this seems fine. + if (PROTOBUF_PREDICT_FALSE(new_size >= hi_cutoff)) { + if (num_buckets_ <= max_size() / 2) { + Resize(num_buckets_ * 2); + return true; + } + } else if (PROTOBUF_PREDICT_FALSE(new_size <= lo_cutoff && + num_buckets_ > kMinTableSize)) { + size_type lg2_of_size_reduction_factor = 1; + // It's possible we want to shrink a lot here... size() could even be 0. + // So, estimate how much to shrink by making sure we don't shrink so + // much that we would need to grow the table after a few inserts. + const size_type hypothetical_size = new_size * 5 / 4 + 1; + while ((hypothetical_size << lg2_of_size_reduction_factor) < + hi_cutoff) { + ++lg2_of_size_reduction_factor; + } + size_type new_num_buckets = std::max( + kMinTableSize, num_buckets_ >> lg2_of_size_reduction_factor); + if (new_num_buckets != num_buckets_) { + Resize(new_num_buckets); + return true; + } + } + return false; + } + + // Resize to the given number of buckets. + void Resize(size_t new_num_buckets) { + GOOGLE_DCHECK_GE(new_num_buckets, kMinTableSize); + void** const old_table = table_; + const size_type old_table_size = num_buckets_; + num_buckets_ = new_num_buckets; + table_ = CreateEmptyTable(num_buckets_); + const size_type start = index_of_first_non_null_; + index_of_first_non_null_ = num_buckets_; + for (size_type i = start; i < old_table_size; i++) { + if (TableEntryIsNonEmptyList(old_table, i)) { + TransferList(old_table, i); + } else if (TableEntryIsTree(old_table, i)) { + TransferTree(old_table, i++); + } + } + Dealloc(old_table, old_table_size); + } + + void TransferList(void* const* table, size_type index) { + Node* node = static_cast(table[index]); + do { + Node* next = node->next; + InsertUnique(BucketNumber(node->kv.first), node); + node = next; + } while (node != nullptr); + } + + void TransferTree(void* const* table, size_type index) { + Tree* tree = static_cast(table[index]); + typename Tree::iterator tree_it = tree->begin(); + do { + InsertUnique(BucketNumber(std::cref(tree_it->first).get()), + NodeFromTreeIterator(tree_it)); + } while (++tree_it != tree->end()); + DestroyTree(tree); + } + + Node* EraseFromLinkedList(Node* item, Node* head) { + if (head == item) { + return head->next; + } else { + head->next = EraseFromLinkedList(item, head->next); + return head; + } + } + + bool TableEntryIsEmpty(size_type b) const { + return TableEntryIsEmpty(table_, b); + } + bool TableEntryIsNonEmptyList(size_type b) const { + return TableEntryIsNonEmptyList(table_, b); + } + bool TableEntryIsTree(size_type b) const { + return TableEntryIsTree(table_, b); + } + bool TableEntryIsList(size_type b) const { + return TableEntryIsList(table_, b); + } + static bool TableEntryIsEmpty(void* const* table, size_type b) { + return table[b] == nullptr; + } + static bool TableEntryIsNonEmptyList(void* const* table, size_type b) { + return table[b] != nullptr && table[b] != table[b ^ 1]; + } + static bool TableEntryIsTree(void* const* table, size_type b) { + return !TableEntryIsEmpty(table, b) && + !TableEntryIsNonEmptyList(table, b); + } + static bool TableEntryIsList(void* const* table, size_type b) { + return !TableEntryIsTree(table, b); + } + + void TreeConvert(size_type b) { + GOOGLE_DCHECK(!TableEntryIsTree(b) && !TableEntryIsTree(b ^ 1)); + Tree* tree = + Arena::Create(alloc_.arena(), typename Tree::key_compare(), + typename Tree::allocator_type(alloc_)); + size_type count = CopyListToTree(b, tree) + CopyListToTree(b ^ 1, tree); + GOOGLE_DCHECK_EQ(count, tree->size()); + table_[b] = table_[b ^ 1] = static_cast(tree); + } + + // Copy a linked list in the given bucket to a tree. + // Returns the number of things it copied. + size_type CopyListToTree(size_type b, Tree* tree) { + size_type count = 0; + Node* node = static_cast(table_[b]); + while (node != nullptr) { + tree->insert({node->kv.first, node}); + ++count; + Node* next = node->next; + node->next = nullptr; + node = next; + } + return count; + } + + // Return whether table_[b] is a linked list that seems awfully long. + // Requires table_[b] to point to a non-empty linked list. + bool TableEntryIsTooLong(size_type b) { + const size_type kMaxLength = 8; + size_type count = 0; + Node* node = static_cast(table_[b]); + do { + ++count; + node = node->next; + } while (node != nullptr); + // Invariant: no linked list ever is more than kMaxLength in length. + GOOGLE_DCHECK_LE(count, kMaxLength); + return count >= kMaxLength; + } + + template + size_type BucketNumber(const K& k) const { + // We xor the hash value against the random seed so that we effectively + // have a random hash function. + uint64 h = hash_function()(k) ^ seed_; + + // We use the multiplication method to determine the bucket number from + // the hash value. The constant kPhi (suggested by Knuth) is roughly + // (sqrt(5) - 1) / 2 * 2^64. + constexpr uint64 kPhi = uint64{0x9e3779b97f4a7c15}; + return ((kPhi * h) >> 32) & (num_buckets_ - 1); + } + + // Return a power of two no less than max(kMinTableSize, n). + // Assumes either n < kMinTableSize or n is a power of two. + size_type TableSize(size_type n) { + return n < static_cast(kMinTableSize) + ? static_cast(kMinTableSize) + : n; + } + + // Use alloc_ to allocate an array of n objects of type U. + template + U* Alloc(size_type n) { + using alloc_type = typename Allocator::template rebind::other; + return alloc_type(alloc_).allocate(n); + } + + // Use alloc_ to deallocate an array of n objects of type U. + template + void Dealloc(U* t, size_type n) { + using alloc_type = typename Allocator::template rebind::other; + alloc_type(alloc_).deallocate(t, n); + } + + void DestroyNode(Node* node) { + if (alloc_.arena() == nullptr) { + delete node; + } + } + + void DestroyTree(Tree* tree) { + if (alloc_.arena() == nullptr) { + delete tree; + } + } + + void** CreateEmptyTable(size_type n) { + GOOGLE_DCHECK(n >= kMinTableSize); + GOOGLE_DCHECK_EQ(n & (n - 1), 0); + void** result = Alloc(n); + memset(result, 0, n * sizeof(result[0])); + return result; + } + + // Return a randomish value. + size_type Seed() const { + // We get a little bit of randomness from the address of the map. The + // lower bits are not very random, due to alignment, so we discard them + // and shift the higher bits into their place. + size_type s = reinterpret_cast(this) >> 12; +#if defined(__x86_64__) && defined(__GNUC__) && \ + !defined(GOOGLE_PROTOBUF_NO_RDTSC) + uint32 hi, lo; + asm("rdtsc" : "=a"(lo), "=d"(hi)); + s += ((static_cast(hi) << 32) | lo); +#endif + return s; + } + + friend class Arena; + using InternalArenaConstructable_ = void; + using DestructorSkippable_ = void; + + size_type num_elements_; + size_type num_buckets_; + size_type seed_; + size_type index_of_first_non_null_; + void** table_; // an array with num_buckets_ entries + Allocator alloc_; + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(InnerMap); + }; // end of class InnerMap + + template + using key_arg = typename internal::TransparentSupport< + key_type>::template key_arg; + + public: + // Iterators + class const_iterator { + using InnerIt = typename InnerMap::const_iterator; + + public: + using iterator_category = std::forward_iterator_tag; + using value_type = typename Map::value_type; + using difference_type = ptrdiff_t; + using pointer = const value_type*; + using reference = const value_type&; + + const_iterator() {} + explicit const_iterator(const InnerIt& it) : it_(it) {} + + const_reference operator*() const { return *it_; } + const_pointer operator->() const { return &(operator*()); } + + const_iterator& operator++() { + ++it_; + return *this; + } + const_iterator operator++(int) { return const_iterator(it_++); } + + friend bool operator==(const const_iterator& a, const const_iterator& b) { + return a.it_ == b.it_; + } + friend bool operator!=(const const_iterator& a, const const_iterator& b) { + return !(a == b); + } + + private: + InnerIt it_; + }; + + class iterator { + using InnerIt = typename InnerMap::iterator; + + public: + using iterator_category = std::forward_iterator_tag; + using value_type = typename Map::value_type; + using difference_type = ptrdiff_t; + using pointer = value_type*; + using reference = value_type&; + + iterator() {} + explicit iterator(const InnerIt& it) : it_(it) {} + + reference operator*() const { return *it_; } + pointer operator->() const { return &(operator*()); } + + iterator& operator++() { + ++it_; + return *this; + } + iterator operator++(int) { return iterator(it_++); } + + // Allow implicit conversion to const_iterator. + operator const_iterator() const { // NOLINT(runtime/explicit) + return const_iterator(typename InnerMap::const_iterator(it_)); + } + + friend bool operator==(const iterator& a, const iterator& b) { + return a.it_ == b.it_; + } + friend bool operator!=(const iterator& a, const iterator& b) { + return !(a == b); + } + + private: + friend class Map; + + InnerIt it_; + }; + + iterator begin() { return iterator(elements_->begin()); } + iterator end() { return iterator(elements_->end()); } + const_iterator begin() const { + return const_iterator(iterator(elements_->begin())); + } + const_iterator end() const { + return const_iterator(iterator(elements_->end())); + } + const_iterator cbegin() const { return begin(); } + const_iterator cend() const { return end(); } + + // Capacity + size_type size() const { return elements_->size(); } + bool empty() const { return size() == 0; } + + // Element access + T& operator[](const key_type& key) { return (*elements_)[key].second; } + + template + const T& at(const key_arg& key) const { + const_iterator it = find(key); + GOOGLE_CHECK(it != end()) << "key not found: " << static_cast(key); + return it->second; + } + + template + T& at(const key_arg& key) { + iterator it = find(key); + GOOGLE_CHECK(it != end()) << "key not found: " << static_cast(key); + return it->second; + } + + // Lookup + template + size_type count(const key_arg& key) const { + return find(key) == end() ? 0 : 1; + } + + template + const_iterator find(const key_arg& key) const { + return const_iterator(iterator(elements_->find(key))); + } + template + iterator find(const key_arg& key) { + return iterator(elements_->find(key)); + } + + template + bool contains(const key_arg& key) const { + return find(key) != end(); + } + + template + std::pair equal_range( + const key_arg& key) const { + const_iterator it = find(key); + if (it == end()) { + return std::pair(it, it); + } else { + const_iterator begin = it++; + return std::pair(begin, it); + } + } + + template + std::pair equal_range(const key_arg& key) { + iterator it = find(key); + if (it == end()) { + return std::pair(it, it); + } else { + iterator begin = it++; + return std::pair(begin, it); + } + } + + // insert + std::pair insert(const value_type& value) { + std::pair p = + elements_->insert(value.first); + if (p.second) { + p.first->second = value.second; + } + return std::pair(iterator(p.first), p.second); + } + template + void insert(InputIt first, InputIt last) { + for (InputIt it = first; it != last; ++it) { + iterator exist_it = find(it->first); + if (exist_it == end()) { + operator[](it->first) = it->second; + } + } + } + void insert(std::initializer_list values) { + insert(values.begin(), values.end()); + } + + // Erase and clear + template + size_type erase(const key_arg& key) { + iterator it = find(key); + if (it == end()) { + return 0; + } else { + erase(it); + return 1; + } + } + iterator erase(iterator pos) { + iterator i = pos++; + elements_->erase(i.it_); + return pos; + } + void erase(iterator first, iterator last) { + while (first != last) { + first = erase(first); + } + } + void clear() { elements_->clear(); } + + // Assign + Map& operator=(const Map& other) { + if (this != &other) { + clear(); + insert(other.begin(), other.end()); + } + return *this; + } + + void swap(Map& other) { + if (arena_ == other.arena_) { + std::swap(default_enum_value_, other.default_enum_value_); + std::swap(elements_, other.elements_); + } else { + // TODO(zuguang): optimize this. The temporary copy can be allocated + // in the same arena as the other message, and the "other = copy" can + // be replaced with the fast-path swap above. + Map copy = *this; + *this = other; + other = copy; + } + } + + // Access to hasher. Currently this returns a copy, but it may + // be modified to return a const reference in the future. + hasher hash_function() const { return elements_->hash_function(); } + + private: + // Set default enum value only for proto2 map field whose value is enum type. + void SetDefaultEnumValue(int default_enum_value) { + default_enum_value_ = default_enum_value; + } + + Arena* arena_; + int default_enum_value_; + InnerMap* elements_; + + friend class Arena; + using InternalArenaConstructable_ = void; + using DestructorSkippable_ = void; + template + friend class internal::MapFieldLite; +}; + +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_MAP_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/map_entry.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/map_entry.h new file mode 100644 index 0000000000000000000000000000000000000000..c636d0ae2b4df4cc26892e86a62305388f125e6f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/map_entry.h @@ -0,0 +1,172 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef GOOGLE_PROTOBUF_MAP_ENTRY_H__ +#define GOOGLE_PROTOBUF_MAP_ENTRY_H__ + +#include +#include +#include +#include +#include +#include +#include + +#include + +#ifdef SWIG +#error "You cannot SWIG proto headers" +#endif + +namespace google { +namespace protobuf { +class Arena; +namespace internal { +template +class MapField; +} +} // namespace protobuf +} // namespace google + +namespace google { +namespace protobuf { +namespace internal { + +// MapEntry is the returned google::protobuf::Message when calling AddMessage of +// google::protobuf::Reflection. In order to let it work with generated message +// reflection, its in-memory type is the same as generated message with the same +// fields. However, in order to decide the in-memory type of key/value, we need +// to know both their cpp type in generated api and proto type. In +// implementation, all in-memory types have related wire format functions to +// support except ArenaStringPtr. Therefore, we need to define another type with +// supporting wire format functions. Since this type is only used as return type +// of MapEntry accessors, it's named MapEntry accessor type. +// +// cpp type: the type visible to users in public API. +// proto type: WireFormatLite::FieldType of the field. +// in-memory type: type of the data member used to stored this field. +// MapEntry accessor type: type used in MapEntry getters/mutators to access the +// field. +// +// cpp type | proto type | in-memory type | MapEntry accessor type +// int32 TYPE_INT32 int32 int32 +// int32 TYPE_FIXED32 int32 int32 +// string TYPE_STRING ArenaStringPtr string +// FooEnum TYPE_ENUM int int +// FooMessage TYPE_MESSAGE FooMessage* FooMessage +// +// The in-memory types of primitive types can be inferred from its proto type, +// while we need to explicitly specify the cpp type if proto type is +// TYPE_MESSAGE to infer the in-memory type. Moreover, default_enum_value is +// used to initialize enum field in proto2. +template +class MapEntry + : public MapEntryImpl { + public: + MapEntry() : _internal_metadata_(NULL) {} + explicit MapEntry(Arena* arena) + : MapEntryImpl(arena), + _internal_metadata_(arena) {} + ~MapEntry() { + Message::_internal_metadata_.Delete(); + _internal_metadata_.Delete(); + } + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + + typedef + typename MapEntryImpl::KeyTypeHandler + KeyTypeHandler; + typedef typename MapEntryImpl< + Derived, Message, Key, Value, kKeyFieldType, kValueFieldType, + default_enum_value>::ValueTypeHandler ValueTypeHandler; + size_t SpaceUsedLong() const override { + size_t size = sizeof(Derived); + size += KeyTypeHandler::SpaceUsedInMapEntryLong(this->key_); + size += ValueTypeHandler::SpaceUsedInMapEntryLong(this->value_); + return size; + } + + InternalMetadata _internal_metadata_; + + private: + friend class ::PROTOBUF_NAMESPACE_ID::Arena; + template + friend class internal::MapField; + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(MapEntry); +}; + +// Specialization for the full runtime +template +struct MapEntryHelper > + : MapEntryHelper > { + explicit MapEntryHelper(const MapPair& map_pair) + : MapEntryHelper >( + map_pair) {} +}; + +template +struct DeconstructMapEntry > { + typedef K Key; + typedef V Value; + static constexpr WireFormatLite::FieldType kKeyFieldType = key; + static constexpr WireFormatLite::FieldType kValueFieldType = value; + static constexpr int default_enum_value = default_enum; +}; + +} // namespace internal +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_MAP_ENTRY_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/map_entry_lite.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/map_entry_lite.h new file mode 100644 index 0000000000000000000000000000000000000000..eb1a33b1937b979d09dc38664c1c937910a9f62d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/map_entry_lite.h @@ -0,0 +1,676 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef GOOGLE_PROTOBUF_MAP_ENTRY_LITE_H__ +#define GOOGLE_PROTOBUF_MAP_ENTRY_LITE_H__ + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#ifdef SWIG +#error "You cannot SWIG proto headers" +#endif + +namespace google { +namespace protobuf { +namespace internal { +template +class MapEntry; +template +class MapFieldLite; +} // namespace internal +} // namespace protobuf +} // namespace google + +namespace google { +namespace protobuf { +namespace internal { + +// MoveHelper::Move is used to set *dest. It copies *src, or moves it (in +// the C++11 sense), or swaps it. *src is left in a sane state for +// subsequent destruction, but shouldn't be used for anything. +template +struct MoveHelper { // primitives + static void Move(T* src, T* dest) { *dest = *src; } +}; + +template +struct MoveHelper { // enums + static void Move(T* src, T* dest) { *dest = *src; } + // T is an enum here, so allow conversions to and from int. + static void Move(T* src, int* dest) { *dest = static_cast(*src); } + static void Move(int* src, T* dest) { *dest = static_cast(*src); } +}; + +template +struct MoveHelper { // messages + static void Move(T* src, T* dest) { dest->Swap(src); } +}; + +template +struct MoveHelper { // strings and similar + static void Move(T* src, T* dest) { +#if __cplusplus >= 201103L + *dest = std::move(*src); +#else + dest->swap(*src); +#endif + } +}; + +// Functions for operating on a map entry. Does not contain any representation +// (this class is not intended to be instantiated). +template +struct MapEntryFuncs { + typedef MapTypeHandler KeyTypeHandler; + typedef MapTypeHandler ValueTypeHandler; + static const int kKeyFieldNumber = 1; + static const int kValueFieldNumber = 2; + + static uint8* InternalSerialize(int field_number, const Key& key, + const Value& value, uint8* ptr, + io::EpsCopyOutputStream* stream) { + ptr = stream->EnsureSpace(ptr); + ptr = WireFormatLite::WriteTagToArray( + field_number, WireFormatLite::WIRETYPE_LENGTH_DELIMITED, ptr); + ptr = io::CodedOutputStream::WriteVarint32ToArray(GetCachedSize(key, value), + ptr); + + ptr = KeyTypeHandler::Write(kKeyFieldNumber, key, ptr, stream); + return ValueTypeHandler::Write(kValueFieldNumber, value, ptr, stream); + } + + static size_t ByteSizeLong(const Key& key, const Value& value) { + // Tags for key and value will both be one byte (field numbers 1 and 2). + size_t inner_length = + 2 + KeyTypeHandler::ByteSize(key) + ValueTypeHandler::ByteSize(value); + return inner_length + io::CodedOutputStream::VarintSize32( + static_cast(inner_length)); + } + + static int GetCachedSize(const Key& key, const Value& value) { + // Tags for key and value will both be one byte (field numbers 1 and 2). + return 2 + KeyTypeHandler::GetCachedSize(key) + + ValueTypeHandler::GetCachedSize(value); + } +}; + +// MapEntryImpl is used to implement parsing and serialization of map entries. +// It uses Curious Recursive Template Pattern (CRTP) to provide the type of +// the eventual code to the template code. +template +class MapEntryImpl : public Base { + public: + typedef MapEntryFuncs Funcs; + + protected: + // Provide utilities to parse/serialize key/value. Provide utilities to + // manipulate internal stored type. + typedef MapTypeHandler KeyTypeHandler; + typedef MapTypeHandler ValueTypeHandler; + + // Define internal memory layout. Strings and messages are stored as + // pointers, while other types are stored as values. + typedef typename KeyTypeHandler::TypeOnMemory KeyOnMemory; + typedef typename ValueTypeHandler::TypeOnMemory ValueOnMemory; + + // Enum type cannot be used for MapTypeHandler::Read. Define a type + // which will replace Enum with int. + typedef typename KeyTypeHandler::MapEntryAccessorType KeyMapEntryAccessorType; + typedef + typename ValueTypeHandler::MapEntryAccessorType ValueMapEntryAccessorType; + + // Constants for field number. + static const int kKeyFieldNumber = 1; + static const int kValueFieldNumber = 2; + + // Constants for field tag. + static const uint8 kKeyTag = + GOOGLE_PROTOBUF_WIRE_FORMAT_MAKE_TAG(kKeyFieldNumber, KeyTypeHandler::kWireType); + static const uint8 kValueTag = GOOGLE_PROTOBUF_WIRE_FORMAT_MAKE_TAG( + kValueFieldNumber, ValueTypeHandler::kWireType); + static const size_t kTagSize = 1; + + public: + // Work-around for a compiler bug (see repeated_field.h). + typedef void MapEntryHasMergeTypeTrait; + typedef Derived EntryType; + typedef Key EntryKeyType; + typedef Value EntryValueType; + static const WireFormatLite::FieldType kEntryKeyFieldType = kKeyFieldType; + static const WireFormatLite::FieldType kEntryValueFieldType = kValueFieldType; + static const int kEntryDefaultEnumValue = default_enum_value; + + MapEntryImpl() { + KeyTypeHandler::Initialize(&key_, NULL); + ValueTypeHandler::InitializeMaybeByDefaultEnum(&value_, default_enum_value, + NULL); + _has_bits_[0] = 0; + } + + explicit MapEntryImpl(Arena* arena) : Base(arena) { + KeyTypeHandler::Initialize(&key_, arena); + ValueTypeHandler::InitializeMaybeByDefaultEnum(&value_, default_enum_value, + arena); + _has_bits_[0] = 0; + } + + ~MapEntryImpl() { + if (Base::GetArena() != NULL) return; + KeyTypeHandler::DeleteNoArena(key_); + ValueTypeHandler::DeleteNoArena(value_); + } + + // accessors ====================================================== + + virtual inline const KeyMapEntryAccessorType& key() const { + return KeyTypeHandler::GetExternalReference(key_); + } + virtual inline const ValueMapEntryAccessorType& value() const { + return ValueTypeHandler::DefaultIfNotInitialized( + value_, Derived::internal_default_instance()->value_); + } + inline KeyMapEntryAccessorType* mutable_key() { + set_has_key(); + return KeyTypeHandler::EnsureMutable(&key_, Base::GetArena()); + } + inline ValueMapEntryAccessorType* mutable_value() { + set_has_value(); + return ValueTypeHandler::EnsureMutable(&value_, Base::GetArena()); + } + + // implements MessageLite ========================================= + + // MapEntryImpl is for implementation only and this function isn't called + // anywhere. Just provide a fake implementation here for MessageLite. + std::string GetTypeName() const override { return ""; } + + void CheckTypeAndMergeFrom(const MessageLite& other) override { + MergeFromInternal(*::google::protobuf::internal::DownCast(&other)); + } + + const char* _InternalParse(const char* ptr, ParseContext* ctx) final { + while (!ctx->Done(&ptr)) { + uint32 tag; + ptr = ReadTag(ptr, &tag); + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); + if (tag == kKeyTag) { + set_has_key(); + KeyMapEntryAccessorType* key = mutable_key(); + ptr = KeyTypeHandler::Read(ptr, ctx, key); + if (!Derived::ValidateKey(key)) return nullptr; + } else if (tag == kValueTag) { + set_has_value(); + ValueMapEntryAccessorType* value = mutable_value(); + ptr = ValueTypeHandler::Read(ptr, ctx, value); + if (!Derived::ValidateValue(value)) return nullptr; + } else { + if (tag == 0 || WireFormatLite::GetTagWireType(tag) == + WireFormatLite::WIRETYPE_END_GROUP) { + ctx->SetLastTag(tag); + return ptr; + } + ptr = UnknownFieldParse(tag, static_cast(nullptr), ptr, + ctx); + } + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); + } + return ptr; + } + + size_t ByteSizeLong() const override { + size_t size = 0; + size += kTagSize + static_cast(KeyTypeHandler::ByteSize(key())); + size += kTagSize + static_cast(ValueTypeHandler::ByteSize(value())); + return size; + } + + ::google::protobuf::uint8* _InternalSerialize(::google::protobuf::uint8* ptr, + io::EpsCopyOutputStream* stream) const override { + ptr = KeyTypeHandler::Write(kKeyFieldNumber, key(), ptr, stream); + return ValueTypeHandler::Write(kValueFieldNumber, value(), ptr, stream); + } + + // Don't override SerializeWithCachedSizesToArray. Use MessageLite's. + + int GetCachedSize() const override { + int size = 0; + size += has_key() ? static_cast(kTagSize) + + KeyTypeHandler::GetCachedSize(key()) + : 0; + size += has_value() ? static_cast(kTagSize) + + ValueTypeHandler::GetCachedSize(value()) + : 0; + return size; + } + + bool IsInitialized() const override { + return ValueTypeHandler::IsInitialized(value_); + } + + Base* New() const override { + Derived* entry = new Derived; + return entry; + } + + Base* New(Arena* arena) const override { + Derived* entry = Arena::CreateMessage(arena); + return entry; + } + + protected: + // We can't declare this function directly here as it would hide the other + // overload (const Message&). + void MergeFromInternal(const MapEntryImpl& from) { + if (from._has_bits_[0]) { + if (from.has_key()) { + KeyTypeHandler::EnsureMutable(&key_, Base::GetArena()); + KeyTypeHandler::Merge(from.key(), &key_, Base::GetArena()); + set_has_key(); + } + if (from.has_value()) { + ValueTypeHandler::EnsureMutable(&value_, Base::GetArena()); + ValueTypeHandler::Merge(from.value(), &value_, Base::GetArena()); + set_has_value(); + } + } + } + + public: + void Clear() override { + KeyTypeHandler::Clear(&key_, Base::GetArena()); + ValueTypeHandler::ClearMaybeByDefaultEnum(&value_, Base::GetArena(), + default_enum_value); + clear_has_key(); + clear_has_value(); + } + + static void InitAsDefaultInstance() { + Derived* d = const_cast(Derived::internal_default_instance()); + KeyTypeHandler::AssignDefaultValue(&d->key_); + ValueTypeHandler::AssignDefaultValue(&d->value_); + } + + // Parsing using MergePartialFromCodedStream, above, is not as + // efficient as it could be. This helper class provides a speedier way. + template + class Parser { + public: + explicit Parser(MapField* mf) : mf_(mf), map_(mf->MutableMap()) {} + ~Parser() { + if (entry_ != nullptr && entry_->GetArena() == nullptr) delete entry_; + } + + // This does what the typical MergePartialFromCodedStream() is expected to + // do, with the additional side-effect that if successful (i.e., if true is + // going to be its return value) it inserts the key-value pair into map_. + bool MergePartialFromCodedStream(io::CodedInputStream* input) { + // Look for the expected thing: a key and then a value. If it fails, + // invoke the enclosing class's MergePartialFromCodedStream, or return + // false if that would be pointless. + if (input->ExpectTag(kKeyTag)) { + if (!KeyTypeHandler::Read(input, &key_)) { + return false; + } + // Peek at the next byte to see if it is kValueTag. If not, bail out. + const void* data; + int size; + input->GetDirectBufferPointerInline(&data, &size); + // We could use memcmp here, but we don't bother. The tag is one byte. + static_assert(kTagSize == 1, "tag size must be 1"); + if (size > 0 && *reinterpret_cast(data) == kValueTag) { + typename Map::size_type map_size = map_->size(); + value_ptr_ = &(*map_)[key_]; + if (PROTOBUF_PREDICT_TRUE(map_size != map_->size())) { + // We created a new key-value pair. Fill in the value. + typedef + typename MapIf::type T; + input->Skip(kTagSize); // Skip kValueTag. + if (!ValueTypeHandler::Read(input, + reinterpret_cast(value_ptr_))) { + map_->erase(key_); // Failure! Undo insertion. + return false; + } + if (input->ExpectAtEnd()) return true; + return ReadBeyondKeyValuePair(input); + } + } + } else { + key_ = Key(); + } + + NewEntry(); + *entry_->mutable_key() = key_; + const bool result = entry_->MergePartialFromCodedStream(input); + if (result) UseKeyAndValueFromEntry(); + return result; + } + + const char* _InternalParse(const char* ptr, ParseContext* ctx) { + if (PROTOBUF_PREDICT_TRUE(!ctx->Done(&ptr) && *ptr == kKeyTag)) { + ptr = KeyTypeHandler::Read(ptr + 1, ctx, &key_); + if (PROTOBUF_PREDICT_FALSE(!ptr || !Derived::ValidateKey(&key_))) { + return nullptr; + } + if (PROTOBUF_PREDICT_TRUE(!ctx->Done(&ptr) && *ptr == kValueTag)) { + typename Map::size_type map_size = map_->size(); + value_ptr_ = &(*map_)[key_]; + if (PROTOBUF_PREDICT_TRUE(map_size != map_->size())) { + using T = + typename MapIf::type; + ptr = ValueTypeHandler::Read(ptr + 1, ctx, + reinterpret_cast(value_ptr_)); + if (PROTOBUF_PREDICT_FALSE(!ptr || + !Derived::ValidateValue(value_ptr_))) { + map_->erase(key_); // Failure! Undo insertion. + return nullptr; + } + if (PROTOBUF_PREDICT_TRUE(ctx->Done(&ptr))) return ptr; + if (!ptr) return nullptr; + NewEntry(); + ValueMover::Move(value_ptr_, entry_->mutable_value()); + map_->erase(key_); + goto move_key; + } + } else { + if (!ptr) return nullptr; + } + NewEntry(); + move_key: + KeyMover::Move(&key_, entry_->mutable_key()); + } else { + if (!ptr) return nullptr; + NewEntry(); + } + ptr = entry_->_InternalParse(ptr, ctx); + if (ptr) UseKeyAndValueFromEntry(); + return ptr; + } + + template + const char* ParseWithEnumValidation(const char* ptr, ParseContext* ctx, + bool (*is_valid)(int), uint32 field_num, + InternalMetadata* metadata) { + auto entry = NewEntry(); + ptr = entry->_InternalParse(ptr, ctx); + if (!ptr) return nullptr; + if (is_valid(entry->value())) { + UseKeyAndValueFromEntry(); + } else { + WriteLengthDelimited(field_num, entry->SerializeAsString(), + metadata->mutable_unknown_fields()); + } + return ptr; + } + + MapEntryImpl* NewEntry() { return entry_ = mf_->NewEntry(); } + + const Key& key() const { return key_; } + const Value& value() const { return *value_ptr_; } + + const Key& entry_key() const { return entry_->key(); } + const Value& entry_value() const { return entry_->value(); } + + private: + void UseKeyAndValueFromEntry() { + // Update key_ in case we need it later (because key() is called). + // This is potentially inefficient, especially if the key is + // expensive to copy (e.g., a long string), but this is a cold + // path, so it's not a big deal. + key_ = entry_->key(); + value_ptr_ = &(*map_)[key_]; + ValueMover::Move(entry_->mutable_value(), value_ptr_); + } + + // After reading a key and value successfully, and inserting that data + // into map_, we are not at the end of the input. This is unusual, but + // allowed by the spec. + bool ReadBeyondKeyValuePair(io::CodedInputStream* input) PROTOBUF_COLD { + NewEntry(); + ValueMover::Move(value_ptr_, entry_->mutable_value()); + map_->erase(key_); + KeyMover::Move(&key_, entry_->mutable_key()); + const bool result = entry_->MergePartialFromCodedStream(input); + if (result) UseKeyAndValueFromEntry(); + return result; + } + + typedef MoveHelper + KeyMover; + typedef MoveHelper + ValueMover; + + MapField* const mf_; + Map* const map_; + Key key_; + Value* value_ptr_; + MapEntryImpl* entry_ = nullptr; + }; + + protected: + void set_has_key() { _has_bits_[0] |= 0x00000001u; } + bool has_key() const { return (_has_bits_[0] & 0x00000001u) != 0; } + void clear_has_key() { _has_bits_[0] &= ~0x00000001u; } + void set_has_value() { _has_bits_[0] |= 0x00000002u; } + bool has_value() const { return (_has_bits_[0] & 0x00000002u) != 0; } + void clear_has_value() { _has_bits_[0] &= ~0x00000002u; } + + public: + inline Arena* GetArena() const { return Base::GetArena(); } + + public: // Needed for constructing tables + KeyOnMemory key_; + ValueOnMemory value_; + uint32 _has_bits_[1]; + + private: + friend class ::PROTOBUF_NAMESPACE_ID::Arena; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + template + friend class internal::MapEntry; + template + friend class internal::MapFieldLite; + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(MapEntryImpl); +}; + +template +class MapEntryLite + : public MapEntryImpl { + public: + typedef MapEntryImpl + SuperType; + MapEntryLite() {} + explicit MapEntryLite(Arena* arena) : SuperType(arena) {} + ~MapEntryLite() { MessageLite::_internal_metadata_.Delete(); } + void MergeFrom(const MapEntryLite& other) { MergeFromInternal(other); } + + private: + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(MapEntryLite); +}; +// The completely unprincipled and unwieldy use of template parameters in +// the map code necessitates wrappers to make the code a little bit more +// manageable. +template +struct DeconstructMapEntry; + +template +struct DeconstructMapEntry > { + typedef K Key; + typedef V Value; + static const WireFormatLite::FieldType kKeyFieldType = key; + static const WireFormatLite::FieldType kValueFieldType = value; + static const int default_enum_value = default_enum; +}; + +// Helpers for deterministic serialization ============================= + +// This struct can be used with any generic sorting algorithm. If the Key +// type is relatively small and easy to copy then copying Keys into an +// array of SortItems can be beneficial. Then all the data the sorting +// algorithm needs to touch is in that one array. +template +struct SortItem { + SortItem() {} + explicit SortItem(PtrToKeyValuePair p) : first(p->first), second(p) {} + + Key first; + PtrToKeyValuePair second; +}; + +template +struct CompareByFirstField { + bool operator()(const T& a, const T& b) const { return a.first < b.first; } +}; + +template +struct CompareByDerefFirst { + bool operator()(const T& a, const T& b) const { return a->first < b->first; } +}; + +// Helper for table driven serialization + +template +struct FromHelper { + template + static const T& From(const T& x) { + return x; + } +}; + +template <> +struct FromHelper { + static ArenaStringPtr From(const std::string& x) { + ArenaStringPtr res; + TaggedPtr ptr; + ptr.Set(const_cast(&x)); + res.UnsafeSetTaggedPointer(ptr); + return res; + } +}; +template <> +struct FromHelper { + static ArenaStringPtr From(const std::string& x) { + ArenaStringPtr res; + TaggedPtr ptr; + ptr.Set(const_cast(&x)); + res.UnsafeSetTaggedPointer(ptr); + return res; + } +}; +template <> +struct FromHelper { + template + static T* From(const T& x) { + return const_cast(&x); + } +}; + +template +struct MapEntryHelper; + +template +struct MapEntryHelper > { + // Provide utilities to parse/serialize key/value. Provide utilities to + // manipulate internal stored type. + typedef MapTypeHandler KeyTypeHandler; + typedef MapTypeHandler ValueTypeHandler; + + // Define internal memory layout. Strings and messages are stored as + // pointers, while other types are stored as values. + typedef typename KeyTypeHandler::TypeOnMemory KeyOnMemory; + typedef typename ValueTypeHandler::TypeOnMemory ValueOnMemory; + + explicit MapEntryHelper(const MapPair& map_pair) + : _has_bits_(3), + _cached_size_(2 + KeyTypeHandler::GetCachedSize(map_pair.first) + + ValueTypeHandler::GetCachedSize(map_pair.second)), + key_(FromHelper::From(map_pair.first)), + value_(FromHelper::From(map_pair.second)) {} + + // Purposely not following the style guide naming. These are the names + // the proto compiler would generate given the map entry descriptor. + // The proto compiler generates the offsets in this struct as if this was + // a regular message. This way the table driven code barely notices it's + // dealing with a map field. + uint32 _has_bits_; // NOLINT + uint32 _cached_size_; // NOLINT + KeyOnMemory key_; // NOLINT + ValueOnMemory value_; // NOLINT +}; + +} // namespace internal +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_MAP_ENTRY_LITE_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/map_field.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/map_field.h new file mode 100644 index 0000000000000000000000000000000000000000..f168d9f5380e91ab75ac4d448949ed80d52d8173 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/map_field.h @@ -0,0 +1,849 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef GOOGLE_PROTOBUF_MAP_FIELD_H__ +#define GOOGLE_PROTOBUF_MAP_FIELD_H__ + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +#include + +#ifdef SWIG +#error "You cannot SWIG proto headers" +#endif + +namespace google { +namespace protobuf { +class DynamicMessage; +class MapIterator; + +#define TYPE_CHECK(EXPECTEDTYPE, METHOD) \ + if (type() != EXPECTEDTYPE) { \ + GOOGLE_LOG(FATAL) << "Protocol Buffer map usage error:\n" \ + << METHOD << " type does not match\n" \ + << " Expected : " \ + << FieldDescriptor::CppTypeName(EXPECTEDTYPE) << "\n" \ + << " Actual : " << FieldDescriptor::CppTypeName(type()); \ + } + +// MapKey is an union type for representing any possible +// map key. +class PROTOBUF_EXPORT MapKey { + public: + MapKey() : type_(0) {} + MapKey(const MapKey& other) : type_(0) { CopyFrom(other); } + + MapKey& operator=(const MapKey& other) { + CopyFrom(other); + return *this; + } + + ~MapKey() { + if (type_ == FieldDescriptor::CPPTYPE_STRING) { + val_.string_value_.Destruct(); + } + } + + FieldDescriptor::CppType type() const { + if (type_ == 0) { + GOOGLE_LOG(FATAL) << "Protocol Buffer map usage error:\n" + << "MapKey::type MapKey is not initialized. " + << "Call set methods to initialize MapKey."; + } + return (FieldDescriptor::CppType)type_; + } + + void SetInt64Value(int64 value) { + SetType(FieldDescriptor::CPPTYPE_INT64); + val_.int64_value_ = value; + } + void SetUInt64Value(uint64 value) { + SetType(FieldDescriptor::CPPTYPE_UINT64); + val_.uint64_value_ = value; + } + void SetInt32Value(int32 value) { + SetType(FieldDescriptor::CPPTYPE_INT32); + val_.int32_value_ = value; + } + void SetUInt32Value(uint32 value) { + SetType(FieldDescriptor::CPPTYPE_UINT32); + val_.uint32_value_ = value; + } + void SetBoolValue(bool value) { + SetType(FieldDescriptor::CPPTYPE_BOOL); + val_.bool_value_ = value; + } + void SetStringValue(std::string val) { + SetType(FieldDescriptor::CPPTYPE_STRING); + *val_.string_value_.get_mutable() = std::move(val); + } + + int64 GetInt64Value() const { + TYPE_CHECK(FieldDescriptor::CPPTYPE_INT64, "MapKey::GetInt64Value"); + return val_.int64_value_; + } + uint64 GetUInt64Value() const { + TYPE_CHECK(FieldDescriptor::CPPTYPE_UINT64, "MapKey::GetUInt64Value"); + return val_.uint64_value_; + } + int32 GetInt32Value() const { + TYPE_CHECK(FieldDescriptor::CPPTYPE_INT32, "MapKey::GetInt32Value"); + return val_.int32_value_; + } + uint32 GetUInt32Value() const { + TYPE_CHECK(FieldDescriptor::CPPTYPE_UINT32, "MapKey::GetUInt32Value"); + return val_.uint32_value_; + } + bool GetBoolValue() const { + TYPE_CHECK(FieldDescriptor::CPPTYPE_BOOL, "MapKey::GetBoolValue"); + return val_.bool_value_; + } + const std::string& GetStringValue() const { + TYPE_CHECK(FieldDescriptor::CPPTYPE_STRING, "MapKey::GetStringValue"); + return val_.string_value_.get(); + } + + bool operator<(const MapKey& other) const { + if (type_ != other.type_) { + // We could define a total order that handles this case, but + // there currently no need. So, for now, fail. + GOOGLE_LOG(FATAL) << "Unsupported: type mismatch"; + } + switch (type()) { + case FieldDescriptor::CPPTYPE_DOUBLE: + case FieldDescriptor::CPPTYPE_FLOAT: + case FieldDescriptor::CPPTYPE_ENUM: + case FieldDescriptor::CPPTYPE_MESSAGE: + GOOGLE_LOG(FATAL) << "Unsupported"; + return false; + case FieldDescriptor::CPPTYPE_STRING: + return val_.string_value_.get() < other.val_.string_value_.get(); + case FieldDescriptor::CPPTYPE_INT64: + return val_.int64_value_ < other.val_.int64_value_; + case FieldDescriptor::CPPTYPE_INT32: + return val_.int32_value_ < other.val_.int32_value_; + case FieldDescriptor::CPPTYPE_UINT64: + return val_.uint64_value_ < other.val_.uint64_value_; + case FieldDescriptor::CPPTYPE_UINT32: + return val_.uint32_value_ < other.val_.uint32_value_; + case FieldDescriptor::CPPTYPE_BOOL: + return val_.bool_value_ < other.val_.bool_value_; + } + return false; + } + + bool operator==(const MapKey& other) const { + if (type_ != other.type_) { + // To be consistent with operator<, we don't allow this either. + GOOGLE_LOG(FATAL) << "Unsupported: type mismatch"; + } + switch (type()) { + case FieldDescriptor::CPPTYPE_DOUBLE: + case FieldDescriptor::CPPTYPE_FLOAT: + case FieldDescriptor::CPPTYPE_ENUM: + case FieldDescriptor::CPPTYPE_MESSAGE: + GOOGLE_LOG(FATAL) << "Unsupported"; + break; + case FieldDescriptor::CPPTYPE_STRING: + return val_.string_value_.get() == other.val_.string_value_.get(); + case FieldDescriptor::CPPTYPE_INT64: + return val_.int64_value_ == other.val_.int64_value_; + case FieldDescriptor::CPPTYPE_INT32: + return val_.int32_value_ == other.val_.int32_value_; + case FieldDescriptor::CPPTYPE_UINT64: + return val_.uint64_value_ == other.val_.uint64_value_; + case FieldDescriptor::CPPTYPE_UINT32: + return val_.uint32_value_ == other.val_.uint32_value_; + case FieldDescriptor::CPPTYPE_BOOL: + return val_.bool_value_ == other.val_.bool_value_; + } + GOOGLE_LOG(FATAL) << "Can't get here."; + return false; + } + + void CopyFrom(const MapKey& other) { + SetType(other.type()); + switch (type_) { + case FieldDescriptor::CPPTYPE_DOUBLE: + case FieldDescriptor::CPPTYPE_FLOAT: + case FieldDescriptor::CPPTYPE_ENUM: + case FieldDescriptor::CPPTYPE_MESSAGE: + GOOGLE_LOG(FATAL) << "Unsupported"; + break; + case FieldDescriptor::CPPTYPE_STRING: + *val_.string_value_.get_mutable() = other.val_.string_value_.get(); + break; + case FieldDescriptor::CPPTYPE_INT64: + val_.int64_value_ = other.val_.int64_value_; + break; + case FieldDescriptor::CPPTYPE_INT32: + val_.int32_value_ = other.val_.int32_value_; + break; + case FieldDescriptor::CPPTYPE_UINT64: + val_.uint64_value_ = other.val_.uint64_value_; + break; + case FieldDescriptor::CPPTYPE_UINT32: + val_.uint32_value_ = other.val_.uint32_value_; + break; + case FieldDescriptor::CPPTYPE_BOOL: + val_.bool_value_ = other.val_.bool_value_; + break; + } + } + + private: + template + friend class internal::TypeDefinedMapFieldBase; + friend class ::PROTOBUF_NAMESPACE_ID::MapIterator; + friend class internal::DynamicMapField; + + union KeyValue { + KeyValue() {} + internal::ExplicitlyConstructed string_value_; + int64 int64_value_; + int32 int32_value_; + uint64 uint64_value_; + uint32 uint32_value_; + bool bool_value_; + } val_; + + void SetType(FieldDescriptor::CppType type) { + if (type_ == type) return; + if (type_ == FieldDescriptor::CPPTYPE_STRING) { + val_.string_value_.Destruct(); + } + type_ = type; + if (type_ == FieldDescriptor::CPPTYPE_STRING) { + val_.string_value_.DefaultConstruct(); + } + } + + // type_ is 0 or a valid FieldDescriptor::CppType. + int type_; +}; + +namespace internal { + +class ContendedMapCleanTest; +class GeneratedMessageReflection; +class MapFieldAccessor; + +// This class provides access to map field using reflection, which is the same +// as those provided for RepeatedPtrField. It is used for internal +// reflection implentation only. Users should never use this directly. +class PROTOBUF_EXPORT MapFieldBase { + public: + MapFieldBase() + : arena_(NULL), repeated_field_(NULL), state_(STATE_MODIFIED_MAP) {} + explicit MapFieldBase(Arena* arena) + : arena_(arena), repeated_field_(NULL), state_(STATE_MODIFIED_MAP) { + // Mutex's destructor needs to be called explicitly to release resources + // acquired in its constructor. + if (arena) { + arena->OwnDestructor(&mutex_); + } + } + virtual ~MapFieldBase(); + + // Returns reference to internal repeated field. Data written using + // Map's api prior to calling this function is guarantted to be + // included in repeated field. + const RepeatedPtrFieldBase& GetRepeatedField() const; + + // Like above. Returns mutable pointer to the internal repeated field. + RepeatedPtrFieldBase* MutableRepeatedField(); + + // Pure virtual map APIs for Map Reflection. + virtual bool ContainsMapKey(const MapKey& map_key) const = 0; + virtual bool InsertOrLookupMapValue(const MapKey& map_key, + MapValueRef* val) = 0; + // Returns whether changes to the map are reflected in the repeated field. + bool IsRepeatedFieldValid() const; + // Insures operations after won't get executed before calling this. + bool IsMapValid() const; + virtual bool DeleteMapValue(const MapKey& map_key) = 0; + virtual bool EqualIterator(const MapIterator& a, + const MapIterator& b) const = 0; + virtual void MapBegin(MapIterator* map_iter) const = 0; + virtual void MapEnd(MapIterator* map_iter) const = 0; + virtual void MergeFrom(const MapFieldBase& other) = 0; + virtual void Swap(MapFieldBase* other) = 0; + // Sync Map with repeated field and returns the size of map. + virtual int size() const = 0; + virtual void Clear() = 0; + + // Returns the number of bytes used by the repeated field, excluding + // sizeof(*this) + size_t SpaceUsedExcludingSelfLong() const; + + int SpaceUsedExcludingSelf() const { + return internal::ToIntSize(SpaceUsedExcludingSelfLong()); + } + + protected: + // Gets the size of space used by map field. + virtual size_t SpaceUsedExcludingSelfNoLock() const; + + // Synchronizes the content in Map to RepeatedPtrField if there is any change + // to Map after last synchronization. + void SyncRepeatedFieldWithMap() const; + virtual void SyncRepeatedFieldWithMapNoLock() const; + + // Synchronizes the content in RepeatedPtrField to Map if there is any change + // to RepeatedPtrField after last synchronization. + void SyncMapWithRepeatedField() const; + virtual void SyncMapWithRepeatedFieldNoLock() const {} + + // Tells MapFieldBase that there is new change to Map. + void SetMapDirty(); + + // Tells MapFieldBase that there is new change to RepeatedPTrField. + void SetRepeatedDirty(); + + // Provides derived class the access to repeated field. + void* MutableRepeatedPtrField() const; + + enum State { + STATE_MODIFIED_MAP = 0, // map has newly added data that has not been + // synchronized to repeated field + STATE_MODIFIED_REPEATED = 1, // repeated field has newly added data that + // has not been synchronized to map + CLEAN = 2, // data in map and repeated field are same + }; + + Arena* arena_; + mutable RepeatedPtrField* repeated_field_; + + mutable internal::WrappedMutex + mutex_; // The thread to synchronize map and repeated field + // needs to get lock first; + mutable std::atomic state_; + + private: + friend class ContendedMapCleanTest; + friend class GeneratedMessageReflection; + friend class MapFieldAccessor; + friend class ::PROTOBUF_NAMESPACE_ID::DynamicMessage; + + // Virtual helper methods for MapIterator. MapIterator doesn't have the + // type helper for key and value. Call these help methods to deal with + // different types. Real helper methods are implemented in + // TypeDefinedMapFieldBase. + friend class ::PROTOBUF_NAMESPACE_ID::MapIterator; + // Allocate map<...>::iterator for MapIterator. + virtual void InitializeIterator(MapIterator* map_iter) const = 0; + + // DeleteIterator() is called by the destructor of MapIterator only. + // It deletes map<...>::iterator for MapIterator. + virtual void DeleteIterator(MapIterator* map_iter) const = 0; + + // Copy the map<...>::iterator from other_iterator to + // this_iterator. + virtual void CopyIterator(MapIterator* this_iterator, + const MapIterator& other_iterator) const = 0; + + // IncreaseIterator() is called by operator++() of MapIterator only. + // It implements the ++ operator of MapIterator. + virtual void IncreaseIterator(MapIterator* map_iter) const = 0; + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(MapFieldBase); +}; + +// This class provides common Map Reflection implementations for generated +// message and dynamic message. +template +class TypeDefinedMapFieldBase : public MapFieldBase { + public: + TypeDefinedMapFieldBase() {} + explicit TypeDefinedMapFieldBase(Arena* arena) : MapFieldBase(arena) {} + ~TypeDefinedMapFieldBase() override {} + void MapBegin(MapIterator* map_iter) const override; + void MapEnd(MapIterator* map_iter) const override; + bool EqualIterator(const MapIterator& a, const MapIterator& b) const override; + + virtual const Map& GetMap() const = 0; + virtual Map* MutableMap() = 0; + + protected: + typename Map::const_iterator& InternalGetIterator( + const MapIterator* map_iter) const; + + private: + void InitializeIterator(MapIterator* map_iter) const override; + void DeleteIterator(MapIterator* map_iter) const override; + void CopyIterator(MapIterator* this_iteratorm, + const MapIterator& that_iterator) const override; + void IncreaseIterator(MapIterator* map_iter) const override; + + virtual void SetMapIteratorValue(MapIterator* map_iter) const = 0; + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(TypeDefinedMapFieldBase); +}; + +// This class provides access to map field using generated api. It is used for +// internal generated message implentation only. Users should never use this +// directly. +template +class MapField : public TypeDefinedMapFieldBase { + // Provide utilities to parse/serialize key/value. Provide utilities to + // manipulate internal stored type. + typedef MapTypeHandler KeyTypeHandler; + typedef MapTypeHandler ValueTypeHandler; + + // Define message type for internal repeated field. + typedef Derived EntryType; + + // Define abbreviation for parent MapFieldLite + typedef MapFieldLite + MapFieldLiteType; + + // Enum needs to be handled differently from other types because it has + // different exposed type in Map's api and repeated field's api. For + // details see the comment in the implementation of + // SyncMapWithRepeatedFieldNoLock. + static constexpr bool kIsValueEnum = ValueTypeHandler::kIsEnum; + typedef typename MapIf::type CastValueType; + + public: + typedef typename Derived::SuperType EntryTypeTrait; + typedef Map MapType; + + MapField() {} + explicit MapField(Arena* arena) + : TypeDefinedMapFieldBase(arena), impl_(arena) {} + + // Implement MapFieldBase + bool ContainsMapKey(const MapKey& map_key) const override; + bool InsertOrLookupMapValue(const MapKey& map_key, MapValueRef* val) override; + bool DeleteMapValue(const MapKey& map_key) override; + + const Map& GetMap() const override { + MapFieldBase::SyncMapWithRepeatedField(); + return impl_.GetMap(); + } + + Map* MutableMap() override { + MapFieldBase::SyncMapWithRepeatedField(); + Map* result = impl_.MutableMap(); + MapFieldBase::SetMapDirty(); + return result; + } + + int size() const override; + void Clear() override; + void MergeFrom(const MapFieldBase& other) override; + void Swap(MapFieldBase* other) override; + + // Used in the implementation of parsing. Caller should take the ownership iff + // arena_ is NULL. + EntryType* NewEntry() const { return impl_.NewEntry(); } + // Used in the implementation of serializing enum value type. Caller should + // take the ownership iff arena_ is NULL. + EntryType* NewEnumEntryWrapper(const Key& key, const T t) const { + return impl_.NewEnumEntryWrapper(key, t); + } + // Used in the implementation of serializing other value types. Caller should + // take the ownership iff arena_ is NULL. + EntryType* NewEntryWrapper(const Key& key, const T& t) const { + return impl_.NewEntryWrapper(key, t); + } + + const char* _InternalParse(const char* ptr, ParseContext* ctx) { + return impl_._InternalParse(ptr, ctx); + } + template + const char* ParseWithEnumValidation(const char* ptr, ParseContext* ctx, + bool (*is_valid)(int), uint32 field_num, + InternalMetadata* metadata) { + return impl_.template ParseWithEnumValidation( + ptr, ctx, is_valid, field_num, metadata); + } + + private: + MapFieldLiteType impl_; + + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + + // Implements MapFieldBase + void SyncRepeatedFieldWithMapNoLock() const override; + void SyncMapWithRepeatedFieldNoLock() const override; + size_t SpaceUsedExcludingSelfNoLock() const override; + + void SetMapIteratorValue(MapIterator* map_iter) const override; + + friend class ::PROTOBUF_NAMESPACE_ID::Arena; + friend class MapFieldStateTest; // For testing, it needs raw access to impl_ + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(MapField); +}; + +template +bool AllAreInitialized( + const MapField& field) { + const auto& t = field.GetMap(); + for (typename Map::const_iterator it = t.begin(); it != t.end(); + ++it) { + if (!it->second.IsInitialized()) return false; + } + return true; +} + +template +struct MapEntryToMapField> { + typedef MapField + MapFieldType; +}; + +class PROTOBUF_EXPORT DynamicMapField + : public TypeDefinedMapFieldBase { + public: + explicit DynamicMapField(const Message* default_entry); + DynamicMapField(const Message* default_entry, Arena* arena); + ~DynamicMapField() override; + + // Implement MapFieldBase + bool ContainsMapKey(const MapKey& map_key) const override; + bool InsertOrLookupMapValue(const MapKey& map_key, MapValueRef* val) override; + bool DeleteMapValue(const MapKey& map_key) override; + void MergeFrom(const MapFieldBase& other) override; + void Swap(MapFieldBase* other) override; + + const Map& GetMap() const override; + Map* MutableMap() override; + + int size() const override; + void Clear() override; + + private: + Map map_; + const Message* default_entry_; + + void AllocateMapValue(MapValueRef* map_val); + + // Implements MapFieldBase + void SyncRepeatedFieldWithMapNoLock() const override; + void SyncMapWithRepeatedFieldNoLock() const override; + size_t SpaceUsedExcludingSelfNoLock() const override; + void SetMapIteratorValue(MapIterator* map_iter) const override; + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(DynamicMapField); +}; + +} // namespace internal + +// MapValueRef points to a map value. +class PROTOBUF_EXPORT MapValueRef { + public: + MapValueRef() : data_(NULL), type_(0) {} + + void SetInt64Value(int64 value) { + TYPE_CHECK(FieldDescriptor::CPPTYPE_INT64, "MapValueRef::SetInt64Value"); + *reinterpret_cast(data_) = value; + } + void SetUInt64Value(uint64 value) { + TYPE_CHECK(FieldDescriptor::CPPTYPE_UINT64, "MapValueRef::SetUInt64Value"); + *reinterpret_cast(data_) = value; + } + void SetInt32Value(int32 value) { + TYPE_CHECK(FieldDescriptor::CPPTYPE_INT32, "MapValueRef::SetInt32Value"); + *reinterpret_cast(data_) = value; + } + void SetUInt32Value(uint32 value) { + TYPE_CHECK(FieldDescriptor::CPPTYPE_UINT32, "MapValueRef::SetUInt32Value"); + *reinterpret_cast(data_) = value; + } + void SetBoolValue(bool value) { + TYPE_CHECK(FieldDescriptor::CPPTYPE_BOOL, "MapValueRef::SetBoolValue"); + *reinterpret_cast(data_) = value; + } + // TODO(jieluo) - Checks that enum is member. + void SetEnumValue(int value) { + TYPE_CHECK(FieldDescriptor::CPPTYPE_ENUM, "MapValueRef::SetEnumValue"); + *reinterpret_cast(data_) = value; + } + void SetStringValue(const std::string& value) { + TYPE_CHECK(FieldDescriptor::CPPTYPE_STRING, "MapValueRef::SetStringValue"); + *reinterpret_cast(data_) = value; + } + void SetFloatValue(float value) { + TYPE_CHECK(FieldDescriptor::CPPTYPE_FLOAT, "MapValueRef::SetFloatValue"); + *reinterpret_cast(data_) = value; + } + void SetDoubleValue(double value) { + TYPE_CHECK(FieldDescriptor::CPPTYPE_DOUBLE, "MapValueRef::SetDoubleValue"); + *reinterpret_cast(data_) = value; + } + + int64 GetInt64Value() const { + TYPE_CHECK(FieldDescriptor::CPPTYPE_INT64, "MapValueRef::GetInt64Value"); + return *reinterpret_cast(data_); + } + uint64 GetUInt64Value() const { + TYPE_CHECK(FieldDescriptor::CPPTYPE_UINT64, "MapValueRef::GetUInt64Value"); + return *reinterpret_cast(data_); + } + int32 GetInt32Value() const { + TYPE_CHECK(FieldDescriptor::CPPTYPE_INT32, "MapValueRef::GetInt32Value"); + return *reinterpret_cast(data_); + } + uint32 GetUInt32Value() const { + TYPE_CHECK(FieldDescriptor::CPPTYPE_UINT32, "MapValueRef::GetUInt32Value"); + return *reinterpret_cast(data_); + } + bool GetBoolValue() const { + TYPE_CHECK(FieldDescriptor::CPPTYPE_BOOL, "MapValueRef::GetBoolValue"); + return *reinterpret_cast(data_); + } + int GetEnumValue() const { + TYPE_CHECK(FieldDescriptor::CPPTYPE_ENUM, "MapValueRef::GetEnumValue"); + return *reinterpret_cast(data_); + } + const std::string& GetStringValue() const { + TYPE_CHECK(FieldDescriptor::CPPTYPE_STRING, "MapValueRef::GetStringValue"); + return *reinterpret_cast(data_); + } + float GetFloatValue() const { + TYPE_CHECK(FieldDescriptor::CPPTYPE_FLOAT, "MapValueRef::GetFloatValue"); + return *reinterpret_cast(data_); + } + double GetDoubleValue() const { + TYPE_CHECK(FieldDescriptor::CPPTYPE_DOUBLE, "MapValueRef::GetDoubleValue"); + return *reinterpret_cast(data_); + } + + const Message& GetMessageValue() const { + TYPE_CHECK(FieldDescriptor::CPPTYPE_MESSAGE, + "MapValueRef::GetMessageValue"); + return *reinterpret_cast(data_); + } + + Message* MutableMessageValue() { + TYPE_CHECK(FieldDescriptor::CPPTYPE_MESSAGE, + "MapValueRef::MutableMessageValue"); + return reinterpret_cast(data_); + } + + private: + template + friend class internal::MapField; + template + friend class internal::TypeDefinedMapFieldBase; + friend class ::PROTOBUF_NAMESPACE_ID::MapIterator; + friend class Reflection; + friend class internal::DynamicMapField; + + void SetType(FieldDescriptor::CppType type) { type_ = type; } + + FieldDescriptor::CppType type() const { + if (type_ == 0 || data_ == NULL) { + GOOGLE_LOG(FATAL) << "Protocol Buffer map usage error:\n" + << "MapValueRef::type MapValueRef is not initialized."; + } + return (FieldDescriptor::CppType)type_; + } + void SetValue(const void* val) { data_ = const_cast(val); } + void CopyFrom(const MapValueRef& other) { + type_ = other.type_; + data_ = other.data_; + } + // Only used in DynamicMapField + void DeleteData() { + switch (type_) { +#define HANDLE_TYPE(CPPTYPE, TYPE) \ + case FieldDescriptor::CPPTYPE_##CPPTYPE: { \ + delete reinterpret_cast(data_); \ + break; \ + } + HANDLE_TYPE(INT32, int32); + HANDLE_TYPE(INT64, int64); + HANDLE_TYPE(UINT32, uint32); + HANDLE_TYPE(UINT64, uint64); + HANDLE_TYPE(DOUBLE, double); + HANDLE_TYPE(FLOAT, float); + HANDLE_TYPE(BOOL, bool); + HANDLE_TYPE(STRING, std::string); + HANDLE_TYPE(ENUM, int32); + HANDLE_TYPE(MESSAGE, Message); +#undef HANDLE_TYPE + } + } + // data_ point to a map value. MapValueRef does not + // own this value. + void* data_; + // type_ is 0 or a valid FieldDescriptor::CppType. + int type_; +}; + +#undef TYPE_CHECK + +class PROTOBUF_EXPORT MapIterator { + public: + MapIterator(Message* message, const FieldDescriptor* field) { + const Reflection* reflection = message->GetReflection(); + map_ = reflection->MutableMapData(message, field); + key_.SetType(field->message_type()->FindFieldByName("key")->cpp_type()); + value_.SetType(field->message_type()->FindFieldByName("value")->cpp_type()); + map_->InitializeIterator(this); + } + MapIterator(const MapIterator& other) { + map_ = other.map_; + map_->InitializeIterator(this); + map_->CopyIterator(this, other); + } + ~MapIterator() { map_->DeleteIterator(this); } + MapIterator& operator=(const MapIterator& other) { + map_ = other.map_; + map_->CopyIterator(this, other); + return *this; + } + friend bool operator==(const MapIterator& a, const MapIterator& b) { + return a.map_->EqualIterator(a, b); + } + friend bool operator!=(const MapIterator& a, const MapIterator& b) { + return !a.map_->EqualIterator(a, b); + } + MapIterator& operator++() { + map_->IncreaseIterator(this); + return *this; + } + MapIterator operator++(int) { + // iter_ is copied from Map<...>::iterator, no need to + // copy from its self again. Use the same implementation + // with operator++() + map_->IncreaseIterator(this); + return *this; + } + const MapKey& GetKey() { return key_; } + const MapValueRef& GetValueRef() { return value_; } + MapValueRef* MutableValueRef() { + map_->SetMapDirty(); + return &value_; + } + + private: + template + friend class internal::TypeDefinedMapFieldBase; + friend class internal::DynamicMapField; + template + friend class internal::MapField; + + // reinterpret_cast from heap-allocated Map<...>::iterator*. MapIterator owns + // the iterator. It is allocated by MapField<...>::InitializeIterator() called + // in constructor and deleted by MapField<...>::DeleteIterator() called in + // destructor. + void* iter_; + // Point to a MapField to call helper methods implemented in MapField. + // MapIterator does not own this object. + internal::MapFieldBase* map_; + MapKey key_; + MapValueRef value_; +}; + +} // namespace protobuf +} // namespace google + +namespace std { +template <> +struct hash<::PROTOBUF_NAMESPACE_ID::MapKey> { + size_t operator()(const ::PROTOBUF_NAMESPACE_ID::MapKey& map_key) const { + switch (map_key.type()) { + case ::PROTOBUF_NAMESPACE_ID::FieldDescriptor::CPPTYPE_DOUBLE: + case ::PROTOBUF_NAMESPACE_ID::FieldDescriptor::CPPTYPE_FLOAT: + case ::PROTOBUF_NAMESPACE_ID::FieldDescriptor::CPPTYPE_ENUM: + case ::PROTOBUF_NAMESPACE_ID::FieldDescriptor::CPPTYPE_MESSAGE: + GOOGLE_LOG(FATAL) << "Unsupported"; + break; + case ::PROTOBUF_NAMESPACE_ID::FieldDescriptor::CPPTYPE_STRING: + return hash()(map_key.GetStringValue()); + case ::PROTOBUF_NAMESPACE_ID::FieldDescriptor::CPPTYPE_INT64: { + auto value = map_key.GetInt64Value(); + return hash()(value); + } + case ::PROTOBUF_NAMESPACE_ID::FieldDescriptor::CPPTYPE_INT32: { + auto value = map_key.GetInt32Value(); + return hash()(map_key.GetInt32Value()); + } + case ::PROTOBUF_NAMESPACE_ID::FieldDescriptor::CPPTYPE_UINT64: { + auto value = map_key.GetUInt64Value(); + return hash()(map_key.GetUInt64Value()); + } + case ::PROTOBUF_NAMESPACE_ID::FieldDescriptor::CPPTYPE_UINT32: { + auto value = map_key.GetUInt32Value(); + return hash()(map_key.GetUInt32Value()); + } + case ::PROTOBUF_NAMESPACE_ID::FieldDescriptor::CPPTYPE_BOOL: { + return hash()(map_key.GetBoolValue()); + } + } + GOOGLE_LOG(FATAL) << "Can't get here."; + return 0; + } + bool operator()(const ::PROTOBUF_NAMESPACE_ID::MapKey& map_key1, + const ::PROTOBUF_NAMESPACE_ID::MapKey& map_key2) const { + return map_key1 < map_key2; + } +}; +} // namespace std +#include + +#endif // GOOGLE_PROTOBUF_MAP_FIELD_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/map_field_inl.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/map_field_inl.h new file mode 100644 index 0000000000000000000000000000000000000000..bc4a6cc718cd19e17c4d00398147b0846d9bfbe6 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/map_field_inl.h @@ -0,0 +1,362 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef GOOGLE_PROTOBUF_MAP_FIELD_INL_H__ +#define GOOGLE_PROTOBUF_MAP_FIELD_INL_H__ + +#include + +#include +#include +#include +#include + +#ifdef SWIG +#error "You cannot SWIG proto headers" +#endif + +namespace google { +namespace protobuf { +namespace internal { +// UnwrapMapKey template +template +T UnwrapMapKey(const MapKey& map_key); +template <> +inline int32 UnwrapMapKey(const MapKey& map_key) { + return map_key.GetInt32Value(); +} +template <> +inline uint32 UnwrapMapKey(const MapKey& map_key) { + return map_key.GetUInt32Value(); +} +template <> +inline int64 UnwrapMapKey(const MapKey& map_key) { + return map_key.GetInt64Value(); +} +template <> +inline uint64 UnwrapMapKey(const MapKey& map_key) { + return map_key.GetUInt64Value(); +} +template <> +inline bool UnwrapMapKey(const MapKey& map_key) { + return map_key.GetBoolValue(); +} +template <> +inline std::string UnwrapMapKey(const MapKey& map_key) { + return map_key.GetStringValue(); +} + +// SetMapKey template +template +inline void SetMapKey(MapKey* map_key, const T& value); +template <> +inline void SetMapKey(MapKey* map_key, const int32& value) { + map_key->SetInt32Value(value); +} +template <> +inline void SetMapKey(MapKey* map_key, const uint32& value) { + map_key->SetUInt32Value(value); +} +template <> +inline void SetMapKey(MapKey* map_key, const int64& value) { + map_key->SetInt64Value(value); +} +template <> +inline void SetMapKey(MapKey* map_key, const uint64& value) { + map_key->SetUInt64Value(value); +} +template <> +inline void SetMapKey(MapKey* map_key, const bool& value) { + map_key->SetBoolValue(value); +} +template <> +inline void SetMapKey(MapKey* map_key, const std::string& value) { + map_key->SetStringValue(value); +} + +// ------------------------TypeDefinedMapFieldBase--------------- +template +typename Map::const_iterator& +TypeDefinedMapFieldBase::InternalGetIterator( + const MapIterator* map_iter) const { + return *reinterpret_cast::const_iterator*>( + map_iter->iter_); +} + +template +void TypeDefinedMapFieldBase::MapBegin(MapIterator* map_iter) const { + InternalGetIterator(map_iter) = GetMap().begin(); + SetMapIteratorValue(map_iter); +} + +template +void TypeDefinedMapFieldBase::MapEnd(MapIterator* map_iter) const { + InternalGetIterator(map_iter) = GetMap().end(); +} + +template +bool TypeDefinedMapFieldBase::EqualIterator( + const MapIterator& a, const MapIterator& b) const { + return InternalGetIterator(&a) == InternalGetIterator(&b); +} + +template +void TypeDefinedMapFieldBase::IncreaseIterator( + MapIterator* map_iter) const { + ++InternalGetIterator(map_iter); + SetMapIteratorValue(map_iter); +} + +template +void TypeDefinedMapFieldBase::InitializeIterator( + MapIterator* map_iter) const { + map_iter->iter_ = new typename Map::const_iterator; + GOOGLE_CHECK(map_iter->iter_ != NULL); +} + +template +void TypeDefinedMapFieldBase::DeleteIterator( + MapIterator* map_iter) const { + delete reinterpret_cast::const_iterator*>( + map_iter->iter_); +} + +template +void TypeDefinedMapFieldBase::CopyIterator( + MapIterator* this_iter, const MapIterator& that_iter) const { + InternalGetIterator(this_iter) = InternalGetIterator(&that_iter); + this_iter->key_.SetType(that_iter.key_.type()); + // MapValueRef::type() fails when containing data is null. However, if + // this_iter points to MapEnd, data can be null. + this_iter->value_.SetType( + static_cast(that_iter.value_.type_)); + SetMapIteratorValue(this_iter); +} + +// ---------------------------------------------------------------------- + +template +int MapField::size() const { + MapFieldBase::SyncMapWithRepeatedField(); + return static_cast(impl_.GetMap().size()); +} + +template +void MapField::Clear() { + if (this->MapFieldBase::repeated_field_ != nullptr) { + RepeatedPtrField* repeated_field = + reinterpret_cast*>( + this->MapFieldBase::repeated_field_); + repeated_field->Clear(); + } + + impl_.MutableMap()->clear(); + // Data in map and repeated field are both empty, but we can't set status + // CLEAN. Because clear is a generated API, we cannot invalidate previous + // reference to map. + MapFieldBase::SetMapDirty(); +} + +template +void MapField::SetMapIteratorValue(MapIterator* map_iter) + const { + const Map& map = impl_.GetMap(); + typename Map::const_iterator iter = + TypeDefinedMapFieldBase::InternalGetIterator(map_iter); + if (iter == map.end()) return; + SetMapKey(&map_iter->key_, iter->first); + map_iter->value_.SetValue(&iter->second); +} + +template +bool MapField::ContainsMapKey(const MapKey& map_key) const { + const Map& map = impl_.GetMap(); + const Key& key = UnwrapMapKey(map_key); + typename Map::const_iterator iter = map.find(key); + return iter != map.end(); +} + +template +bool MapField::InsertOrLookupMapValue(const MapKey& map_key, + MapValueRef* val) { + // Always use mutable map because users may change the map value by + // MapValueRef. + Map* map = MutableMap(); + const Key& key = UnwrapMapKey(map_key); + typename Map::iterator iter = map->find(key); + if (map->end() == iter) { + val->SetValue(&((*map)[key])); + return true; + } + // Key is already in the map. Make sure (*map)[key] is not called. + // [] may reorder the map and iterators. + val->SetValue(&(iter->second)); + return false; +} + +template +bool MapField::DeleteMapValue(const MapKey& map_key) { + const Key& key = UnwrapMapKey(map_key); + return MutableMap()->erase(key); +} + +template +void MapField::MergeFrom(const MapFieldBase& other) { + MapFieldBase::SyncMapWithRepeatedField(); + const MapField& other_field = static_cast(other); + other_field.SyncMapWithRepeatedField(); + impl_.MergeFrom(other_field.impl_); + MapFieldBase::SetMapDirty(); +} + +template +void MapField::Swap(MapFieldBase* other) { + MapField* other_field = down_cast(other); + std::swap(this->MapFieldBase::repeated_field_, other_field->repeated_field_); + impl_.Swap(&other_field->impl_); + // a relaxed swap of the atomic + auto other_state = other_field->state_.load(std::memory_order_relaxed); + auto this_state = this->MapFieldBase::state_.load(std::memory_order_relaxed); + other_field->state_.store(this_state, std::memory_order_relaxed); + this->MapFieldBase::state_.store(other_state, std::memory_order_relaxed); +} + +template +void MapField::SyncRepeatedFieldWithMapNoLock() const { + if (this->MapFieldBase::repeated_field_ == NULL) { + if (this->MapFieldBase::arena_ == NULL) { + this->MapFieldBase::repeated_field_ = new RepeatedPtrField(); + } else { + this->MapFieldBase::repeated_field_ = + Arena::CreateMessage >( + this->MapFieldBase::arena_); + } + } + const Map& map = impl_.GetMap(); + RepeatedPtrField* repeated_field = + reinterpret_cast*>( + this->MapFieldBase::repeated_field_); + + repeated_field->Clear(); + + // The only way we can get at this point is through reflection and the + // only way we can get the reflection object is by having called GetReflection + // on the encompassing field. So that type must have existed and hence we + // know that this MapEntry default_type has also already been constructed. + // So it's safe to just call internal_default_instance(). + const Message* default_entry = Derived::internal_default_instance(); + for (typename Map::const_iterator it = map.begin(); it != map.end(); + ++it) { + EntryType* new_entry = + down_cast(default_entry->New(this->MapFieldBase::arena_)); + repeated_field->AddAllocated(new_entry); + (*new_entry->mutable_key()) = it->first; + (*new_entry->mutable_value()) = it->second; + } +} + +template +void MapField::SyncMapWithRepeatedFieldNoLock() const { + Map* map = const_cast(this)->impl_.MutableMap(); + RepeatedPtrField* repeated_field = + reinterpret_cast*>( + this->MapFieldBase::repeated_field_); + GOOGLE_CHECK(this->MapFieldBase::repeated_field_ != NULL); + map->clear(); + for (typename RepeatedPtrField::iterator it = + repeated_field->begin(); + it != repeated_field->end(); ++it) { + // Cast is needed because Map's api and internal storage is different when + // value is enum. For enum, we cannot cast an int to enum. Thus, we have to + // copy value. For other types, they have same exposed api type and internal + // stored type. We should not introduce value copy for them. We achieve this + // by casting to value for enum while casting to reference for other types. + (*map)[it->key()] = static_cast(it->value()); + } +} + +template +size_t MapField::SpaceUsedExcludingSelfNoLock() const { + size_t size = 0; + if (this->MapFieldBase::repeated_field_ != NULL) { + size += this->MapFieldBase::repeated_field_->SpaceUsedExcludingSelfLong(); + } + Map* map = const_cast(this)->impl_.MutableMap(); + size += sizeof(*map); + for (typename Map::iterator it = map->begin(); it != map->end(); + ++it) { + size += KeyTypeHandler::SpaceUsedInMapLong(it->first); + size += ValueTypeHandler::SpaceUsedInMapLong(it->second); + } + return size; +} +} // namespace internal +} // namespace protobuf +} // namespace google + +#endif // GOOGLE_PROTOBUF_MAP_FIELD_INL_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/map_field_lite.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/map_field_lite.h new file mode 100644 index 0000000000000000000000000000000000000000..a8e04ca67aa1cfbe4f980ad5292140dcd04e022a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/map_field_lite.h @@ -0,0 +1,195 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef GOOGLE_PROTOBUF_MAP_FIELD_LITE_H__ +#define GOOGLE_PROTOBUF_MAP_FIELD_LITE_H__ + +#include +#include +#include +#include +#include +#include +#include + +#include + +#ifdef SWIG +#error "You cannot SWIG proto headers" +#endif + +namespace google { +namespace protobuf { +namespace internal { + +// This class provides access to map field using generated api. It is used for +// internal generated message implentation only. Users should never use this +// directly. +template +class MapFieldLite { + // Define message type for internal repeated field. + typedef Derived EntryType; + + public: + typedef Map MapType; + typedef EntryType EntryTypeTrait; + + MapFieldLite() { SetDefaultEnumValue(); } + + explicit MapFieldLite(Arena* arena) : map_(arena) { SetDefaultEnumValue(); } + + // Accessors + const Map& GetMap() const { return map_; } + Map* MutableMap() { return &map_; } + + // Convenient methods for generated message implementation. + int size() const { return static_cast(map_.size()); } + void Clear() { return map_.clear(); } + void MergeFrom(const MapFieldLite& other) { + for (typename Map::const_iterator it = other.map_.begin(); + it != other.map_.end(); ++it) { + map_[it->first] = it->second; + } + } + void Swap(MapFieldLite* other) { map_.swap(other->map_); } + + // Set default enum value only for proto2 map field whose value is enum type. + void SetDefaultEnumValue() { + MutableMap()->SetDefaultEnumValue(default_enum_value); + } + + // Used in the implementation of parsing. Caller should take the ownership iff + // arena_ is NULL. + EntryType* NewEntry() const { + return Arena::CreateMessage(map_.arena_); + } + // Used in the implementation of serializing enum value type. Caller should + // take the ownership iff arena_ is NULL. + EntryType* NewEnumEntryWrapper(const Key& key, const T t) const { + return EntryType::EnumWrap(key, t, map_.arena_); + } + // Used in the implementation of serializing other value types. Caller should + // take the ownership iff arena_ is NULL. + EntryType* NewEntryWrapper(const Key& key, const T& t) const { + return EntryType::Wrap(key, t, map_.arena_); + } + + const char* _InternalParse(const char* ptr, ParseContext* ctx) { + typename Derived::template Parser> parser(this); + return parser._InternalParse(ptr, ctx); + } + + template + const char* ParseWithEnumValidation(const char* ptr, ParseContext* ctx, + bool (*is_valid)(int), uint32 field_num, + InternalMetadata* metadata) { + typename Derived::template Parser> parser(this); + return parser.template ParseWithEnumValidation( + ptr, ctx, is_valid, field_num, metadata); + } + + private: + typedef void DestructorSkippable_; + + Map map_; + + friend class ::PROTOBUF_NAMESPACE_ID::Arena; +}; + +template +struct EnumParseWrapper { + const char* _InternalParse(const char* ptr, ParseContext* ctx) { + return map_field->template ParseWithEnumValidation( + ptr, ctx, is_valid, field_num, metadata); + } + T* map_field; + bool (*is_valid)(int); + uint32 field_num; + InternalMetadata* metadata; +}; + +// Helper function because the typenames of maps are horrendous to print. This +// leverages compiler type deduction, to keep all type data out of the +// generated code +template +EnumParseWrapper InitEnumParseWrapper( + T* map_field, bool (*is_valid)(int), uint32 field_num, + InternalMetadata* metadata) { + return EnumParseWrapper{map_field, is_valid, field_num, + metadata}; +} + +// True if IsInitialized() is true for value field in all elements of t. T is +// expected to be message. It's useful to have this helper here to keep the +// protobuf compiler from ever having to emit loops in IsInitialized() methods. +// We want the C++ compiler to inline this or not as it sees fit. +template +bool AllAreInitialized( + const MapFieldLite& field) { + const auto& t = field.GetMap(); + for (typename Map::const_iterator it = t.begin(); it != t.end(); + ++it) { + if (!it->second.IsInitialized()) return false; + } + return true; +} + +template +struct MapEntryToMapField : MapEntryToMapField {}; + +template +struct MapEntryToMapField> { + typedef MapFieldLite, + Key, Value, kKeyFieldType, kValueFieldType, + default_enum_value> + MapFieldType; +}; + +} // namespace internal +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_MAP_FIELD_LITE_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/map_type_handler.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/map_type_handler.h new file mode 100644 index 0000000000000000000000000000000000000000..d0169bef30b5f7a3a41bfe301750b8ca3aa93a08 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/map_type_handler.h @@ -0,0 +1,812 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef GOOGLE_PROTOBUF_TYPE_HANDLER_H__ +#define GOOGLE_PROTOBUF_TYPE_HANDLER_H__ + +#include +#include +#include +#include + +#ifdef SWIG +#error "You cannot SWIG proto headers" +#endif + +namespace google { +namespace protobuf { +namespace internal { + +// Used for compile time type selection. MapIf::type will be TrueType if Flag is +// true and FalseType otherwise. +template +struct MapIf; + +template +struct MapIf { + typedef TrueType type; +}; + +template +struct MapIf { + typedef FalseType type; +}; + +// In proto2 Map, enum needs to be initialized to given default value, while +// other types' default value can be inferred from the type. +template +class MapValueInitializer { + public: + static inline void Initialize(Type& type, int default_enum_value); +}; + +template +class MapValueInitializer { + public: + static inline void Initialize(Type& value, int default_enum_value) { + value = static_cast(default_enum_value); + } +}; + +template +class MapValueInitializer { + public: + static inline void Initialize(Type& /* value */, + int /* default_enum_value */) {} +}; + +template +class MapArenaMessageCreator { + public: + // Use arena to create message if Type is arena constructable. Otherwise, + // create the message on heap. + static inline Type* CreateMessage(Arena* arena); +}; +template +class MapArenaMessageCreator { + public: + static inline Type* CreateMessage(Arena* arena) { + return Arena::CreateMessage(arena); + } +}; +template +class MapArenaMessageCreator { + public: + static inline Type* CreateMessage(Arena* arena) { + return Arena::Create(arena); + } +}; + +// Define constants for given wire field type +template +class MapWireFieldTypeTraits {}; + +#define TYPE_TRAITS(FieldType, CType, WireFormatType, IsMessage, IsEnum) \ + template \ + class MapWireFieldTypeTraits { \ + public: \ + static const bool kIsMessage = IsMessage; \ + static const bool kIsEnum = IsEnum; \ + typedef typename MapIf::type TypeOnMemory; \ + typedef typename MapIf::type MapEntryAccessorType; \ + static const WireFormatLite::WireType kWireType = \ + WireFormatLite::WIRETYPE_##WireFormatType; \ + }; + +TYPE_TRAITS(MESSAGE, Type, LENGTH_DELIMITED, true, false) +TYPE_TRAITS(STRING, ArenaStringPtr, LENGTH_DELIMITED, false, false) +TYPE_TRAITS(BYTES, ArenaStringPtr, LENGTH_DELIMITED, false, false) +TYPE_TRAITS(INT64, int64, VARINT, false, false) +TYPE_TRAITS(UINT64, uint64, VARINT, false, false) +TYPE_TRAITS(INT32, int32, VARINT, false, false) +TYPE_TRAITS(UINT32, uint32, VARINT, false, false) +TYPE_TRAITS(SINT64, int64, VARINT, false, false) +TYPE_TRAITS(SINT32, int32, VARINT, false, false) +TYPE_TRAITS(ENUM, int, VARINT, false, true) +TYPE_TRAITS(DOUBLE, double, FIXED64, false, false) +TYPE_TRAITS(FLOAT, float, FIXED32, false, false) +TYPE_TRAITS(FIXED64, uint64, FIXED64, false, false) +TYPE_TRAITS(FIXED32, uint32, FIXED32, false, false) +TYPE_TRAITS(SFIXED64, int64, FIXED64, false, false) +TYPE_TRAITS(SFIXED32, int32, FIXED32, false, false) +TYPE_TRAITS(BOOL, bool, VARINT, false, false) + +#undef TYPE_TRAITS + +template +class MapTypeHandler {}; + +template +class MapTypeHandler { + public: + // Enum type cannot be used for MapTypeHandler::Read. Define a type which will + // replace Enum with int. + typedef typename MapWireFieldTypeTraits::MapEntryAccessorType + MapEntryAccessorType; + // Internal stored type in MapEntryLite for given wire field type. + typedef typename MapWireFieldTypeTraits::TypeOnMemory TypeOnMemory; + // Corresponding wire type for field type. + static constexpr WireFormatLite::WireType kWireType = + MapWireFieldTypeTraits::kWireType; + // Whether wire type is for message. + static constexpr bool kIsMessage = + MapWireFieldTypeTraits::kIsMessage; + // Whether wire type is for enum. + static constexpr bool kIsEnum = + MapWireFieldTypeTraits::kIsEnum; + + // Functions used in parsing and serialization. =================== + static inline size_t ByteSize(const MapEntryAccessorType& value); + static inline int GetCachedSize(const MapEntryAccessorType& value); + static inline bool Read(io::CodedInputStream* input, + MapEntryAccessorType* value); + static inline const char* Read(const char* ptr, ParseContext* ctx, + MapEntryAccessorType* value); + + static inline uint8* Write(int field, const MapEntryAccessorType& value, + uint8* ptr, io::EpsCopyOutputStream* stream); + + // Functions to manipulate data on memory. ======================== + static inline const Type& GetExternalReference(const Type* value); + static inline void DeleteNoArena(const Type* x); + static inline void Merge(const Type& from, Type** to, Arena* arena); + static inline void Clear(Type** value, Arena* arena); + static inline void ClearMaybeByDefaultEnum(Type** value, Arena* arena, + int default_enum_value); + static inline void Initialize(Type** x, Arena* arena); + + static inline void InitializeMaybeByDefaultEnum(Type** x, + int default_enum_value, + Arena* arena); + static inline Type* EnsureMutable(Type** value, Arena* arena); + // SpaceUsedInMapEntry: Return bytes used by value in MapEntry, excluding + // those already calculate in sizeof(MapField). + static inline size_t SpaceUsedInMapEntryLong(const Type* value); + // Return bytes used by value in Map. + static inline size_t SpaceUsedInMapLong(const Type& value); + // Assign default value to given instance. + static inline void AssignDefaultValue(Type** value); + // Return default instance if value is not initialized when calling const + // reference accessor. + static inline const Type& DefaultIfNotInitialized(const Type* value, + const Type* default_value); + // Check if all required fields have values set. + static inline bool IsInitialized(Type* value); +}; + +#define MAP_HANDLER(FieldType) \ + template \ + class MapTypeHandler { \ + public: \ + typedef typename MapWireFieldTypeTraits::MapEntryAccessorType \ + MapEntryAccessorType; \ + typedef typename MapWireFieldTypeTraits::TypeOnMemory TypeOnMemory; \ + static const WireFormatLite::WireType kWireType = \ + MapWireFieldTypeTraits::kWireType; \ + static const bool kIsMessage = \ + MapWireFieldTypeTraits::kIsMessage; \ + static const bool kIsEnum = \ + MapWireFieldTypeTraits::kIsEnum; \ + static inline int ByteSize(const MapEntryAccessorType& value); \ + static inline int GetCachedSize(const MapEntryAccessorType& value); \ + static inline bool Read(io::CodedInputStream* input, \ + MapEntryAccessorType* value); \ + static inline const char* Read(const char* begin, ParseContext* ctx, \ + MapEntryAccessorType* value); \ + static inline uint8* Write(int field, const MapEntryAccessorType& value, \ + uint8* ptr, io::EpsCopyOutputStream* stream); \ + static inline const MapEntryAccessorType& GetExternalReference( \ + const TypeOnMemory& value); \ + static inline void DeleteNoArena(const TypeOnMemory& x); \ + static inline void Merge(const MapEntryAccessorType& from, \ + TypeOnMemory* to, Arena* arena); \ + static inline void Clear(TypeOnMemory* value, Arena* arena); \ + static inline void ClearMaybeByDefaultEnum(TypeOnMemory* value, \ + Arena* arena, \ + int default_enum); \ + static inline size_t SpaceUsedInMapEntryLong(const TypeOnMemory& value); \ + static inline size_t SpaceUsedInMapLong(const TypeOnMemory& value); \ + static inline size_t SpaceUsedInMapLong(ConstStringParam value); \ + static inline void AssignDefaultValue(TypeOnMemory* value); \ + static inline const MapEntryAccessorType& DefaultIfNotInitialized( \ + const TypeOnMemory& value, const TypeOnMemory& default_value); \ + static inline bool IsInitialized(const TypeOnMemory& value); \ + static void DeleteNoArena(TypeOnMemory& value); \ + static inline void Initialize(TypeOnMemory* value, Arena* arena); \ + static inline void InitializeMaybeByDefaultEnum(TypeOnMemory* value, \ + int default_enum_value, \ + Arena* arena); \ + static inline MapEntryAccessorType* EnsureMutable(TypeOnMemory* value, \ + Arena* arena); \ + }; +MAP_HANDLER(STRING) +MAP_HANDLER(BYTES) +MAP_HANDLER(INT64) +MAP_HANDLER(UINT64) +MAP_HANDLER(INT32) +MAP_HANDLER(UINT32) +MAP_HANDLER(SINT64) +MAP_HANDLER(SINT32) +MAP_HANDLER(ENUM) +MAP_HANDLER(DOUBLE) +MAP_HANDLER(FLOAT) +MAP_HANDLER(FIXED64) +MAP_HANDLER(FIXED32) +MAP_HANDLER(SFIXED64) +MAP_HANDLER(SFIXED32) +MAP_HANDLER(BOOL) +#undef MAP_HANDLER + +template +inline size_t MapTypeHandler::ByteSize( + const MapEntryAccessorType& value) { + return WireFormatLite::MessageSizeNoVirtual(value); +} + +#define GOOGLE_PROTOBUF_BYTE_SIZE(FieldType, DeclaredType) \ + template \ + inline int MapTypeHandler::ByteSize( \ + const MapEntryAccessorType& value) { \ + return static_cast(WireFormatLite::DeclaredType##Size(value)); \ + } + +GOOGLE_PROTOBUF_BYTE_SIZE(STRING, String) +GOOGLE_PROTOBUF_BYTE_SIZE(BYTES, Bytes) +GOOGLE_PROTOBUF_BYTE_SIZE(INT64, Int64) +GOOGLE_PROTOBUF_BYTE_SIZE(UINT64, UInt64) +GOOGLE_PROTOBUF_BYTE_SIZE(INT32, Int32) +GOOGLE_PROTOBUF_BYTE_SIZE(UINT32, UInt32) +GOOGLE_PROTOBUF_BYTE_SIZE(SINT64, SInt64) +GOOGLE_PROTOBUF_BYTE_SIZE(SINT32, SInt32) +GOOGLE_PROTOBUF_BYTE_SIZE(ENUM, Enum) + +#undef GOOGLE_PROTOBUF_BYTE_SIZE + +#define FIXED_BYTE_SIZE(FieldType, DeclaredType) \ + template \ + inline int MapTypeHandler::ByteSize( \ + const MapEntryAccessorType& /* value */) { \ + return WireFormatLite::k##DeclaredType##Size; \ + } + +FIXED_BYTE_SIZE(DOUBLE, Double) +FIXED_BYTE_SIZE(FLOAT, Float) +FIXED_BYTE_SIZE(FIXED64, Fixed64) +FIXED_BYTE_SIZE(FIXED32, Fixed32) +FIXED_BYTE_SIZE(SFIXED64, SFixed64) +FIXED_BYTE_SIZE(SFIXED32, SFixed32) +FIXED_BYTE_SIZE(BOOL, Bool) + +#undef FIXED_BYTE_SIZE + +template +inline int MapTypeHandler::GetCachedSize( + const MapEntryAccessorType& value) { + return static_cast(WireFormatLite::LengthDelimitedSize( + static_cast(value.GetCachedSize()))); +} + +#define GET_CACHED_SIZE(FieldType, DeclaredType) \ + template \ + inline int \ + MapTypeHandler::GetCachedSize( \ + const MapEntryAccessorType& value) { \ + return static_cast(WireFormatLite::DeclaredType##Size(value)); \ + } + +GET_CACHED_SIZE(STRING, String) +GET_CACHED_SIZE(BYTES, Bytes) +GET_CACHED_SIZE(INT64, Int64) +GET_CACHED_SIZE(UINT64, UInt64) +GET_CACHED_SIZE(INT32, Int32) +GET_CACHED_SIZE(UINT32, UInt32) +GET_CACHED_SIZE(SINT64, SInt64) +GET_CACHED_SIZE(SINT32, SInt32) +GET_CACHED_SIZE(ENUM, Enum) + +#undef GET_CACHED_SIZE + +#define GET_FIXED_CACHED_SIZE(FieldType, DeclaredType) \ + template \ + inline int \ + MapTypeHandler::GetCachedSize( \ + const MapEntryAccessorType& /* value */) { \ + return WireFormatLite::k##DeclaredType##Size; \ + } + +GET_FIXED_CACHED_SIZE(DOUBLE, Double) +GET_FIXED_CACHED_SIZE(FLOAT, Float) +GET_FIXED_CACHED_SIZE(FIXED64, Fixed64) +GET_FIXED_CACHED_SIZE(FIXED32, Fixed32) +GET_FIXED_CACHED_SIZE(SFIXED64, SFixed64) +GET_FIXED_CACHED_SIZE(SFIXED32, SFixed32) +GET_FIXED_CACHED_SIZE(BOOL, Bool) + +#undef GET_FIXED_CACHED_SIZE + +template +inline uint8* MapTypeHandler::Write( + int field, const MapEntryAccessorType& value, uint8* ptr, + io::EpsCopyOutputStream* stream) { + ptr = stream->EnsureSpace(ptr); + return WireFormatLite::InternalWriteMessage(field, value, ptr, stream); +} + +#define WRITE_METHOD(FieldType, DeclaredType) \ + template \ + inline uint8* MapTypeHandler::Write( \ + int field, const MapEntryAccessorType& value, uint8* ptr, \ + io::EpsCopyOutputStream* stream) { \ + ptr = stream->EnsureSpace(ptr); \ + return stream->Write##DeclaredType(field, value, ptr); \ + } + +WRITE_METHOD(STRING, String) +WRITE_METHOD(BYTES, Bytes) + +#undef WRITE_METHOD +#define WRITE_METHOD(FieldType, DeclaredType) \ + template \ + inline uint8* MapTypeHandler::Write( \ + int field, const MapEntryAccessorType& value, uint8* ptr, \ + io::EpsCopyOutputStream* stream) { \ + ptr = stream->EnsureSpace(ptr); \ + return WireFormatLite::Write##DeclaredType##ToArray(field, value, ptr); \ + } + +WRITE_METHOD(INT64, Int64) +WRITE_METHOD(UINT64, UInt64) +WRITE_METHOD(INT32, Int32) +WRITE_METHOD(UINT32, UInt32) +WRITE_METHOD(SINT64, SInt64) +WRITE_METHOD(SINT32, SInt32) +WRITE_METHOD(ENUM, Enum) +WRITE_METHOD(DOUBLE, Double) +WRITE_METHOD(FLOAT, Float) +WRITE_METHOD(FIXED64, Fixed64) +WRITE_METHOD(FIXED32, Fixed32) +WRITE_METHOD(SFIXED64, SFixed64) +WRITE_METHOD(SFIXED32, SFixed32) +WRITE_METHOD(BOOL, Bool) + +#undef WRITE_METHOD + +template +inline bool MapTypeHandler::Read( + io::CodedInputStream* input, MapEntryAccessorType* value) { + return WireFormatLite::ReadMessageNoVirtual(input, value); +} + +template +inline bool MapTypeHandler::Read( + io::CodedInputStream* input, MapEntryAccessorType* value) { + return WireFormatLite::ReadString(input, value); +} + +template +inline bool MapTypeHandler::Read( + io::CodedInputStream* input, MapEntryAccessorType* value) { + return WireFormatLite::ReadBytes(input, value); +} + +template +const char* MapTypeHandler::Read( + const char* ptr, ParseContext* ctx, MapEntryAccessorType* value) { + return ctx->ParseMessage(value, ptr); +} + +template +const char* MapTypeHandler::Read( + const char* ptr, ParseContext* ctx, MapEntryAccessorType* value) { + int size = ReadSize(&ptr); + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); + return ctx->ReadString(ptr, size, value); +} + +template +const char* MapTypeHandler::Read( + const char* ptr, ParseContext* ctx, MapEntryAccessorType* value) { + int size = ReadSize(&ptr); + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); + return ctx->ReadString(ptr, size, value); +} + +inline const char* ReadINT64(const char* ptr, int64* value) { + return VarintParse(ptr, reinterpret_cast(value)); +} +inline const char* ReadUINT64(const char* ptr, uint64* value) { + return VarintParse(ptr, value); +} +inline const char* ReadINT32(const char* ptr, int32* value) { + return VarintParse(ptr, reinterpret_cast(value)); +} +inline const char* ReadUINT32(const char* ptr, uint32* value) { + return VarintParse(ptr, value); +} +inline const char* ReadSINT64(const char* ptr, int64* value) { + *value = ReadVarintZigZag64(&ptr); + return ptr; +} +inline const char* ReadSINT32(const char* ptr, int32* value) { + *value = ReadVarintZigZag32(&ptr); + return ptr; +} +template +inline const char* ReadENUM(const char* ptr, E* value) { + *value = static_cast(ReadVarint32(&ptr)); + return ptr; +} +inline const char* ReadBOOL(const char* ptr, bool* value) { + *value = static_cast(ReadVarint32(&ptr)); + return ptr; +} + +template +inline const char* ReadUnaligned(const char* ptr, F* value) { + *value = UnalignedLoad(ptr); + return ptr + sizeof(F); +} +inline const char* ReadFLOAT(const char* ptr, float* value) { + return ReadUnaligned(ptr, value); +} +inline const char* ReadDOUBLE(const char* ptr, double* value) { + return ReadUnaligned(ptr, value); +} +inline const char* ReadFIXED64(const char* ptr, uint64* value) { + return ReadUnaligned(ptr, value); +} +inline const char* ReadFIXED32(const char* ptr, uint32* value) { + return ReadUnaligned(ptr, value); +} +inline const char* ReadSFIXED64(const char* ptr, int64* value) { + return ReadUnaligned(ptr, value); +} +inline const char* ReadSFIXED32(const char* ptr, int32* value) { + return ReadUnaligned(ptr, value); +} + +#define READ_METHOD(FieldType) \ + template \ + inline bool MapTypeHandler::Read( \ + io::CodedInputStream* input, MapEntryAccessorType* value) { \ + return WireFormatLite::ReadPrimitive( \ + input, value); \ + } \ + template \ + const char* MapTypeHandler::Read( \ + const char* begin, ParseContext* ctx, MapEntryAccessorType* value) { \ + (void)ctx; \ + return Read##FieldType(begin, value); \ + } + +READ_METHOD(INT64) +READ_METHOD(UINT64) +READ_METHOD(INT32) +READ_METHOD(UINT32) +READ_METHOD(SINT64) +READ_METHOD(SINT32) +READ_METHOD(ENUM) +READ_METHOD(DOUBLE) +READ_METHOD(FLOAT) +READ_METHOD(FIXED64) +READ_METHOD(FIXED32) +READ_METHOD(SFIXED64) +READ_METHOD(SFIXED32) +READ_METHOD(BOOL) + +#undef READ_METHOD + +// Definition for message handler + +template +inline const Type& +MapTypeHandler::GetExternalReference( + const Type* value) { + return *value; +} + +template +inline size_t MapTypeHandler::SpaceUsedInMapEntryLong(const Type* value) { + return value->SpaceUsedLong(); +} + +template +size_t MapTypeHandler::SpaceUsedInMapLong( + const Type& value) { + return value.SpaceUsedLong(); +} + +template +inline void MapTypeHandler::Clear( + Type** value, Arena* /* arena */) { + if (*value != NULL) (*value)->Clear(); +} +template +inline void +MapTypeHandler::ClearMaybeByDefaultEnum( + Type** value, Arena* /* arena */, int /* default_enum_value */) { + if (*value != NULL) (*value)->Clear(); +} +template +inline void MapTypeHandler::Merge( + const Type& from, Type** to, Arena* /* arena */) { + (*to)->MergeFrom(from); +} + +template +void MapTypeHandler::DeleteNoArena( + const Type* ptr) { + delete ptr; +} + +template +inline void MapTypeHandler::AssignDefaultValue(Type** value) { + *value = const_cast(Type::internal_default_instance()); +} + +template +inline void MapTypeHandler::Initialize( + Type** x, Arena* /* arena */) { + *x = NULL; +} + +template +inline void MapTypeHandler:: + InitializeMaybeByDefaultEnum(Type** x, int /* default_enum_value */, + Arena* /* arena */) { + *x = NULL; +} + +template +inline Type* MapTypeHandler::EnsureMutable( + Type** value, Arena* arena) { + if (*value == NULL) { + *value = MapArenaMessageCreator< + Type, + Arena::is_arena_constructable::type::value>::CreateMessage(arena); + } + return *value; +} + +template +inline const Type& +MapTypeHandler::DefaultIfNotInitialized( + const Type* value, const Type* default_value) { + return value != NULL ? *value : *default_value; +} + +template +inline bool MapTypeHandler::IsInitialized( + Type* value) { + return value ? value->IsInitialized() : false; +} + +// Definition for string/bytes handler + +#define STRING_OR_BYTES_HANDLER_FUNCTIONS(FieldType) \ + template \ + inline const typename MapTypeHandler::MapEntryAccessorType& \ + MapTypeHandler::GetExternalReference(const TypeOnMemory& value) { \ + return value.Get(); \ + } \ + template \ + inline size_t \ + MapTypeHandler::SpaceUsedInMapEntryLong(const TypeOnMemory& value) { \ + return sizeof(value); \ + } \ + template \ + inline size_t \ + MapTypeHandler::SpaceUsedInMapLong( \ + const TypeOnMemory& value) { \ + return sizeof(value); \ + } \ + template \ + inline size_t \ + MapTypeHandler::SpaceUsedInMapLong( \ + ConstStringParam value) { \ + return sizeof(std::string); \ + } \ + template \ + inline void MapTypeHandler::Clear( \ + TypeOnMemory* value, Arena* arena) { \ + value->ClearToEmpty(&internal::GetEmptyStringAlreadyInited(), arena); \ + } \ + template \ + inline void MapTypeHandler:: \ + ClearMaybeByDefaultEnum(TypeOnMemory* value, Arena* arena, \ + int /* default_enum */) { \ + Clear(value, arena); \ + } \ + template \ + inline void MapTypeHandler::Merge( \ + const MapEntryAccessorType& from, TypeOnMemory* to, Arena* arena) { \ + to->Set(&internal::GetEmptyStringAlreadyInited(), from, arena); \ + } \ + template \ + void MapTypeHandler::DeleteNoArena( \ + TypeOnMemory& value) { \ + value.DestroyNoArena(&internal::GetEmptyStringAlreadyInited()); \ + } \ + template \ + inline void \ + MapTypeHandler::AssignDefaultValue( \ + TypeOnMemory* /* value */) {} \ + template \ + inline void \ + MapTypeHandler::Initialize( \ + TypeOnMemory* value, Arena* /* arena */) { \ + value->UnsafeSetDefault(&internal::GetEmptyStringAlreadyInited()); \ + } \ + template \ + inline void MapTypeHandler:: \ + InitializeMaybeByDefaultEnum( \ + TypeOnMemory* value, int /* default_enum_value */, Arena* arena) { \ + Initialize(value, arena); \ + } \ + template \ + inline typename MapTypeHandler::MapEntryAccessorType* \ + MapTypeHandler::EnsureMutable( \ + TypeOnMemory* value, Arena* arena) { \ + return value->Mutable(&internal::GetEmptyStringAlreadyInited(), arena); \ + } \ + template \ + inline const typename MapTypeHandler::MapEntryAccessorType& \ + MapTypeHandler:: \ + DefaultIfNotInitialized(const TypeOnMemory& value, \ + const TypeOnMemory& /* default_value */) { \ + return value.Get(); \ + } \ + template \ + inline bool \ + MapTypeHandler::IsInitialized( \ + const TypeOnMemory& /* value */) { \ + return true; \ + } +STRING_OR_BYTES_HANDLER_FUNCTIONS(STRING) +STRING_OR_BYTES_HANDLER_FUNCTIONS(BYTES) +#undef STRING_OR_BYTES_HANDLER_FUNCTIONS + +#define PRIMITIVE_HANDLER_FUNCTIONS(FieldType) \ + template \ + inline const typename MapTypeHandler::MapEntryAccessorType& \ + MapTypeHandler::GetExternalReference(const TypeOnMemory& value) { \ + return value; \ + } \ + template \ + inline size_t MapTypeHandler:: \ + SpaceUsedInMapEntryLong(const TypeOnMemory& /* value */) { \ + return 0; \ + } \ + template \ + inline size_t \ + MapTypeHandler::SpaceUsedInMapLong( \ + const TypeOnMemory& /* value */) { \ + return sizeof(Type); \ + } \ + template \ + inline void MapTypeHandler::Clear( \ + TypeOnMemory* value, Arena* /* arena */) { \ + *value = 0; \ + } \ + template \ + inline void MapTypeHandler:: \ + ClearMaybeByDefaultEnum(TypeOnMemory* value, Arena* /* arena */, \ + int default_enum_value) { \ + *value = static_cast(default_enum_value); \ + } \ + template \ + inline void MapTypeHandler::Merge( \ + const MapEntryAccessorType& from, TypeOnMemory* to, \ + Arena* /* arena */) { \ + *to = from; \ + } \ + template \ + inline void MapTypeHandler::DeleteNoArena(TypeOnMemory& /* x */) {} \ + template \ + inline void \ + MapTypeHandler::AssignDefaultValue( \ + TypeOnMemory* /* value */) {} \ + template \ + inline void \ + MapTypeHandler::Initialize( \ + TypeOnMemory* value, Arena* /* arena */) { \ + *value = 0; \ + } \ + template \ + inline void MapTypeHandler:: \ + InitializeMaybeByDefaultEnum( \ + TypeOnMemory* value, int default_enum_value, Arena* /* arena */) { \ + *value = static_cast(default_enum_value); \ + } \ + template \ + inline typename MapTypeHandler::MapEntryAccessorType* \ + MapTypeHandler::EnsureMutable( \ + TypeOnMemory* value, Arena* /* arena */) { \ + return value; \ + } \ + template \ + inline const typename MapTypeHandler::MapEntryAccessorType& \ + MapTypeHandler:: \ + DefaultIfNotInitialized(const TypeOnMemory& value, \ + const TypeOnMemory& /* default_value */) { \ + return value; \ + } \ + template \ + inline bool \ + MapTypeHandler::IsInitialized( \ + const TypeOnMemory& /* value */) { \ + return true; \ + } +PRIMITIVE_HANDLER_FUNCTIONS(INT64) +PRIMITIVE_HANDLER_FUNCTIONS(UINT64) +PRIMITIVE_HANDLER_FUNCTIONS(INT32) +PRIMITIVE_HANDLER_FUNCTIONS(UINT32) +PRIMITIVE_HANDLER_FUNCTIONS(SINT64) +PRIMITIVE_HANDLER_FUNCTIONS(SINT32) +PRIMITIVE_HANDLER_FUNCTIONS(ENUM) +PRIMITIVE_HANDLER_FUNCTIONS(DOUBLE) +PRIMITIVE_HANDLER_FUNCTIONS(FLOAT) +PRIMITIVE_HANDLER_FUNCTIONS(FIXED64) +PRIMITIVE_HANDLER_FUNCTIONS(FIXED32) +PRIMITIVE_HANDLER_FUNCTIONS(SFIXED64) +PRIMITIVE_HANDLER_FUNCTIONS(SFIXED32) +PRIMITIVE_HANDLER_FUNCTIONS(BOOL) +#undef PRIMITIVE_HANDLER_FUNCTIONS + +} // namespace internal +} // namespace protobuf +} // namespace google + +#endif // GOOGLE_PROTOBUF_TYPE_HANDLER_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/message.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/message.h new file mode 100644 index 0000000000000000000000000000000000000000..89761c62ed239aaff1103627725864c56afea253 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/message.h @@ -0,0 +1,1344 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: kenton@google.com (Kenton Varda) +// Based on original Protocol Buffers design by +// Sanjay Ghemawat, Jeff Dean, and others. +// +// Defines Message, the abstract interface implemented by non-lite +// protocol message objects. Although it's possible to implement this +// interface manually, most users will use the protocol compiler to +// generate implementations. +// +// Example usage: +// +// Say you have a message defined as: +// +// message Foo { +// optional string text = 1; +// repeated int32 numbers = 2; +// } +// +// Then, if you used the protocol compiler to generate a class from the above +// definition, you could use it like so: +// +// std::string data; // Will store a serialized version of the message. +// +// { +// // Create a message and serialize it. +// Foo foo; +// foo.set_text("Hello World!"); +// foo.add_numbers(1); +// foo.add_numbers(5); +// foo.add_numbers(42); +// +// foo.SerializeToString(&data); +// } +// +// { +// // Parse the serialized message and check that it contains the +// // correct data. +// Foo foo; +// foo.ParseFromString(data); +// +// assert(foo.text() == "Hello World!"); +// assert(foo.numbers_size() == 3); +// assert(foo.numbers(0) == 1); +// assert(foo.numbers(1) == 5); +// assert(foo.numbers(2) == 42); +// } +// +// { +// // Same as the last block, but do it dynamically via the Message +// // reflection interface. +// Message* foo = new Foo; +// const Descriptor* descriptor = foo->GetDescriptor(); +// +// // Get the descriptors for the fields we're interested in and verify +// // their types. +// const FieldDescriptor* text_field = descriptor->FindFieldByName("text"); +// assert(text_field != nullptr); +// assert(text_field->type() == FieldDescriptor::TYPE_STRING); +// assert(text_field->label() == FieldDescriptor::LABEL_OPTIONAL); +// const FieldDescriptor* numbers_field = descriptor-> +// FindFieldByName("numbers"); +// assert(numbers_field != nullptr); +// assert(numbers_field->type() == FieldDescriptor::TYPE_INT32); +// assert(numbers_field->label() == FieldDescriptor::LABEL_REPEATED); +// +// // Parse the message. +// foo->ParseFromString(data); +// +// // Use the reflection interface to examine the contents. +// const Reflection* reflection = foo->GetReflection(); +// assert(reflection->GetString(*foo, text_field) == "Hello World!"); +// assert(reflection->FieldSize(*foo, numbers_field) == 3); +// assert(reflection->GetRepeatedInt32(*foo, numbers_field, 0) == 1); +// assert(reflection->GetRepeatedInt32(*foo, numbers_field, 1) == 5); +// assert(reflection->GetRepeatedInt32(*foo, numbers_field, 2) == 42); +// +// delete foo; +// } + +#ifndef GOOGLE_PROTOBUF_MESSAGE_H__ +#define GOOGLE_PROTOBUF_MESSAGE_H__ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + + +#define GOOGLE_PROTOBUF_HAS_ONEOF +#define GOOGLE_PROTOBUF_HAS_ARENAS + +#include + +#ifdef SWIG +#error "You cannot SWIG proto headers" +#endif + +namespace google { +namespace protobuf { + +// Defined in this file. +class Message; +class Reflection; +class MessageFactory; + +// Defined in other files. +class AssignDescriptorsHelper; +class DynamicMessageFactory; +class MapKey; +class MapValueRef; +class MapIterator; +class MapReflectionTester; + +namespace internal { +struct DescriptorTable; +class MapFieldBase; +} +class UnknownFieldSet; // unknown_field_set.h +namespace io { +class ZeroCopyInputStream; // zero_copy_stream.h +class ZeroCopyOutputStream; // zero_copy_stream.h +class CodedInputStream; // coded_stream.h +class CodedOutputStream; // coded_stream.h +} // namespace io +namespace python { +class MapReflectionFriend; // scalar_map_container.h +} +namespace expr { +class CelMapReflectionFriend; // field_backed_map_impl.cc +} + +namespace internal { +class MapFieldPrinterHelper; // text_format.cc +} + + +namespace internal { +class ReflectionAccessor; // message.cc +class ReflectionOps; // reflection_ops.h +class MapKeySorter; // wire_format.cc +class WireFormat; // wire_format.h +class MapFieldReflectionTest; // map_test.cc +} // namespace internal + +template +class RepeatedField; // repeated_field.h + +template +class RepeatedPtrField; // repeated_field.h + +// A container to hold message metadata. +struct Metadata { + const Descriptor* descriptor; + const Reflection* reflection; +}; + +namespace internal { +template +inline To* GetPointerAtOffset(Message* message, uint32 offset) { + return reinterpret_cast(reinterpret_cast(message) + offset); +} + +template +const To* GetConstPointerAtOffset(const Message* message, uint32 offset) { + return reinterpret_cast(reinterpret_cast(message) + + offset); +} + +template +const To& GetConstRefAtOffset(const Message& message, uint32 offset) { + return *GetConstPointerAtOffset(&message, offset); +} + +bool CreateUnknownEnumValues(const FieldDescriptor* field); +} // namespace internal + +// Abstract interface for protocol messages. +// +// See also MessageLite, which contains most every-day operations. Message +// adds descriptors and reflection on top of that. +// +// The methods of this class that are virtual but not pure-virtual have +// default implementations based on reflection. Message classes which are +// optimized for speed will want to override these with faster implementations, +// but classes optimized for code size may be happy with keeping them. See +// the optimize_for option in descriptor.proto. +// +// Users must not derive from this class. Only the protocol compiler and +// the internal library are allowed to create subclasses. +class PROTOBUF_EXPORT Message : public MessageLite { + public: + inline Message() {} + + // Basic Operations ------------------------------------------------ + + // Construct a new instance of the same type. Ownership is passed to the + // caller. (This is also defined in MessageLite, but is defined again here + // for return-type covariance.) + Message* New() const override = 0; + + // Construct a new instance on the arena. Ownership is passed to the caller + // if arena is a nullptr. Default implementation allows for API compatibility + // during the Arena transition. + Message* New(Arena* arena) const override { + Message* message = New(); + if (arena != nullptr) { + arena->Own(message); + } + return message; + } + + // Make this message into a copy of the given message. The given message + // must have the same descriptor, but need not necessarily be the same class. + // By default this is just implemented as "Clear(); MergeFrom(from);". + virtual void CopyFrom(const Message& from); + + // Merge the fields from the given message into this message. Singular + // fields will be overwritten, if specified in from, except for embedded + // messages which will be merged. Repeated fields will be concatenated. + // The given message must be of the same type as this message (i.e. the + // exact same class). + virtual void MergeFrom(const Message& from); + + // Verifies that IsInitialized() returns true. GOOGLE_CHECK-fails otherwise, with + // a nice error message. + void CheckInitialized() const; + + // Slowly build a list of all required fields that are not set. + // This is much, much slower than IsInitialized() as it is implemented + // purely via reflection. Generally, you should not call this unless you + // have already determined that an error exists by calling IsInitialized(). + void FindInitializationErrors(std::vector* errors) const; + + // Like FindInitializationErrors, but joins all the strings, delimited by + // commas, and returns them. + std::string InitializationErrorString() const override; + + // Clears all unknown fields from this message and all embedded messages. + // Normally, if unknown tag numbers are encountered when parsing a message, + // the tag and value are stored in the message's UnknownFieldSet and + // then written back out when the message is serialized. This allows servers + // which simply route messages to other servers to pass through messages + // that have new field definitions which they don't yet know about. However, + // this behavior can have security implications. To avoid it, call this + // method after parsing. + // + // See Reflection::GetUnknownFields() for more on unknown fields. + virtual void DiscardUnknownFields(); + + // Computes (an estimate of) the total number of bytes currently used for + // storing the message in memory. The default implementation calls the + // Reflection object's SpaceUsed() method. + // + // SpaceUsed() is noticeably slower than ByteSize(), as it is implemented + // using reflection (rather than the generated code implementation for + // ByteSize()). Like ByteSize(), its CPU time is linear in the number of + // fields defined for the proto. + virtual size_t SpaceUsedLong() const; + + PROTOBUF_DEPRECATED_MSG("Please use SpaceUsedLong() instead") + int SpaceUsed() const { return internal::ToIntSize(SpaceUsedLong()); } + + // Debugging & Testing---------------------------------------------- + + // Generates a human readable form of this message, useful for debugging + // and other purposes. + std::string DebugString() const; + // Like DebugString(), but with less whitespace. + std::string ShortDebugString() const; + // Like DebugString(), but do not escape UTF-8 byte sequences. + std::string Utf8DebugString() const; + // Convenience function useful in GDB. Prints DebugString() to stdout. + void PrintDebugString() const; + + // Reflection-based methods ---------------------------------------- + // These methods are pure-virtual in MessageLite, but Message provides + // reflection-based default implementations. + + std::string GetTypeName() const override; + void Clear() override; + + // Returns whether all required fields have been set. Note that required + // fields no longer exist starting in proto3. + bool IsInitialized() const override; + + void CheckTypeAndMergeFrom(const MessageLite& other) override; + // Reflective parser + const char* _InternalParse(const char* ptr, + internal::ParseContext* ctx) override; + size_t ByteSizeLong() const override; + uint8* _InternalSerialize(uint8* target, + io::EpsCopyOutputStream* stream) const override; + + private: + // This is called only by the default implementation of ByteSize(), to + // update the cached size. If you override ByteSize(), you do not need + // to override this. If you do not override ByteSize(), you MUST override + // this; the default implementation will crash. + // + // The method is private because subclasses should never call it; only + // override it. Yes, C++ lets you do that. Crazy, huh? + virtual void SetCachedSize(int size) const; + + public: + // Introspection --------------------------------------------------- + + + // Get a non-owning pointer to a Descriptor for this message's type. This + // describes what fields the message contains, the types of those fields, etc. + // This object remains property of the Message. + const Descriptor* GetDescriptor() const { return GetMetadata().descriptor; } + + // Get a non-owning pointer to the Reflection interface for this Message, + // which can be used to read and modify the fields of the Message dynamically + // (in other words, without knowing the message type at compile time). This + // object remains property of the Message. + const Reflection* GetReflection() const { return GetMetadata().reflection; } + + protected: + // Get a struct containing the metadata for the Message, which is used in turn + // to implement GetDescriptor() and GetReflection() above. + virtual Metadata GetMetadata() const = 0; + + inline explicit Message(Arena* arena) : MessageLite(arena) {} + + + private: + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(Message); +}; + +namespace internal { +// Forward-declare interfaces used to implement RepeatedFieldRef. +// These are protobuf internals that users shouldn't care about. +class RepeatedFieldAccessor; +} // namespace internal + +// Forward-declare RepeatedFieldRef templates. The second type parameter is +// used for SFINAE tricks. Users should ignore it. +template +class RepeatedFieldRef; + +template +class MutableRepeatedFieldRef; + +// This interface contains methods that can be used to dynamically access +// and modify the fields of a protocol message. Their semantics are +// similar to the accessors the protocol compiler generates. +// +// To get the Reflection for a given Message, call Message::GetReflection(). +// +// This interface is separate from Message only for efficiency reasons; +// the vast majority of implementations of Message will share the same +// implementation of Reflection (GeneratedMessageReflection, +// defined in generated_message.h), and all Messages of a particular class +// should share the same Reflection object (though you should not rely on +// the latter fact). +// +// There are several ways that these methods can be used incorrectly. For +// example, any of the following conditions will lead to undefined +// results (probably assertion failures): +// - The FieldDescriptor is not a field of this message type. +// - The method called is not appropriate for the field's type. For +// each field type in FieldDescriptor::TYPE_*, there is only one +// Get*() method, one Set*() method, and one Add*() method that is +// valid for that type. It should be obvious which (except maybe +// for TYPE_BYTES, which are represented using strings in C++). +// - A Get*() or Set*() method for singular fields is called on a repeated +// field. +// - GetRepeated*(), SetRepeated*(), or Add*() is called on a non-repeated +// field. +// - The Message object passed to any method is not of the right type for +// this Reflection object (i.e. message.GetReflection() != reflection). +// +// You might wonder why there is not any abstract representation for a field +// of arbitrary type. E.g., why isn't there just a "GetField()" method that +// returns "const Field&", where "Field" is some class with accessors like +// "GetInt32Value()". The problem is that someone would have to deal with +// allocating these Field objects. For generated message classes, having to +// allocate space for an additional object to wrap every field would at least +// double the message's memory footprint, probably worse. Allocating the +// objects on-demand, on the other hand, would be expensive and prone to +// memory leaks. So, instead we ended up with this flat interface. +class PROTOBUF_EXPORT Reflection final { + public: + // Get the UnknownFieldSet for the message. This contains fields which + // were seen when the Message was parsed but were not recognized according + // to the Message's definition. + const UnknownFieldSet& GetUnknownFields(const Message& message) const; + // Get a mutable pointer to the UnknownFieldSet for the message. This + // contains fields which were seen when the Message was parsed but were not + // recognized according to the Message's definition. + UnknownFieldSet* MutableUnknownFields(Message* message) const; + + // Estimate the amount of memory used by the message object. + size_t SpaceUsedLong(const Message& message) const; + + PROTOBUF_DEPRECATED_MSG("Please use SpaceUsedLong() instead") + int SpaceUsed(const Message& message) const { + return internal::ToIntSize(SpaceUsedLong(message)); + } + + // Check if the given non-repeated field is set. + bool HasField(const Message& message, const FieldDescriptor* field) const; + + // Get the number of elements of a repeated field. + int FieldSize(const Message& message, const FieldDescriptor* field) const; + + // Clear the value of a field, so that HasField() returns false or + // FieldSize() returns zero. + void ClearField(Message* message, const FieldDescriptor* field) const; + + // Check if the oneof is set. Returns true if any field in oneof + // is set, false otherwise. + bool HasOneof(const Message& message, + const OneofDescriptor* oneof_descriptor) const; + + void ClearOneof(Message* message, + const OneofDescriptor* oneof_descriptor) const; + + // Returns the field descriptor if the oneof is set. nullptr otherwise. + const FieldDescriptor* GetOneofFieldDescriptor( + const Message& message, const OneofDescriptor* oneof_descriptor) const; + + // Removes the last element of a repeated field. + // We don't provide a way to remove any element other than the last + // because it invites inefficient use, such as O(n^2) filtering loops + // that should have been O(n). If you want to remove an element other + // than the last, the best way to do it is to re-arrange the elements + // (using Swap()) so that the one you want removed is at the end, then + // call RemoveLast(). + void RemoveLast(Message* message, const FieldDescriptor* field) const; + // Removes the last element of a repeated message field, and returns the + // pointer to the caller. Caller takes ownership of the returned pointer. + Message* ReleaseLast(Message* message, const FieldDescriptor* field) const; + + // Swap the complete contents of two messages. + void Swap(Message* message1, Message* message2) const; + + // Swap fields listed in fields vector of two messages. + void SwapFields(Message* message1, Message* message2, + const std::vector& fields) const; + + // Swap two elements of a repeated field. + void SwapElements(Message* message, const FieldDescriptor* field, int index1, + int index2) const; + + // List all fields of the message which are currently set, except for unknown + // fields, but including extension known to the parser (i.e. compiled in). + // Singular fields will only be listed if HasField(field) would return true + // and repeated fields will only be listed if FieldSize(field) would return + // non-zero. Fields (both normal fields and extension fields) will be listed + // ordered by field number. + // Use Reflection::GetUnknownFields() or message.unknown_fields() to also get + // access to fields/extensions unknown to the parser. + void ListFields(const Message& message, + std::vector* output) const; + + // Singular field getters ------------------------------------------ + // These get the value of a non-repeated field. They return the default + // value for fields that aren't set. + + int32 GetInt32(const Message& message, const FieldDescriptor* field) const; + int64 GetInt64(const Message& message, const FieldDescriptor* field) const; + uint32 GetUInt32(const Message& message, const FieldDescriptor* field) const; + uint64 GetUInt64(const Message& message, const FieldDescriptor* field) const; + float GetFloat(const Message& message, const FieldDescriptor* field) const; + double GetDouble(const Message& message, const FieldDescriptor* field) const; + bool GetBool(const Message& message, const FieldDescriptor* field) const; + std::string GetString(const Message& message, + const FieldDescriptor* field) const; + const EnumValueDescriptor* GetEnum(const Message& message, + const FieldDescriptor* field) const; + + // GetEnumValue() returns an enum field's value as an integer rather than + // an EnumValueDescriptor*. If the integer value does not correspond to a + // known value descriptor, a new value descriptor is created. (Such a value + // will only be present when the new unknown-enum-value semantics are enabled + // for a message.) + int GetEnumValue(const Message& message, const FieldDescriptor* field) const; + + // See MutableMessage() for the meaning of the "factory" parameter. + const Message& GetMessage(const Message& message, + const FieldDescriptor* field, + MessageFactory* factory = nullptr) const; + + // Get a string value without copying, if possible. + // + // GetString() necessarily returns a copy of the string. This can be + // inefficient when the std::string is already stored in a std::string object + // in the underlying message. GetStringReference() will return a reference to + // the underlying std::string in this case. Otherwise, it will copy the + // string into *scratch and return that. + // + // Note: It is perfectly reasonable and useful to write code like: + // str = reflection->GetStringReference(message, field, &str); + // This line would ensure that only one copy of the string is made + // regardless of the field's underlying representation. When initializing + // a newly-constructed string, though, it's just as fast and more + // readable to use code like: + // std::string str = reflection->GetString(message, field); + const std::string& GetStringReference(const Message& message, + const FieldDescriptor* field, + std::string* scratch) const; + + + // Singular field mutators ----------------------------------------- + // These mutate the value of a non-repeated field. + + void SetInt32(Message* message, const FieldDescriptor* field, + int32 value) const; + void SetInt64(Message* message, const FieldDescriptor* field, + int64 value) const; + void SetUInt32(Message* message, const FieldDescriptor* field, + uint32 value) const; + void SetUInt64(Message* message, const FieldDescriptor* field, + uint64 value) const; + void SetFloat(Message* message, const FieldDescriptor* field, + float value) const; + void SetDouble(Message* message, const FieldDescriptor* field, + double value) const; + void SetBool(Message* message, const FieldDescriptor* field, + bool value) const; + void SetString(Message* message, const FieldDescriptor* field, + std::string value) const; + void SetEnum(Message* message, const FieldDescriptor* field, + const EnumValueDescriptor* value) const; + // Set an enum field's value with an integer rather than EnumValueDescriptor. + // For proto3 this is just setting the enum field to the value specified, for + // proto2 it's more complicated. If value is a known enum value the field is + // set as usual. If the value is unknown then it is added to the unknown field + // set. Note this matches the behavior of parsing unknown enum values. + // If multiple calls with unknown values happen than they are all added to the + // unknown field set in order of the calls. + void SetEnumValue(Message* message, const FieldDescriptor* field, + int value) const; + + // Get a mutable pointer to a field with a message type. If a MessageFactory + // is provided, it will be used to construct instances of the sub-message; + // otherwise, the default factory is used. If the field is an extension that + // does not live in the same pool as the containing message's descriptor (e.g. + // it lives in an overlay pool), then a MessageFactory must be provided. + // If you have no idea what that meant, then you probably don't need to worry + // about it (don't provide a MessageFactory). WARNING: If the + // FieldDescriptor is for a compiled-in extension, then + // factory->GetPrototype(field->message_type()) MUST return an instance of + // the compiled-in class for this type, NOT DynamicMessage. + Message* MutableMessage(Message* message, const FieldDescriptor* field, + MessageFactory* factory = nullptr) const; + // Replaces the message specified by 'field' with the already-allocated object + // sub_message, passing ownership to the message. If the field contained a + // message, that message is deleted. If sub_message is nullptr, the field is + // cleared. + void SetAllocatedMessage(Message* message, Message* sub_message, + const FieldDescriptor* field) const; + // Releases the message specified by 'field' and returns the pointer, + // ReleaseMessage() will return the message the message object if it exists. + // Otherwise, it may or may not return nullptr. In any case, if the return + // value is non-null, the caller takes ownership of the pointer. + // If the field existed (HasField() is true), then the returned pointer will + // be the same as the pointer returned by MutableMessage(). + // This function has the same effect as ClearField(). + Message* ReleaseMessage(Message* message, const FieldDescriptor* field, + MessageFactory* factory = nullptr) const; + + + // Repeated field getters ------------------------------------------ + // These get the value of one element of a repeated field. + + int32 GetRepeatedInt32(const Message& message, const FieldDescriptor* field, + int index) const; + int64 GetRepeatedInt64(const Message& message, const FieldDescriptor* field, + int index) const; + uint32 GetRepeatedUInt32(const Message& message, const FieldDescriptor* field, + int index) const; + uint64 GetRepeatedUInt64(const Message& message, const FieldDescriptor* field, + int index) const; + float GetRepeatedFloat(const Message& message, const FieldDescriptor* field, + int index) const; + double GetRepeatedDouble(const Message& message, const FieldDescriptor* field, + int index) const; + bool GetRepeatedBool(const Message& message, const FieldDescriptor* field, + int index) const; + std::string GetRepeatedString(const Message& message, + const FieldDescriptor* field, int index) const; + const EnumValueDescriptor* GetRepeatedEnum(const Message& message, + const FieldDescriptor* field, + int index) const; + // GetRepeatedEnumValue() returns an enum field's value as an integer rather + // than an EnumValueDescriptor*. If the integer value does not correspond to a + // known value descriptor, a new value descriptor is created. (Such a value + // will only be present when the new unknown-enum-value semantics are enabled + // for a message.) + int GetRepeatedEnumValue(const Message& message, const FieldDescriptor* field, + int index) const; + const Message& GetRepeatedMessage(const Message& message, + const FieldDescriptor* field, + int index) const; + + // See GetStringReference(), above. + const std::string& GetRepeatedStringReference(const Message& message, + const FieldDescriptor* field, + int index, + std::string* scratch) const; + + + // Repeated field mutators ----------------------------------------- + // These mutate the value of one element of a repeated field. + + void SetRepeatedInt32(Message* message, const FieldDescriptor* field, + int index, int32 value) const; + void SetRepeatedInt64(Message* message, const FieldDescriptor* field, + int index, int64 value) const; + void SetRepeatedUInt32(Message* message, const FieldDescriptor* field, + int index, uint32 value) const; + void SetRepeatedUInt64(Message* message, const FieldDescriptor* field, + int index, uint64 value) const; + void SetRepeatedFloat(Message* message, const FieldDescriptor* field, + int index, float value) const; + void SetRepeatedDouble(Message* message, const FieldDescriptor* field, + int index, double value) const; + void SetRepeatedBool(Message* message, const FieldDescriptor* field, + int index, bool value) const; + void SetRepeatedString(Message* message, const FieldDescriptor* field, + int index, std::string value) const; + void SetRepeatedEnum(Message* message, const FieldDescriptor* field, + int index, const EnumValueDescriptor* value) const; + // Set an enum field's value with an integer rather than EnumValueDescriptor. + // For proto3 this is just setting the enum field to the value specified, for + // proto2 it's more complicated. If value is a known enum value the field is + // set as usual. If the value is unknown then it is added to the unknown field + // set. Note this matches the behavior of parsing unknown enum values. + // If multiple calls with unknown values happen than they are all added to the + // unknown field set in order of the calls. + void SetRepeatedEnumValue(Message* message, const FieldDescriptor* field, + int index, int value) const; + // Get a mutable pointer to an element of a repeated field with a message + // type. + Message* MutableRepeatedMessage(Message* message, + const FieldDescriptor* field, + int index) const; + + + // Repeated field adders ------------------------------------------- + // These add an element to a repeated field. + + void AddInt32(Message* message, const FieldDescriptor* field, + int32 value) const; + void AddInt64(Message* message, const FieldDescriptor* field, + int64 value) const; + void AddUInt32(Message* message, const FieldDescriptor* field, + uint32 value) const; + void AddUInt64(Message* message, const FieldDescriptor* field, + uint64 value) const; + void AddFloat(Message* message, const FieldDescriptor* field, + float value) const; + void AddDouble(Message* message, const FieldDescriptor* field, + double value) const; + void AddBool(Message* message, const FieldDescriptor* field, + bool value) const; + void AddString(Message* message, const FieldDescriptor* field, + std::string value) const; + void AddEnum(Message* message, const FieldDescriptor* field, + const EnumValueDescriptor* value) const; + // Add an integer value to a repeated enum field rather than + // EnumValueDescriptor. For proto3 this is just setting the enum field to the + // value specified, for proto2 it's more complicated. If value is a known enum + // value the field is set as usual. If the value is unknown then it is added + // to the unknown field set. Note this matches the behavior of parsing unknown + // enum values. If multiple calls with unknown values happen than they are all + // added to the unknown field set in order of the calls. + void AddEnumValue(Message* message, const FieldDescriptor* field, + int value) const; + // See MutableMessage() for comments on the "factory" parameter. + Message* AddMessage(Message* message, const FieldDescriptor* field, + MessageFactory* factory = nullptr) const; + + // Appends an already-allocated object 'new_entry' to the repeated field + // specified by 'field' passing ownership to the message. + void AddAllocatedMessage(Message* message, const FieldDescriptor* field, + Message* new_entry) const; + + + // Get a RepeatedFieldRef object that can be used to read the underlying + // repeated field. The type parameter T must be set according to the + // field's cpp type. The following table shows the mapping from cpp type + // to acceptable T. + // + // field->cpp_type() T + // CPPTYPE_INT32 int32 + // CPPTYPE_UINT32 uint32 + // CPPTYPE_INT64 int64 + // CPPTYPE_UINT64 uint64 + // CPPTYPE_DOUBLE double + // CPPTYPE_FLOAT float + // CPPTYPE_BOOL bool + // CPPTYPE_ENUM generated enum type or int32 + // CPPTYPE_STRING std::string + // CPPTYPE_MESSAGE generated message type or google::protobuf::Message + // + // A RepeatedFieldRef object can be copied and the resulted object will point + // to the same repeated field in the same message. The object can be used as + // long as the message is not destroyed. + // + // Note that to use this method users need to include the header file + // "reflection.h" (which defines the RepeatedFieldRef class templates). + template + RepeatedFieldRef GetRepeatedFieldRef(const Message& message, + const FieldDescriptor* field) const; + + // Like GetRepeatedFieldRef() but return an object that can also be used + // manipulate the underlying repeated field. + template + MutableRepeatedFieldRef GetMutableRepeatedFieldRef( + Message* message, const FieldDescriptor* field) const; + + // DEPRECATED. Please use Get(Mutable)RepeatedFieldRef() for repeated field + // access. The following repeated field accesors will be removed in the + // future. + // + // Repeated field accessors ------------------------------------------------- + // The methods above, e.g. GetRepeatedInt32(msg, fd, index), provide singular + // access to the data in a RepeatedField. The methods below provide aggregate + // access by exposing the RepeatedField object itself with the Message. + // Applying these templates to inappropriate types will lead to an undefined + // reference at link time (e.g. GetRepeatedField<***double>), or possibly a + // template matching error at compile time (e.g. GetRepeatedPtrField). + // + // Usage example: my_doubs = refl->GetRepeatedField(msg, fd); + + // DEPRECATED. Please use GetRepeatedFieldRef(). + // + // for T = Cord and all protobuf scalar types except enums. + template + PROTOBUF_DEPRECATED_MSG("Please use GetRepeatedFieldRef() instead") + const RepeatedField& GetRepeatedField(const Message& msg, + const FieldDescriptor* d) const { + return GetRepeatedFieldInternal(msg, d); + } + + // DEPRECATED. Please use GetMutableRepeatedFieldRef(). + // + // for T = Cord and all protobuf scalar types except enums. + template + PROTOBUF_DEPRECATED_MSG("Please use GetMutableRepeatedFieldRef() instead") + RepeatedField* MutableRepeatedField(Message* msg, + const FieldDescriptor* d) const { + return MutableRepeatedFieldInternal(msg, d); + } + + // DEPRECATED. Please use GetRepeatedFieldRef(). + // + // for T = std::string, google::protobuf::internal::StringPieceField + // google::protobuf::Message & descendants. + template + PROTOBUF_DEPRECATED_MSG("Please use GetRepeatedFieldRef() instead") + const RepeatedPtrField& GetRepeatedPtrField( + const Message& msg, const FieldDescriptor* d) const { + return GetRepeatedPtrFieldInternal(msg, d); + } + + // DEPRECATED. Please use GetMutableRepeatedFieldRef(). + // + // for T = std::string, google::protobuf::internal::StringPieceField + // google::protobuf::Message & descendants. + template + PROTOBUF_DEPRECATED_MSG("Please use GetMutableRepeatedFieldRef() instead") + RepeatedPtrField* MutableRepeatedPtrField(Message* msg, + const FieldDescriptor* d) const { + return MutableRepeatedPtrFieldInternal(msg, d); + } + + // Extensions ---------------------------------------------------------------- + + // Try to find an extension of this message type by fully-qualified field + // name. Returns nullptr if no extension is known for this name or number. + const FieldDescriptor* FindKnownExtensionByName( + const std::string& name) const; + + // Try to find an extension of this message type by field number. + // Returns nullptr if no extension is known for this name or number. + const FieldDescriptor* FindKnownExtensionByNumber(int number) const; + + // Feature Flags ------------------------------------------------------------- + + // Does this message support storing arbitrary integer values in enum fields? + // If |true|, GetEnumValue/SetEnumValue and associated repeated-field versions + // take arbitrary integer values, and the legacy GetEnum() getter will + // dynamically create an EnumValueDescriptor for any integer value without + // one. If |false|, setting an unknown enum value via the integer-based + // setters results in undefined behavior (in practice, GOOGLE_DCHECK-fails). + // + // Generic code that uses reflection to handle messages with enum fields + // should check this flag before using the integer-based setter, and either + // downgrade to a compatible value or use the UnknownFieldSet if not. For + // example: + // + // int new_value = GetValueFromApplicationLogic(); + // if (reflection->SupportsUnknownEnumValues()) { + // reflection->SetEnumValue(message, field, new_value); + // } else { + // if (field_descriptor->enum_type()-> + // FindValueByNumber(new_value) != nullptr) { + // reflection->SetEnumValue(message, field, new_value); + // } else if (emit_unknown_enum_values) { + // reflection->MutableUnknownFields(message)->AddVarint( + // field->number(), new_value); + // } else { + // // convert value to a compatible/default value. + // new_value = CompatibleDowngrade(new_value); + // reflection->SetEnumValue(message, field, new_value); + // } + // } + bool SupportsUnknownEnumValues() const; + + // Returns the MessageFactory associated with this message. This can be + // useful for determining if a message is a generated message or not, for + // example: + // if (message->GetReflection()->GetMessageFactory() == + // google::protobuf::MessageFactory::generated_factory()) { + // // This is a generated message. + // } + // It can also be used to create more messages of this type, though + // Message::New() is an easier way to accomplish this. + MessageFactory* GetMessageFactory() const; + + private: + template + const RepeatedField& GetRepeatedFieldInternal( + const Message& message, const FieldDescriptor* field) const; + template + RepeatedField* MutableRepeatedFieldInternal( + Message* message, const FieldDescriptor* field) const; + template + const RepeatedPtrField& GetRepeatedPtrFieldInternal( + const Message& message, const FieldDescriptor* field) const; + template + RepeatedPtrField* MutableRepeatedPtrFieldInternal( + Message* message, const FieldDescriptor* field) const; + // Obtain a pointer to a Repeated Field Structure and do some type checking: + // on field->cpp_type(), + // on field->field_option().ctype() (if ctype >= 0) + // of field->message_type() (if message_type != nullptr). + // We use 2 routine rather than 4 (const vs mutable) x (scalar vs pointer). + void* MutableRawRepeatedField(Message* message, const FieldDescriptor* field, + FieldDescriptor::CppType, int ctype, + const Descriptor* message_type) const; + + const void* GetRawRepeatedField(const Message& message, + const FieldDescriptor* field, + FieldDescriptor::CppType cpptype, int ctype, + const Descriptor* message_type) const; + + // The following methods are used to implement (Mutable)RepeatedFieldRef. + // A Ref object will store a raw pointer to the repeated field data (obtained + // from RepeatedFieldData()) and a pointer to a Accessor (obtained from + // RepeatedFieldAccessor) which will be used to access the raw data. + + // Returns a raw pointer to the repeated field + // + // "cpp_type" and "message_type" are deduced from the type parameter T passed + // to Get(Mutable)RepeatedFieldRef. If T is a generated message type, + // "message_type" should be set to its descriptor. Otherwise "message_type" + // should be set to nullptr. Implementations of this method should check + // whether "cpp_type"/"message_type" is consistent with the actual type of the + // field. We use 1 routine rather than 2 (const vs mutable) because it is + // protected and it doesn't change the message. + void* RepeatedFieldData(Message* message, const FieldDescriptor* field, + FieldDescriptor::CppType cpp_type, + const Descriptor* message_type) const; + + // The returned pointer should point to a singleton instance which implements + // the RepeatedFieldAccessor interface. + const internal::RepeatedFieldAccessor* RepeatedFieldAccessor( + const FieldDescriptor* field) const; + + // Lists all fields of the message which are currently set, except for unknown + // fields and stripped fields. See ListFields for details. + void ListFieldsOmitStripped( + const Message& message, + std::vector* output) const; + + bool IsMessageStripped(const Descriptor* descriptor) const { + return schema_.IsMessageStripped(descriptor); + } + + friend class TextFormat; + + void ListFieldsMayFailOnStripped( + const Message& message, bool should_fail, + std::vector* output) const; + + const Descriptor* const descriptor_; + const internal::ReflectionSchema schema_; + const DescriptorPool* const descriptor_pool_; + MessageFactory* const message_factory_; + + // Last non weak field index. This is an optimization when most weak fields + // are at the end of the containing message. If a message proto doesn't + // contain weak fields, then this field equals descriptor_->field_count(). + int last_non_weak_field_index_; + + template + friend class RepeatedFieldRef; + template + friend class MutableRepeatedFieldRef; + friend class ::PROTOBUF_NAMESPACE_ID::MessageLayoutInspector; + friend class ::PROTOBUF_NAMESPACE_ID::AssignDescriptorsHelper; + friend class DynamicMessageFactory; + friend class python::MapReflectionFriend; +#define GOOGLE_PROTOBUF_HAS_CEL_MAP_REFLECTION_FRIEND + friend class expr::CelMapReflectionFriend; + friend class internal::MapFieldReflectionTest; + friend class internal::MapKeySorter; + friend class internal::WireFormat; + friend class internal::ReflectionOps; + // Needed for implementing text format for map. + friend class internal::MapFieldPrinterHelper; + + Reflection(const Descriptor* descriptor, + const internal::ReflectionSchema& schema, + const DescriptorPool* pool, MessageFactory* factory); + + // Special version for specialized implementations of string. We can't + // call MutableRawRepeatedField directly here because we don't have access to + // FieldOptions::* which are defined in descriptor.pb.h. Including that + // file here is not possible because it would cause a circular include cycle. + // We use 1 routine rather than 2 (const vs mutable) because it is private + // and mutable a repeated string field doesn't change the message. + void* MutableRawRepeatedString(Message* message, const FieldDescriptor* field, + bool is_string) const; + + friend class MapReflectionTester; + // Returns true if key is in map. Returns false if key is not in map field. + bool ContainsMapKey(const Message& message, const FieldDescriptor* field, + const MapKey& key) const; + + // If key is in map field: Saves the value pointer to val and returns + // false. If key in not in map field: Insert the key into map, saves + // value pointer to val and returns true. + bool InsertOrLookupMapValue(Message* message, const FieldDescriptor* field, + const MapKey& key, MapValueRef* val) const; + + // Delete and returns true if key is in the map field. Returns false + // otherwise. + bool DeleteMapValue(Message* message, const FieldDescriptor* field, + const MapKey& key) const; + + // Returns a MapIterator referring to the first element in the map field. + // If the map field is empty, this function returns the same as + // reflection::MapEnd. Mutation to the field may invalidate the iterator. + MapIterator MapBegin(Message* message, const FieldDescriptor* field) const; + + // Returns a MapIterator referring to the theoretical element that would + // follow the last element in the map field. It does not point to any + // real element. Mutation to the field may invalidate the iterator. + MapIterator MapEnd(Message* message, const FieldDescriptor* field) const; + + // Get the number of pair of a map field. The result may be + // different from FieldSize which can have duplicate keys. + int MapSize(const Message& message, const FieldDescriptor* field) const; + + // Help method for MapIterator. + friend class MapIterator; + friend class WireFormatForMapFieldTest; + internal::MapFieldBase* MutableMapData(Message* message, + const FieldDescriptor* field) const; + + const internal::MapFieldBase* GetMapData(const Message& message, + const FieldDescriptor* field) const; + + template + const T& GetRawNonOneof(const Message& message, + const FieldDescriptor* field) const; + template + T* MutableRawNonOneof(Message* message, const FieldDescriptor* field) const; + + template + const Type& GetRaw(const Message& message, + const FieldDescriptor* field) const; + template + inline Type* MutableRaw(Message* message, const FieldDescriptor* field) const; + template + const Type& DefaultRaw(const FieldDescriptor* field) const; + + inline const uint32* GetHasBits(const Message& message) const; + inline uint32* MutableHasBits(Message* message) const; + inline uint32 GetOneofCase(const Message& message, + const OneofDescriptor* oneof_descriptor) const; + inline uint32* MutableOneofCase( + Message* message, const OneofDescriptor* oneof_descriptor) const; + inline bool HasExtensionSet(const Message& message) const { + return schema_.HasExtensionSet(); + } + const internal::ExtensionSet& GetExtensionSet(const Message& message) const; + internal::ExtensionSet* MutableExtensionSet(Message* message) const; + inline Arena* GetArena(Message* message) const; + + inline const internal::InternalMetadata& GetInternalMetadata( + const Message& message) const; + + internal::InternalMetadata* MutableInternalMetadata(Message* message) const; + + inline bool IsInlined(const FieldDescriptor* field) const; + + inline bool HasBit(const Message& message, + const FieldDescriptor* field) const; + inline void SetBit(Message* message, const FieldDescriptor* field) const; + inline void ClearBit(Message* message, const FieldDescriptor* field) const; + inline void SwapBit(Message* message1, Message* message2, + const FieldDescriptor* field) const; + + // This function only swaps the field. Should swap corresponding has_bit + // before or after using this function. + void SwapField(Message* message1, Message* message2, + const FieldDescriptor* field) const; + + void SwapOneofField(Message* message1, Message* message2, + const OneofDescriptor* oneof_descriptor) const; + + inline bool HasOneofField(const Message& message, + const FieldDescriptor* field) const; + inline void SetOneofCase(Message* message, + const FieldDescriptor* field) const; + inline void ClearOneofField(Message* message, + const FieldDescriptor* field) const; + + template + inline const Type& GetField(const Message& message, + const FieldDescriptor* field) const; + template + inline void SetField(Message* message, const FieldDescriptor* field, + const Type& value) const; + template + inline Type* MutableField(Message* message, + const FieldDescriptor* field) const; + template + inline const Type& GetRepeatedField(const Message& message, + const FieldDescriptor* field, + int index) const; + template + inline const Type& GetRepeatedPtrField(const Message& message, + const FieldDescriptor* field, + int index) const; + template + inline void SetRepeatedField(Message* message, const FieldDescriptor* field, + int index, Type value) const; + template + inline Type* MutableRepeatedField(Message* message, + const FieldDescriptor* field, + int index) const; + template + inline void AddField(Message* message, const FieldDescriptor* field, + const Type& value) const; + template + inline Type* AddField(Message* message, const FieldDescriptor* field) const; + + int GetExtensionNumberOrDie(const Descriptor* type) const; + + // Internal versions of EnumValue API perform no checking. Called after checks + // by public methods. + void SetEnumValueInternal(Message* message, const FieldDescriptor* field, + int value) const; + void SetRepeatedEnumValueInternal(Message* message, + const FieldDescriptor* field, int index, + int value) const; + void AddEnumValueInternal(Message* message, const FieldDescriptor* field, + int value) const; + + Message* UnsafeArenaReleaseMessage(Message* message, + const FieldDescriptor* field, + MessageFactory* factory = nullptr) const; + + void UnsafeArenaSetAllocatedMessage(Message* message, Message* sub_message, + const FieldDescriptor* field) const; + + friend inline // inline so nobody can call this function. + void + RegisterAllTypesInternal(const Metadata* file_level_metadata, int size); + friend inline const char* ParseLenDelim(int field_number, + const FieldDescriptor* field, + Message* msg, + const Reflection* reflection, + const char* ptr, + internal::ParseContext* ctx); + friend inline const char* ParsePackedField(const FieldDescriptor* field, + Message* msg, + const Reflection* reflection, + const char* ptr, + internal::ParseContext* ctx); + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(Reflection); +}; + +// Abstract interface for a factory for message objects. +class PROTOBUF_EXPORT MessageFactory { + public: + inline MessageFactory() {} + virtual ~MessageFactory(); + + // Given a Descriptor, gets or constructs the default (prototype) Message + // of that type. You can then call that message's New() method to construct + // a mutable message of that type. + // + // Calling this method twice with the same Descriptor returns the same + // object. The returned object remains property of the factory. Also, any + // objects created by calling the prototype's New() method share some data + // with the prototype, so these must be destroyed before the MessageFactory + // is destroyed. + // + // The given descriptor must outlive the returned message, and hence must + // outlive the MessageFactory. + // + // Some implementations do not support all types. GetPrototype() will + // return nullptr if the descriptor passed in is not supported. + // + // This method may or may not be thread-safe depending on the implementation. + // Each implementation should document its own degree thread-safety. + virtual const Message* GetPrototype(const Descriptor* type) = 0; + + // Gets a MessageFactory which supports all generated, compiled-in messages. + // In other words, for any compiled-in type FooMessage, the following is true: + // MessageFactory::generated_factory()->GetPrototype( + // FooMessage::descriptor()) == FooMessage::default_instance() + // This factory supports all types which are found in + // DescriptorPool::generated_pool(). If given a descriptor from any other + // pool, GetPrototype() will return nullptr. (You can also check if a + // descriptor is for a generated message by checking if + // descriptor->file()->pool() == DescriptorPool::generated_pool().) + // + // This factory is 100% thread-safe; calling GetPrototype() does not modify + // any shared data. + // + // This factory is a singleton. The caller must not delete the object. + static MessageFactory* generated_factory(); + + // For internal use only: Registers a .proto file at static initialization + // time, to be placed in generated_factory. The first time GetPrototype() + // is called with a descriptor from this file, |register_messages| will be + // called, with the file name as the parameter. It must call + // InternalRegisterGeneratedMessage() (below) to register each message type + // in the file. This strange mechanism is necessary because descriptors are + // built lazily, so we can't register types by their descriptor until we + // know that the descriptor exists. |filename| must be a permanent string. + static void InternalRegisterGeneratedFile( + const google::protobuf::internal::DescriptorTable* table); + + // For internal use only: Registers a message type. Called only by the + // functions which are registered with InternalRegisterGeneratedFile(), + // above. + static void InternalRegisterGeneratedMessage(const Descriptor* descriptor, + const Message* prototype); + + + private: + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(MessageFactory); +}; + +#define DECLARE_GET_REPEATED_FIELD(TYPE) \ + template <> \ + PROTOBUF_EXPORT const RepeatedField& \ + Reflection::GetRepeatedFieldInternal( \ + const Message& message, const FieldDescriptor* field) const; \ + \ + template <> \ + PROTOBUF_EXPORT RepeatedField* \ + Reflection::MutableRepeatedFieldInternal( \ + Message * message, const FieldDescriptor* field) const; + +DECLARE_GET_REPEATED_FIELD(int32) +DECLARE_GET_REPEATED_FIELD(int64) +DECLARE_GET_REPEATED_FIELD(uint32) +DECLARE_GET_REPEATED_FIELD(uint64) +DECLARE_GET_REPEATED_FIELD(float) +DECLARE_GET_REPEATED_FIELD(double) +DECLARE_GET_REPEATED_FIELD(bool) + +#undef DECLARE_GET_REPEATED_FIELD + +// Tries to downcast this message to a generated message type. Returns nullptr +// if this class is not an instance of T. This works even if RTTI is disabled. +// +// This also has the effect of creating a strong reference to T that will +// prevent the linker from stripping it out at link time. This can be important +// if you are using a DynamicMessageFactory that delegates to the generated +// factory. +template +const T* DynamicCastToGenerated(const Message* from) { + // Compile-time assert that T is a generated type that has a + // default_instance() accessor, but avoid actually calling it. + const T& (*get_default_instance)() = &T::default_instance; + (void)get_default_instance; + + // Compile-time assert that T is a subclass of google::protobuf::Message. + const Message* unused = static_cast(nullptr); + (void)unused; + +#if PROTOBUF_RTTI + return dynamic_cast(from); +#else + bool ok = T::default_instance().GetReflection() == from->GetReflection(); + return ok ? down_cast(from) : nullptr; +#endif +} + +template +T* DynamicCastToGenerated(Message* from) { + const Message* message_const = from; + return const_cast(DynamicCastToGenerated(message_const)); +} + +// Call this function to ensure that this message's reflection is linked into +// the binary: +// +// google::protobuf::LinkMessageReflection(); +// +// This will ensure that the following lookup will succeed: +// +// DescriptorPool::generated_pool()->FindMessageTypeByName("FooMessage"); +// +// As a side-effect, it will also guarantee that anything else from the same +// .proto file will also be available for lookup in the generated pool. +// +// This function does not actually register the message, so it does not need +// to be called before the lookup. However it does need to occur in a function +// that cannot be stripped from the binary (ie. it must be reachable from main). +// +// Best practice is to call this function as close as possible to where the +// reflection is actually needed. This function is very cheap to call, so you +// should not need to worry about its runtime overhead except in the tightest +// of loops (on x86-64 it compiles into two "mov" instructions). +template +void LinkMessageReflection() { + internal::StrongReference(T::default_instance); +} + +// ============================================================================= +// Implementation details for {Get,Mutable}RawRepeatedPtrField. We provide +// specializations for , and and +// handle everything else with the default template which will match any type +// having a method with signature "static const google::protobuf::Descriptor* +// descriptor()". Such a type presumably is a descendant of google::protobuf::Message. + +template <> +inline const RepeatedPtrField& +Reflection::GetRepeatedPtrFieldInternal( + const Message& message, const FieldDescriptor* field) const { + return *static_cast*>( + MutableRawRepeatedString(const_cast(&message), field, true)); +} + +template <> +inline RepeatedPtrField* +Reflection::MutableRepeatedPtrFieldInternal( + Message* message, const FieldDescriptor* field) const { + return static_cast*>( + MutableRawRepeatedString(message, field, true)); +} + + +// ----- + +template <> +inline const RepeatedPtrField& Reflection::GetRepeatedPtrFieldInternal( + const Message& message, const FieldDescriptor* field) const { + return *static_cast*>(GetRawRepeatedField( + message, field, FieldDescriptor::CPPTYPE_MESSAGE, -1, nullptr)); +} + +template <> +inline RepeatedPtrField* Reflection::MutableRepeatedPtrFieldInternal( + Message* message, const FieldDescriptor* field) const { + return static_cast*>(MutableRawRepeatedField( + message, field, FieldDescriptor::CPPTYPE_MESSAGE, -1, nullptr)); +} + +template +inline const RepeatedPtrField& Reflection::GetRepeatedPtrFieldInternal( + const Message& message, const FieldDescriptor* field) const { + return *static_cast*>( + GetRawRepeatedField(message, field, FieldDescriptor::CPPTYPE_MESSAGE, -1, + PB::default_instance().GetDescriptor())); +} + +template +inline RepeatedPtrField* Reflection::MutableRepeatedPtrFieldInternal( + Message* message, const FieldDescriptor* field) const { + return static_cast*>( + MutableRawRepeatedField(message, field, FieldDescriptor::CPPTYPE_MESSAGE, + -1, PB::default_instance().GetDescriptor())); +} + +template +const Type& Reflection::DefaultRaw(const FieldDescriptor* field) const { + return *reinterpret_cast(schema_.GetFieldDefault(field)); +} +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_MESSAGE_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/message_lite.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/message_lite.h new file mode 100644 index 0000000000000000000000000000000000000000..963173cd44ff28285942acea952dc7963d8cffc9 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/message_lite.h @@ -0,0 +1,608 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Authors: wink@google.com (Wink Saville), +// kenton@google.com (Kenton Varda) +// Based on original Protocol Buffers design by +// Sanjay Ghemawat, Jeff Dean, and others. +// +// Defines MessageLite, the abstract interface implemented by all (lite +// and non-lite) protocol message objects. + +#ifndef GOOGLE_PROTOBUF_MESSAGE_LITE_H__ +#define GOOGLE_PROTOBUF_MESSAGE_LITE_H__ + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + + +#include + +#ifdef SWIG +#error "You cannot SWIG proto headers" +#endif + +namespace google { +namespace protobuf { + +template +class RepeatedPtrField; + +namespace io { + +class CodedInputStream; +class CodedOutputStream; +class ZeroCopyInputStream; +class ZeroCopyOutputStream; + +} // namespace io +namespace internal { + +// See parse_context.h for explanation +class ParseContext; + +class RepeatedPtrFieldBase; +class WireFormatLite; +class WeakFieldMap; + +// We compute sizes as size_t but cache them as int. This function converts a +// computed size to a cached size. Since we don't proceed with serialization +// if the total size was > INT_MAX, it is not important what this function +// returns for inputs > INT_MAX. However this case should not error or +// GOOGLE_CHECK-fail, because the full size_t resolution is still returned from +// ByteSizeLong() and checked against INT_MAX; we can catch the overflow +// there. +inline int ToCachedSize(size_t size) { return static_cast(size); } + +// We mainly calculate sizes in terms of size_t, but some functions that +// compute sizes return "int". These int sizes are expected to always be +// positive. This function is more efficient than casting an int to size_t +// directly on 64-bit platforms because it avoids making the compiler emit a +// sign extending instruction, which we don't want and don't want to pay for. +inline size_t FromIntSize(int size) { + // Convert to unsigned before widening so sign extension is not necessary. + return static_cast(size); +} + +// For cases where a legacy function returns an integer size. We GOOGLE_DCHECK() +// that the conversion will fit within an integer; if this is false then we +// are losing information. +inline int ToIntSize(size_t size) { + GOOGLE_DCHECK_LE(size, static_cast(INT_MAX)); + return static_cast(size); +} + +// This type wraps a variable whose constructor and destructor are explicitly +// called. It is particularly useful for a global variable, without its +// constructor and destructor run on start and end of the program lifetime. +// This circumvents the initial construction order fiasco, while keeping +// the address of the empty string a compile time constant. +// +// Pay special attention to the initialization state of the object. +// 1. The object is "uninitialized" to begin with. +// 2. Call Construct() or DefaultConstruct() only if the object is +// uninitialized. After the call, the object becomes "initialized". +// 3. Call get() and get_mutable() only if the object is initialized. +// 4. Call Destruct() only if the object is initialized. +// After the call, the object becomes uninitialized. +template +class ExplicitlyConstructed { + public: + void DefaultConstruct() { new (&union_) T(); } + + template + void Construct(Args&&... args) { + new (&union_) T(std::forward(args)...); + } + + void Destruct() { get_mutable()->~T(); } + + constexpr const T& get() const { return reinterpret_cast(union_); } + T* get_mutable() { return reinterpret_cast(&union_); } + + private: + // Prefer c++14 aligned_storage, but for compatibility this will do. + union AlignedUnion { + char space[sizeof(T)]; + int64 align_to_int64; + void* align_to_ptr; + } union_; +}; + +// Default empty string object. Don't use this directly. Instead, call +// GetEmptyString() to get the reference. +PROTOBUF_EXPORT extern ExplicitlyConstructed + fixed_address_empty_string; + + +PROTOBUF_EXPORT inline const std::string& GetEmptyStringAlreadyInited() { + return fixed_address_empty_string.get(); +} + +PROTOBUF_EXPORT size_t StringSpaceUsedExcludingSelfLong(const std::string& str); + +} // namespace internal + +// Interface to light weight protocol messages. +// +// This interface is implemented by all protocol message objects. Non-lite +// messages additionally implement the Message interface, which is a +// subclass of MessageLite. Use MessageLite instead when you only need +// the subset of features which it supports -- namely, nothing that uses +// descriptors or reflection. You can instruct the protocol compiler +// to generate classes which implement only MessageLite, not the full +// Message interface, by adding the following line to the .proto file: +// +// option optimize_for = LITE_RUNTIME; +// +// This is particularly useful on resource-constrained systems where +// the full protocol buffers runtime library is too big. +// +// Note that on non-constrained systems (e.g. servers) when you need +// to link in lots of protocol definitions, a better way to reduce +// total code footprint is to use optimize_for = CODE_SIZE. This +// will make the generated code smaller while still supporting all the +// same features (at the expense of speed). optimize_for = LITE_RUNTIME +// is best when you only have a small number of message types linked +// into your binary, in which case the size of the protocol buffers +// runtime itself is the biggest problem. +// +// Users must not derive from this class. Only the protocol compiler and +// the internal library are allowed to create subclasses. +class PROTOBUF_EXPORT MessageLite { + public: + inline MessageLite() {} + virtual ~MessageLite() = default; + + // Basic Operations ------------------------------------------------ + + // Get the name of this message type, e.g. "foo.bar.BazProto". + virtual std::string GetTypeName() const = 0; + + // Construct a new instance of the same type. Ownership is passed to the + // caller. + virtual MessageLite* New() const = 0; + + // Construct a new instance on the arena. Ownership is passed to the caller + // if arena is a NULL. Default implementation for backwards compatibility. + virtual MessageLite* New(Arena* arena) const; + + // Get the arena, if any, associated with this message. Virtual method + // required for generic operations but most arena-related operations should + // use the GetArena() generated-code method. Default implementation + // to reduce code size by avoiding the need for per-type implementations + // when types do not implement arena support. + Arena* GetArena() const { return _internal_metadata_.arena(); } + + // Get a pointer that may be equal to this message's arena, or may not be. + // If the value returned by this method is equal to some arena pointer, then + // this message is on that arena; however, if this message is on some arena, + // this method may or may not return that arena's pointer. As a tradeoff, + // this method may be more efficient than GetArena(). The intent is to allow + // underlying representations that use e.g. tagged pointers to sometimes + // store the arena pointer directly, and sometimes in a more indirect way, + // and allow a fastpath comparison against the arena pointer when it's easy + // to obtain. + void* GetMaybeArenaPointer() const { + return _internal_metadata_.raw_arena_ptr(); + } + + // Clear all fields of the message and set them to their default values. + // Clear() avoids freeing memory, assuming that any memory allocated + // to hold parts of the message will be needed again to hold the next + // message. If you actually want to free the memory used by a Message, + // you must delete it. + virtual void Clear() = 0; + + // Quickly check if all required fields have values set. + virtual bool IsInitialized() const = 0; + + // This is not implemented for Lite messages -- it just returns "(cannot + // determine missing fields for lite message)". However, it is implemented + // for full messages. See message.h. + virtual std::string InitializationErrorString() const; + + // If |other| is the exact same class as this, calls MergeFrom(). Otherwise, + // results are undefined (probably crash). + virtual void CheckTypeAndMergeFrom(const MessageLite& other) = 0; + + // These methods return a human-readable summary of the message. Note that + // since the MessageLite interface does not support reflection, there is very + // little information that these methods can provide. They are shadowed by + // methods of the same name on the Message interface which provide much more + // information. The methods here are intended primarily to facilitate code + // reuse for logic that needs to interoperate with both full and lite protos. + // + // The format of the returned string is subject to change, so please do not + // assume it will remain stable over time. + std::string DebugString() const; + std::string ShortDebugString() const { return DebugString(); } + // MessageLite::DebugString is already Utf8 Safe. This is to add compatibility + // with Message. + std::string Utf8DebugString() const { return DebugString(); } + + // Parsing --------------------------------------------------------- + // Methods for parsing in protocol buffer format. Most of these are + // just simple wrappers around MergeFromCodedStream(). Clear() will be + // called before merging the input. + + // Fill the message with a protocol buffer parsed from the given input + // stream. Returns false on a read error or if the input is in the wrong + // format. A successful return does not indicate the entire input is + // consumed, ensure you call ConsumedEntireMessage() to check that if + // applicable. + PROTOBUF_ATTRIBUTE_REINITIALIZES bool ParseFromCodedStream( + io::CodedInputStream* input); + // Like ParseFromCodedStream(), but accepts messages that are missing + // required fields. + PROTOBUF_ATTRIBUTE_REINITIALIZES bool ParsePartialFromCodedStream( + io::CodedInputStream* input); + // Read a protocol buffer from the given zero-copy input stream. If + // successful, the entire input will be consumed. + PROTOBUF_ATTRIBUTE_REINITIALIZES bool ParseFromZeroCopyStream( + io::ZeroCopyInputStream* input); + // Like ParseFromZeroCopyStream(), but accepts messages that are missing + // required fields. + PROTOBUF_ATTRIBUTE_REINITIALIZES bool ParsePartialFromZeroCopyStream( + io::ZeroCopyInputStream* input); + // Parse a protocol buffer from a file descriptor. If successful, the entire + // input will be consumed. + PROTOBUF_ATTRIBUTE_REINITIALIZES bool ParseFromFileDescriptor( + int file_descriptor); + // Like ParseFromFileDescriptor(), but accepts messages that are missing + // required fields. + PROTOBUF_ATTRIBUTE_REINITIALIZES bool ParsePartialFromFileDescriptor( + int file_descriptor); + // Parse a protocol buffer from a C++ istream. If successful, the entire + // input will be consumed. + PROTOBUF_ATTRIBUTE_REINITIALIZES bool ParseFromIstream(std::istream* input); + // Like ParseFromIstream(), but accepts messages that are missing + // required fields. + PROTOBUF_ATTRIBUTE_REINITIALIZES bool ParsePartialFromIstream( + std::istream* input); + // Read a protocol buffer from the given zero-copy input stream, expecting + // the message to be exactly "size" bytes long. If successful, exactly + // this many bytes will have been consumed from the input. + bool MergePartialFromBoundedZeroCopyStream(io::ZeroCopyInputStream* input, + int size); + // Like ParseFromBoundedZeroCopyStream(), but accepts messages that are + // missing required fields. + bool MergeFromBoundedZeroCopyStream(io::ZeroCopyInputStream* input, int size); + PROTOBUF_ATTRIBUTE_REINITIALIZES bool ParseFromBoundedZeroCopyStream( + io::ZeroCopyInputStream* input, int size); + // Like ParseFromBoundedZeroCopyStream(), but accepts messages that are + // missing required fields. + PROTOBUF_ATTRIBUTE_REINITIALIZES bool ParsePartialFromBoundedZeroCopyStream( + io::ZeroCopyInputStream* input, int size); + // Parses a protocol buffer contained in a string. Returns true on success. + // This function takes a string in the (non-human-readable) binary wire + // format, matching the encoding output by MessageLite::SerializeToString(). + // If you'd like to convert a human-readable string into a protocol buffer + // object, see google::protobuf::TextFormat::ParseFromString(). + PROTOBUF_ATTRIBUTE_REINITIALIZES bool ParseFromString( + const std::string& data); + // Like ParseFromString(), but accepts messages that are missing + // required fields. + PROTOBUF_ATTRIBUTE_REINITIALIZES bool ParsePartialFromString( + const std::string& data); + // Parse a protocol buffer contained in an array of bytes. + PROTOBUF_ATTRIBUTE_REINITIALIZES bool ParseFromArray(const void* data, + int size); + // Like ParseFromArray(), but accepts messages that are missing + // required fields. + PROTOBUF_ATTRIBUTE_REINITIALIZES bool ParsePartialFromArray(const void* data, + int size); + + + // Reads a protocol buffer from the stream and merges it into this + // Message. Singular fields read from the what is + // already in the Message and repeated fields are appended to those + // already present. + // + // It is the responsibility of the caller to call input->LastTagWas() + // (for groups) or input->ConsumedEntireMessage() (for non-groups) after + // this returns to verify that the message's end was delimited correctly. + // + // ParseFromCodedStream() is implemented as Clear() followed by + // MergeFromCodedStream(). + bool MergeFromCodedStream(io::CodedInputStream* input); + + // Like MergeFromCodedStream(), but succeeds even if required fields are + // missing in the input. + // + // MergeFromCodedStream() is just implemented as MergePartialFromCodedStream() + // followed by IsInitialized(). + bool MergePartialFromCodedStream(io::CodedInputStream* input); + + // Merge a protocol buffer contained in a string. + bool MergeFromString(const std::string& data); + + + // Serialization --------------------------------------------------- + // Methods for serializing in protocol buffer format. Most of these + // are just simple wrappers around ByteSize() and SerializeWithCachedSizes(). + + // Write a protocol buffer of this message to the given output. Returns + // false on a write error. If the message is missing required fields, + // this may GOOGLE_CHECK-fail. + bool SerializeToCodedStream(io::CodedOutputStream* output) const; + // Like SerializeToCodedStream(), but allows missing required fields. + bool SerializePartialToCodedStream(io::CodedOutputStream* output) const; + // Write the message to the given zero-copy output stream. All required + // fields must be set. + bool SerializeToZeroCopyStream(io::ZeroCopyOutputStream* output) const; + // Like SerializeToZeroCopyStream(), but allows missing required fields. + bool SerializePartialToZeroCopyStream(io::ZeroCopyOutputStream* output) const; + // Serialize the message and store it in the given string. All required + // fields must be set. + bool SerializeToString(std::string* output) const; + // Like SerializeToString(), but allows missing required fields. + bool SerializePartialToString(std::string* output) const; + // Serialize the message and store it in the given byte array. All required + // fields must be set. + bool SerializeToArray(void* data, int size) const; + // Like SerializeToArray(), but allows missing required fields. + bool SerializePartialToArray(void* data, int size) const; + + // Make a string encoding the message. Is equivalent to calling + // SerializeToString() on a string and using that. Returns the empty + // string if SerializeToString() would have returned an error. + // Note: If you intend to generate many such strings, you may + // reduce heap fragmentation by instead re-using the same string + // object with calls to SerializeToString(). + std::string SerializeAsString() const; + // Like SerializeAsString(), but allows missing required fields. + std::string SerializePartialAsString() const; + + // Serialize the message and write it to the given file descriptor. All + // required fields must be set. + bool SerializeToFileDescriptor(int file_descriptor) const; + // Like SerializeToFileDescriptor(), but allows missing required fields. + bool SerializePartialToFileDescriptor(int file_descriptor) const; + // Serialize the message and write it to the given C++ ostream. All + // required fields must be set. + bool SerializeToOstream(std::ostream* output) const; + // Like SerializeToOstream(), but allows missing required fields. + bool SerializePartialToOstream(std::ostream* output) const; + + // Like SerializeToString(), but appends to the data to the string's + // existing contents. All required fields must be set. + bool AppendToString(std::string* output) const; + // Like AppendToString(), but allows missing required fields. + bool AppendPartialToString(std::string* output) const; + + + // Computes the serialized size of the message. This recursively calls + // ByteSizeLong() on all embedded messages. + // + // ByteSizeLong() is generally linear in the number of fields defined for the + // proto. + virtual size_t ByteSizeLong() const = 0; + + // Legacy ByteSize() API. + PROTOBUF_DEPRECATED_MSG("Please use ByteSizeLong() instead") + int ByteSize() const { return internal::ToIntSize(ByteSizeLong()); } + + // Serializes the message without recomputing the size. The message must not + // have changed since the last call to ByteSize(), and the value returned by + // ByteSize must be non-negative. Otherwise the results are undefined. + void SerializeWithCachedSizes(io::CodedOutputStream* output) const { + output->SetCur(_InternalSerialize(output->Cur(), output->EpsCopy())); + } + + // Functions below here are not part of the public interface. It isn't + // enforced, but they should be treated as private, and will be private + // at some future time. Unfortunately the implementation of the "friend" + // keyword in GCC is broken at the moment, but we expect it will be fixed. + + // Like SerializeWithCachedSizes, but writes directly to *target, returning + // a pointer to the byte immediately after the last byte written. "target" + // must point at a byte array of at least ByteSize() bytes. Whether to use + // deterministic serialization, e.g., maps in sorted order, is determined by + // CodedOutputStream::IsDefaultSerializationDeterministic(). + uint8* SerializeWithCachedSizesToArray(uint8* target) const; + + // Returns the result of the last call to ByteSize(). An embedded message's + // size is needed both to serialize it (because embedded messages are + // length-delimited) and to compute the outer message's size. Caching + // the size avoids computing it multiple times. + // + // ByteSize() does not automatically use the cached size when available + // because this would require invalidating it every time the message was + // modified, which would be too hard and expensive. (E.g. if a deeply-nested + // sub-message is changed, all of its parents' cached sizes would need to be + // invalidated, which is too much work for an otherwise inlined setter + // method.) + virtual int GetCachedSize() const = 0; + + virtual const char* _InternalParse(const char* /*ptr*/, + internal::ParseContext* /*ctx*/) { + return nullptr; + } + + protected: + template + static T* CreateMaybeMessage(Arena* arena) { + return Arena::CreateMaybeMessage(arena); + } + + inline explicit MessageLite(Arena* arena) : _internal_metadata_(arena) {} + + internal::InternalMetadata _internal_metadata_; + + public: + enum ParseFlags { + kMerge = 0, + kParse = 1, + kMergePartial = 2, + kParsePartial = 3, + kMergeWithAliasing = 4, + kParseWithAliasing = 5, + kMergePartialWithAliasing = 6, + kParsePartialWithAliasing = 7 + }; + + template + bool ParseFrom(const T& input); + + // Fast path when conditions match (ie. non-deterministic) + // uint8* _InternalSerialize(uint8* ptr) const; + virtual uint8* _InternalSerialize(uint8* ptr, + io::EpsCopyOutputStream* stream) const = 0; + + // Identical to IsInitialized() except that it logs an error message. + bool IsInitializedWithErrors() const { + if (IsInitialized()) return true; + LogInitializationErrorMessage(); + return false; + } + + private: + // TODO(gerbens) make this a pure abstract function + virtual const void* InternalGetTable() const { return NULL; } + + friend class internal::WireFormatLite; + friend class Message; + friend class internal::WeakFieldMap; + + void LogInitializationErrorMessage() const; + + bool MergeFromImpl(io::CodedInputStream* input, ParseFlags parse_flags); + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(MessageLite); +}; + +namespace internal { + +template +bool MergeFromImpl(StringPiece input, MessageLite* msg, + MessageLite::ParseFlags parse_flags); +extern template bool MergeFromImpl(StringPiece input, + MessageLite* msg, + MessageLite::ParseFlags parse_flags); +extern template bool MergeFromImpl(StringPiece input, + MessageLite* msg, + MessageLite::ParseFlags parse_flags); + +template +bool MergeFromImpl(io::ZeroCopyInputStream* input, MessageLite* msg, + MessageLite::ParseFlags parse_flags); +extern template bool MergeFromImpl(io::ZeroCopyInputStream* input, + MessageLite* msg, + MessageLite::ParseFlags parse_flags); +extern template bool MergeFromImpl(io::ZeroCopyInputStream* input, + MessageLite* msg, + MessageLite::ParseFlags parse_flags); + +struct BoundedZCIS { + io::ZeroCopyInputStream* zcis; + int limit; +}; + +template +bool MergeFromImpl(BoundedZCIS input, MessageLite* msg, + MessageLite::ParseFlags parse_flags); +extern template bool MergeFromImpl(BoundedZCIS input, MessageLite* msg, + MessageLite::ParseFlags parse_flags); +extern template bool MergeFromImpl(BoundedZCIS input, MessageLite* msg, + MessageLite::ParseFlags parse_flags); + +template +struct SourceWrapper; + +template +bool MergeFromImpl(const SourceWrapper& input, MessageLite* msg, + MessageLite::ParseFlags parse_flags) { + return input.template MergeInto(msg, parse_flags); +} + +} // namespace internal + +template +bool MessageLite::ParseFrom(const T& input) { + if (flags & kParse) Clear(); + constexpr bool alias = (flags & kMergeWithAliasing) != 0; + return internal::MergeFromImpl(input, this, flags); +} + +// =================================================================== +// Shutdown support. + + +// Shut down the entire protocol buffers library, deleting all static-duration +// objects allocated by the library or by generated .pb.cc files. +// +// There are two reasons you might want to call this: +// * You use a draconian definition of "memory leak" in which you expect +// every single malloc() to have a corresponding free(), even for objects +// which live until program exit. +// * You are writing a dynamically-loaded library which needs to clean up +// after itself when the library is unloaded. +// +// It is safe to call this multiple times. However, it is not safe to use +// any other part of the protocol buffers library after +// ShutdownProtobufLibrary() has been called. Furthermore this call is not +// thread safe, user needs to synchronize multiple calls. +PROTOBUF_EXPORT void ShutdownProtobufLibrary(); + +namespace internal { + +// Register a function to be called when ShutdownProtocolBuffers() is called. +PROTOBUF_EXPORT void OnShutdown(void (*func)()); +// Run an arbitrary function on an arg +PROTOBUF_EXPORT void OnShutdownRun(void (*f)(const void*), const void* arg); + +template +T* OnShutdownDelete(T* p) { + OnShutdownRun([](const void* pp) { delete static_cast(pp); }, p); + return p; +} + +} // namespace internal +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_MESSAGE_LITE_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/metadata.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/metadata.h new file mode 100644 index 0000000000000000000000000000000000000000..756936991766a4ef660ba971bc9b9172bd734b8d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/metadata.h @@ -0,0 +1,41 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef GOOGLE_PROTOBUF_METADATA_H__ +#define GOOGLE_PROTOBUF_METADATA_H__ + +// TODO(b/151117630): Remove this file and all instances where it gets imported. + +#endif // GOOGLE_PROTOBUF_METADATA_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/metadata_lite.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/metadata_lite.h new file mode 100644 index 0000000000000000000000000000000000000000..eac4a8c4e754a4d25046409e514b1d32a6fd509e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/metadata_lite.h @@ -0,0 +1,253 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef GOOGLE_PROTOBUF_METADATA_LITE_H__ +#define GOOGLE_PROTOBUF_METADATA_LITE_H__ + +#include +#include +#include +#include + +#include + +#ifdef SWIG +#error "You cannot SWIG proto headers" +#endif + +namespace google { +namespace protobuf { +namespace internal { + +// This is the representation for messages that support arena allocation. It +// uses a tagged pointer to either store the Arena pointer, if there are no +// unknown fields, or a pointer to a block of memory with both the Arena pointer +// and the UnknownFieldSet, if there are unknown fields. This optimization +// allows for "zero-overhead" storage of the Arena pointer, relative to the +// above baseline implementation. +// +// The tagged pointer uses the LSB to disambiguate cases, and uses bit 0 == 0 to +// indicate an arena pointer and bit 0 == 1 to indicate a UFS+Arena-container +// pointer. +class InternalMetadata { + public: + InternalMetadata() : ptr_(nullptr) {} + explicit InternalMetadata(Arena* arena) : ptr_(arena) {} + + template + void Delete() { + // Note that Delete<> should be called not more than once. + if (have_unknown_fields() && arena() == NULL) { + delete PtrValue>(); + } + } + + PROTOBUF_ALWAYS_INLINE Arena* arena() const { + if (PROTOBUF_PREDICT_FALSE(have_unknown_fields())) { + return PtrValue()->arena; + } else { + return PtrValue(); + } + } + + PROTOBUF_ALWAYS_INLINE bool have_unknown_fields() const { + return PtrTag() == kTagContainer; + } + + PROTOBUF_ALWAYS_INLINE void* raw_arena_ptr() const { return ptr_; } + + template + PROTOBUF_ALWAYS_INLINE const T& unknown_fields( + const T& (*default_instance)()) const { + if (PROTOBUF_PREDICT_FALSE(have_unknown_fields())) { + return PtrValue>()->unknown_fields; + } else { + return default_instance(); + } + } + + template + PROTOBUF_ALWAYS_INLINE T* mutable_unknown_fields() { + if (PROTOBUF_PREDICT_TRUE(have_unknown_fields())) { + return &PtrValue>()->unknown_fields; + } else { + return mutable_unknown_fields_slow(); + } + } + + template + PROTOBUF_ALWAYS_INLINE void Swap(InternalMetadata* other) { + // Semantics here are that we swap only the unknown fields, not the arena + // pointer. We cannot simply swap ptr_ with other->ptr_ because we need to + // maintain our own arena ptr. Also, our ptr_ and other's ptr_ may be in + // different states (direct arena pointer vs. container with UFS) so we + // cannot simply swap ptr_ and then restore the arena pointers. We reuse + // UFS's swap implementation instead. + if (have_unknown_fields() || other->have_unknown_fields()) { + DoSwap(other->mutable_unknown_fields()); + } + } + + template + PROTOBUF_ALWAYS_INLINE void MergeFrom(const InternalMetadata& other) { + if (other.have_unknown_fields()) { + DoMergeFrom(other.unknown_fields(nullptr)); + } + } + + template + PROTOBUF_ALWAYS_INLINE void Clear() { + if (have_unknown_fields()) { + DoClear(); + } + } + + private: + void* ptr_; + + // Tagged pointer implementation. + enum { + // ptr_ is an Arena*. + kTagArena = 0, + // ptr_ is a Container*. + kTagContainer = 1, + }; + static constexpr intptr_t kPtrTagMask = 1; + static constexpr intptr_t kPtrValueMask = ~kPtrTagMask; + + // Accessors for pointer tag and pointer value. + PROTOBUF_ALWAYS_INLINE int PtrTag() const { + return reinterpret_cast(ptr_) & kPtrTagMask; + } + + template + U* PtrValue() const { + return reinterpret_cast(reinterpret_cast(ptr_) & + kPtrValueMask); + } + + // If ptr_'s tag is kTagContainer, it points to an instance of this struct. + struct ContainerBase { + Arena* arena; + }; + + template + struct Container : public ContainerBase { + T unknown_fields; + }; + + template + PROTOBUF_NOINLINE T* mutable_unknown_fields_slow() { + Arena* my_arena = arena(); + Container* container = Arena::Create>(my_arena); + // Two-step assignment works around a bug in clang's static analyzer: + // https://bugs.llvm.org/show_bug.cgi?id=34198. + ptr_ = container; + ptr_ = reinterpret_cast(reinterpret_cast(ptr_) | + kTagContainer); + container->arena = my_arena; + return &(container->unknown_fields); + } + + // Templated functions. + + template + void DoClear() { + mutable_unknown_fields()->Clear(); + } + + template + void DoMergeFrom(const T& other) { + mutable_unknown_fields()->MergeFrom(other); + } + + template + void DoSwap(T* other) { + mutable_unknown_fields()->Swap(other); + } +}; + +// String Template specializations. + +template <> +inline void InternalMetadata::DoClear() { + mutable_unknown_fields()->clear(); +} + +template <> +inline void InternalMetadata::DoMergeFrom( + const std::string& other) { + mutable_unknown_fields()->append(other); +} + +template <> +inline void InternalMetadata::DoSwap(std::string* other) { + mutable_unknown_fields()->swap(*other); +} + +// This helper RAII class is needed to efficiently parse unknown fields. We +// should only call mutable_unknown_fields if there are actual unknown fields. +// The obvious thing to just use a stack string and swap it at the end of +// the parse won't work, because the destructor of StringOutputStream needs to +// be called before we can modify the string (it check-fails). Using +// LiteUnknownFieldSetter setter(&_internal_metadata_); +// StringOutputStream stream(setter.buffer()); +// guarantees that the string is only swapped after stream is destroyed. +class PROTOBUF_EXPORT LiteUnknownFieldSetter { + public: + explicit LiteUnknownFieldSetter(InternalMetadata* metadata) + : metadata_(metadata) { + if (metadata->have_unknown_fields()) { + buffer_.swap(*metadata->mutable_unknown_fields()); + } + } + ~LiteUnknownFieldSetter() { + if (!buffer_.empty()) + metadata_->mutable_unknown_fields()->swap(buffer_); + } + std::string* buffer() { return &buffer_; } + + private: + InternalMetadata* metadata_; + std::string buffer_; +}; + +} // namespace internal +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_METADATA_LITE_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/parse_context.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/parse_context.h new file mode 100644 index 0000000000000000000000000000000000000000..b462bedf592c44f444bfd6f62eb77eda7bf0f1f6 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/parse_context.h @@ -0,0 +1,810 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef GOOGLE_PROTOBUF_PARSE_CONTEXT_H__ +#define GOOGLE_PROTOBUF_PARSE_CONTEXT_H__ + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + + +namespace google { +namespace protobuf { + +class UnknownFieldSet; +class DescriptorPool; +class MessageFactory; + +namespace internal { + +// Template code below needs to know about the existence of these functions. +PROTOBUF_EXPORT void WriteVarint(uint32 num, uint64 val, std::string* s); +PROTOBUF_EXPORT void WriteLengthDelimited(uint32 num, StringPiece val, + std::string* s); +// Inline because it is just forwarding to s->WriteVarint +inline void WriteVarint(uint32 num, uint64 val, UnknownFieldSet* s); +inline void WriteLengthDelimited(uint32 num, StringPiece val, + UnknownFieldSet* s); + + +// The basic abstraction the parser is designed for is a slight modification +// of the ZeroCopyInputStream (ZCIS) abstraction. A ZCIS presents a serialized +// stream as a series of buffers that concatenate to the full stream. +// Pictorially a ZCIS presents a stream in chunks like so +// [---------------------------------------------------------------] +// [---------------------] chunk 1 +// [----------------------------] chunk 2 +// chunk 3 [--------------] +// +// Where the '-' represent the bytes which are vertically lined up with the +// bytes of the stream. The proto parser requires its input to be presented +// similarly with the extra +// property that each chunk has kSlopBytes past its end that overlaps with the +// first kSlopBytes of the next chunk, or if there is no next chunk at least its +// still valid to read those bytes. Again, pictorially, we now have +// +// [---------------------------------------------------------------] +// [-------------------....] chunk 1 +// [------------------------....] chunk 2 +// chunk 3 [------------------..**] +// chunk 4 [--****] +// Here '-' mean the bytes of the stream or chunk and '.' means bytes past the +// chunk that match up with the start of the next chunk. Above each chunk has +// 4 '.' after the chunk. In the case these 'overflow' bytes represents bytes +// past the stream, indicated by '*' above, their values are unspecified. It is +// still legal to read them (ie. should not segfault). Reading past the +// end should be detected by the user and indicated as an error. +// +// The reason for this, admittedly, unconventional invariant is to ruthlessly +// optimize the protobuf parser. Having an overlap helps in two important ways. +// Firstly it alleviates having to performing bounds checks if a piece of code +// is guaranteed to not read more than kSlopBytes. Secondly, and more +// importantly, the protobuf wireformat is such that reading a key/value pair is +// always less than 16 bytes. This removes the need to change to next buffer in +// the middle of reading primitive values. Hence there is no need to store and +// load the current position. + +class PROTOBUF_EXPORT EpsCopyInputStream { + public: + enum { kSlopBytes = 16, kMaxCordBytesToCopy = 512 }; + + explicit EpsCopyInputStream(bool enable_aliasing) + : aliasing_(enable_aliasing ? kOnPatch : kNoAliasing) {} + + void BackUp(const char* ptr) { + GOOGLE_DCHECK(ptr <= buffer_end_ + kSlopBytes); + int count; + if (next_chunk_ == buffer_) { + count = static_cast(buffer_end_ + kSlopBytes - ptr); + } else { + count = size_ + static_cast(buffer_end_ - ptr); + } + if (count > 0) StreamBackUp(count); + } + + // If return value is negative it's an error + PROTOBUF_MUST_USE_RESULT int PushLimit(const char* ptr, int limit) { + GOOGLE_DCHECK(limit >= 0 && limit <= INT_MAX - kSlopBytes); + // This add is safe due to the invariant above, because + // ptr - buffer_end_ <= kSlopBytes. + limit += static_cast(ptr - buffer_end_); + limit_end_ = buffer_end_ + (std::min)(0, limit); + auto old_limit = limit_; + limit_ = limit; + return old_limit - limit; + } + + PROTOBUF_MUST_USE_RESULT bool PopLimit(int delta) { + if (PROTOBUF_PREDICT_FALSE(!EndedAtLimit())) return false; + limit_ = limit_ + delta; + // TODO(gerbens) We could remove this line and hoist the code to + // DoneFallback. Study the perf/bin-size effects. + limit_end_ = buffer_end_ + (std::min)(0, limit_); + return true; + } + + PROTOBUF_MUST_USE_RESULT const char* Skip(const char* ptr, int size) { + if (size <= buffer_end_ + kSlopBytes - ptr) { + return ptr + size; + } + return SkipFallback(ptr, size); + } + PROTOBUF_MUST_USE_RESULT const char* ReadString(const char* ptr, int size, + std::string* s) { + if (size <= buffer_end_ + kSlopBytes - ptr) { + s->assign(ptr, size); + return ptr + size; + } + return ReadStringFallback(ptr, size, s); + } + PROTOBUF_MUST_USE_RESULT const char* AppendString(const char* ptr, int size, + std::string* s) { + if (size <= buffer_end_ + kSlopBytes - ptr) { + s->append(ptr, size); + return ptr + size; + } + return AppendStringFallback(ptr, size, s); + } + + template + PROTOBUF_MUST_USE_RESULT const char* ReadRepeatedFixed(const char* ptr, + Tag expected_tag, + RepeatedField* out); + + template + PROTOBUF_MUST_USE_RESULT const char* ReadPackedFixed(const char* ptr, + int size, + RepeatedField* out); + template + PROTOBUF_MUST_USE_RESULT const char* ReadPackedVarint(const char* ptr, + Add add); + + uint32 LastTag() const { return last_tag_minus_1_ + 1; } + bool ConsumeEndGroup(uint32 start_tag) { + bool res = last_tag_minus_1_ == start_tag; + last_tag_minus_1_ = 0; + return res; + } + bool EndedAtLimit() const { return last_tag_minus_1_ == 0; } + bool EndedAtEndOfStream() const { return last_tag_minus_1_ == 1; } + void SetLastTag(uint32 tag) { last_tag_minus_1_ = tag - 1; } + void SetEndOfStream() { last_tag_minus_1_ = 1; } + bool IsExceedingLimit(const char* ptr) { + return ptr > limit_end_ && + (next_chunk_ == nullptr || ptr - buffer_end_ > limit_); + } + int BytesUntilLimit(const char* ptr) const { + return limit_ + static_cast(buffer_end_ - ptr); + } + // Returns true if more data is available, if false is returned one has to + // call Done for further checks. + bool DataAvailable(const char* ptr) { return ptr < limit_end_; } + + protected: + // Returns true is limit (either an explicit limit or end of stream) is + // reached. It aligns *ptr across buffer seams. + // If limit is exceeded it returns true and ptr is set to null. + bool DoneWithCheck(const char** ptr, int d) { + GOOGLE_DCHECK(*ptr); + if (PROTOBUF_PREDICT_TRUE(*ptr < limit_end_)) return false; + // No need to fetch buffer if we ended on a limit in the slop region + if ((*ptr - buffer_end_) == limit_) return true; + auto res = DoneFallback(*ptr, d); + *ptr = res.first; + return res.second; + } + + const char* InitFrom(StringPiece flat) { + overall_limit_ = 0; + if (flat.size() > kSlopBytes) { + limit_ = kSlopBytes; + limit_end_ = buffer_end_ = flat.data() + flat.size() - kSlopBytes; + next_chunk_ = buffer_; + if (aliasing_ == kOnPatch) aliasing_ = kNoDelta; + return flat.data(); + } else { + std::memcpy(buffer_, flat.data(), flat.size()); + limit_ = 0; + limit_end_ = buffer_end_ = buffer_ + flat.size(); + next_chunk_ = nullptr; + if (aliasing_ == kOnPatch) { + aliasing_ = reinterpret_cast(flat.data()) - + reinterpret_cast(buffer_); + } + return buffer_; + } + } + + const char* InitFrom(io::ZeroCopyInputStream* zcis); + + const char* InitFrom(io::ZeroCopyInputStream* zcis, int limit) { + if (limit == -1) return InitFrom(zcis); + overall_limit_ = limit; + auto res = InitFrom(zcis); + limit_ = limit - static_cast(buffer_end_ - res); + limit_end_ = buffer_end_ + (std::min)(0, limit_); + return res; + } + + private: + const char* limit_end_; // buffer_end_ + min(limit_, 0) + const char* buffer_end_; + const char* next_chunk_; + int size_; + int limit_; // relative to buffer_end_; + io::ZeroCopyInputStream* zcis_ = nullptr; + char buffer_[2 * kSlopBytes] = {}; + enum { kNoAliasing = 0, kOnPatch = 1, kNoDelta = 2 }; + std::uintptr_t aliasing_ = kNoAliasing; + // This variable is used to communicate how the parse ended, in order to + // completely verify the parsed data. A wire-format parse can end because of + // one of the following conditions: + // 1) A parse can end on a pushed limit. + // 2) A parse can end on End Of Stream (EOS). + // 3) A parse can end on 0 tag (only valid for toplevel message). + // 4) A parse can end on an end-group tag. + // This variable should always be set to 0, which indicates case 1. If the + // parse terminated due to EOS (case 2), it's set to 1. In case the parse + // ended due to a terminating tag (case 3 and 4) it's set to (tag - 1). + // This var doesn't really belong in EpsCopyInputStream and should be part of + // the ParseContext, but case 2 is most easily and optimally implemented in + // DoneFallback. + uint32 last_tag_minus_1_ = 0; + int overall_limit_ = INT_MAX; // Overall limit independent of pushed limits. + // Pretty random large number that seems like a safe allocation on most + // systems. TODO(gerbens) do we need to set this as build flag? + enum { kSafeStringSize = 50000000 }; + + std::pair DoneFallback(const char* ptr, int d); + const char* Next(int overrun, int d); + const char* SkipFallback(const char* ptr, int size); + const char* AppendStringFallback(const char* ptr, int size, std::string* str); + const char* ReadStringFallback(const char* ptr, int size, std::string* str); + bool StreamNext(const void** data) { + bool res = zcis_->Next(data, &size_); + if (res) overall_limit_ -= size_; + return res; + } + void StreamBackUp(int count) { + zcis_->BackUp(count); + overall_limit_ += count; + } + + template + const char* AppendSize(const char* ptr, int size, const A& append) { + int chunk_size = buffer_end_ + kSlopBytes - ptr; + do { + GOOGLE_DCHECK(size > chunk_size); + append(ptr, chunk_size); + ptr += chunk_size; + size -= chunk_size; + // DoneFallBack asserts it isn't called when exactly on the limit. If this + // happens we fail the parse, as we are at the limit and still more bytes + // to read. + if (limit_ == kSlopBytes) return nullptr; + auto res = DoneFallback(ptr, -1); + if (res.second) return nullptr; // If done we passed the limit + ptr = res.first; + chunk_size = buffer_end_ + kSlopBytes - ptr; + } while (size > chunk_size); + append(ptr, size); + return ptr + size; + } + + // AppendUntilEnd appends data until a limit (either a PushLimit or end of + // stream. Normal payloads are from length delimited fields which have an + // explicit size. Reading until limit only comes when the string takes + // the place of a protobuf, ie RawMessage/StringRawMessage, lazy fields and + // implicit weak messages. We keep these methods private and friend them. + template + const char* AppendUntilEnd(const char* ptr, const A& append) { + while (!DoneWithCheck(&ptr, -1)) { + append(ptr, limit_end_ - ptr); + ptr = limit_end_; + } + return ptr; + } + + PROTOBUF_MUST_USE_RESULT const char* AppendString(const char* ptr, + std::string* str) { + return AppendUntilEnd( + ptr, [str](const char* p, ptrdiff_t s) { str->append(p, s); }); + } + friend class ImplicitWeakMessage; +}; + +// ParseContext holds all data that is global to the entire parse. Most +// importantly it contains the input stream, but also recursion depth and also +// stores the end group tag, in case a parser ended on a endgroup, to verify +// matching start/end group tags. +class PROTOBUF_EXPORT ParseContext : public EpsCopyInputStream { + public: + struct Data { + const DescriptorPool* pool = nullptr; + MessageFactory* factory = nullptr; + }; + + template + ParseContext(int depth, bool aliasing, const char** start, T&&... args) + : EpsCopyInputStream(aliasing), depth_(depth) { + *start = InitFrom(std::forward(args)...); + } + + void TrackCorrectEnding() { group_depth_ = 0; } + + bool Done(const char** ptr) { return DoneWithCheck(ptr, group_depth_); } + bool DoneNoSlopCheck(const char** ptr) { return DoneWithCheck(ptr, -1); } + + int depth() const { return depth_; } + + Data& data() { return data_; } + const Data& data() const { return data_; } + + template + PROTOBUF_MUST_USE_RESULT const char* ParseMessage(T* msg, const char* ptr); + // We outline when the type is generic and we go through a virtual + const char* ParseMessage(MessageLite* msg, const char* ptr); + const char* ParseMessage(Message* msg, const char* ptr); + + template + PROTOBUF_MUST_USE_RESULT PROTOBUF_ALWAYS_INLINE const char* ParseGroup( + T* msg, const char* ptr, uint32 tag) { + if (--depth_ < 0) return nullptr; + group_depth_++; + ptr = msg->_InternalParse(ptr, this); + group_depth_--; + depth_++; + if (PROTOBUF_PREDICT_FALSE(!ConsumeEndGroup(tag))) return nullptr; + return ptr; + } + + private: + // The context keeps an internal stack to keep track of the recursive + // part of the parse state. + // Current depth of the active parser, depth counts down. + // This is used to limit recursion depth (to prevent overflow on malicious + // data), but is also used to index in stack_ to store the current state. + int depth_; + // Unfortunately necessary for the fringe case of ending on 0 or end-group tag + // in the last kSlopBytes of a ZeroCopyInputStream chunk. + int group_depth_ = INT_MIN; + Data data_; +}; + +template +bool ExpectTag(const char* ptr) { + if (tag < 128) { + return *ptr == tag; + } else { + static_assert(tag < 128 * 128, "We only expect tags for 1 or 2 bytes"); + char buf[2] = {static_cast(tag | 0x80), static_cast(tag >> 7)}; + return std::memcmp(ptr, buf, 2) == 0; + } +} + +template +struct EndianHelper; + +template <> +struct EndianHelper<1> { + static uint8 Load(const void* p) { return *static_cast(p); } +}; + +template <> +struct EndianHelper<2> { + static uint16 Load(const void* p) { + uint16 tmp; + std::memcpy(&tmp, p, 2); +#ifndef PROTOBUF_LITTLE_ENDIAN + tmp = bswap_16(tmp); +#endif + return tmp; + } +}; + +template <> +struct EndianHelper<4> { + static uint32 Load(const void* p) { + uint32 tmp; + std::memcpy(&tmp, p, 4); +#ifndef PROTOBUF_LITTLE_ENDIAN + tmp = bswap_32(tmp); +#endif + return tmp; + } +}; + +template <> +struct EndianHelper<8> { + static uint64 Load(const void* p) { + uint64 tmp; + std::memcpy(&tmp, p, 8); +#ifndef PROTOBUF_LITTLE_ENDIAN + tmp = bswap_64(tmp); +#endif + return tmp; + } +}; + +template +T UnalignedLoad(const char* p) { + auto tmp = EndianHelper::Load(p); + T res; + memcpy(&res, &tmp, sizeof(T)); + return res; +} + +PROTOBUF_EXPORT +std::pair VarintParseSlow32(const char* p, uint32 res); +PROTOBUF_EXPORT +std::pair VarintParseSlow64(const char* p, uint32 res); + +inline const char* VarintParseSlow(const char* p, uint32 res, uint32* out) { + auto tmp = VarintParseSlow32(p, res); + *out = tmp.second; + return tmp.first; +} + +inline const char* VarintParseSlow(const char* p, uint32 res, uint64* out) { + auto tmp = VarintParseSlow64(p, res); + *out = tmp.second; + return tmp.first; +} + +template +PROTOBUF_MUST_USE_RESULT const char* VarintParse(const char* p, T* out) { + auto ptr = reinterpret_cast(p); + uint32 res = ptr[0]; + if (!(res & 0x80)) { + *out = res; + return p + 1; + } + uint32 byte = ptr[1]; + res += (byte - 1) << 7; + if (!(byte & 0x80)) { + *out = res; + return p + 2; + } + return VarintParseSlow(p, res, out); +} + +// Used for tags, could read up to 5 bytes which must be available. +// Caller must ensure its safe to call. + +PROTOBUF_EXPORT +std::pair ReadTagFallback(const char* p, uint32 res); + +// Same as ParseVarint but only accept 5 bytes at most. +inline const char* ReadTag(const char* p, uint32* out, uint32 /*max_tag*/ = 0) { + uint32 res = static_cast(p[0]); + if (res < 128) { + *out = res; + return p + 1; + } + uint32 second = static_cast(p[1]); + res += (second - 1) << 7; + if (second < 128) { + *out = res; + return p + 2; + } + auto tmp = ReadTagFallback(p, res); + *out = tmp.second; + return tmp.first; +} + +// Decode 2 consecutive bytes of a varint and returns the value, shifted left +// by 1. It simultaneous updates *ptr to *ptr + 1 or *ptr + 2 depending if the +// first byte's continuation bit is set. +// If bit 15 of return value is set (equivalent to the continuation bits of both +// bytes being set) the varint continues, otherwise the parse is done. On x86 +// movsx eax, dil +// add edi, eax +// adc [rsi], 1 +// add eax, eax +// and eax, edi +inline uint32 DecodeTwoBytes(const char** ptr) { + uint32 value = UnalignedLoad(*ptr); + // Sign extend the low byte continuation bit + uint32_t x = static_cast(value); + // This add is an amazing operation, it cancels the low byte continuation bit + // from y transferring it to the carry. Simultaneously it also shifts the 7 + // LSB left by one tightly against high byte varint bits. Hence value now + // contains the unpacked value shifted left by 1. + value += x; + // Use the carry to update the ptr appropriately. + *ptr += value < x ? 2 : 1; + return value & (x + x); // Mask out the high byte iff no continuation +} + +// More efficient varint parsing for big varints +inline const char* ParseBigVarint(const char* p, uint64* out) { + auto pnew = p; + auto tmp = DecodeTwoBytes(&pnew); + uint64 res = tmp >> 1; + if (PROTOBUF_PREDICT_TRUE(std::int16_t(tmp) >= 0)) { + *out = res; + return pnew; + } + for (std::uint32_t i = 1; i < 5; i++) { + pnew = p + 2 * i; + tmp = DecodeTwoBytes(&pnew); + res += (static_cast(tmp) - 2) << (14 * i - 1); + if (PROTOBUF_PREDICT_TRUE(std::int16_t(tmp) >= 0)) { + *out = res; + return pnew; + } + } + return nullptr; +} + +PROTOBUF_EXPORT +std::pair ReadSizeFallback(const char* p, uint32 first); +// Used for tags, could read up to 5 bytes which must be available. Additionally +// it makes sure the unsigned value fits a int32, otherwise returns nullptr. +// Caller must ensure its safe to call. +inline uint32 ReadSize(const char** pp) { + auto p = *pp; + uint32 res = static_cast(p[0]); + if (res < 128) { + *pp = p + 1; + return res; + } + auto x = ReadSizeFallback(p, res); + *pp = x.first; + return x.second; +} + +// Some convenience functions to simplify the generated parse loop code. +// Returning the value and updating the buffer pointer allows for nicer +// function composition. We rely on the compiler to inline this. +// Also in debug compiles having local scoped variables tend to generated +// stack frames that scale as O(num fields). +inline uint64 ReadVarint64(const char** p) { + uint64 tmp; + *p = VarintParse(*p, &tmp); + return tmp; +} + +inline uint32 ReadVarint32(const char** p) { + uint32 tmp; + *p = VarintParse(*p, &tmp); + return tmp; +} + +inline int64 ReadVarintZigZag64(const char** p) { + uint64 tmp; + *p = VarintParse(*p, &tmp); + return WireFormatLite::ZigZagDecode64(tmp); +} + +inline int32 ReadVarintZigZag32(const char** p) { + uint64 tmp; + *p = VarintParse(*p, &tmp); + return WireFormatLite::ZigZagDecode32(static_cast(tmp)); +} + +template +PROTOBUF_MUST_USE_RESULT const char* ParseContext::ParseMessage( + T* msg, const char* ptr) { + int size = ReadSize(&ptr); + if (!ptr) return nullptr; + auto old = PushLimit(ptr, size); + if (--depth_ < 0) return nullptr; + ptr = msg->_InternalParse(ptr, this); + if (PROTOBUF_PREDICT_FALSE(ptr == nullptr)) return nullptr; + depth_++; + if (!PopLimit(old)) return nullptr; + return ptr; +} + +template +const char* EpsCopyInputStream::ReadPackedVarint(const char* ptr, Add add) { + int size = ReadSize(&ptr); + if (ptr == nullptr) return nullptr; + auto old = PushLimit(ptr, size); + if (old < 0) return nullptr; + while (!DoneWithCheck(&ptr, -1)) { + uint64 varint; + ptr = VarintParse(ptr, &varint); + if (!ptr) return nullptr; + add(varint); + } + if (!PopLimit(old)) return nullptr; + return ptr; +} + +// Helper for verification of utf8 +PROTOBUF_EXPORT +bool VerifyUTF8(StringPiece s, const char* field_name); + +inline bool VerifyUTF8(const std::string* s, const char* field_name) { + return VerifyUTF8(*s, field_name); +} + +// All the string parsers with or without UTF checking and for all CTypes. +PROTOBUF_EXPORT PROTOBUF_MUST_USE_RESULT const char* InlineGreedyStringParser( + std::string* s, const char* ptr, ParseContext* ctx); + + +// Add any of the following lines to debug which parse function is failing. + +#define GOOGLE_PROTOBUF_ASSERT_RETURN(predicate, ret) \ + if (!(predicate)) { \ + /* ::raise(SIGINT); */ \ + /* GOOGLE_LOG(ERROR) << "Parse failure"; */ \ + return ret; \ + } + +#define GOOGLE_PROTOBUF_PARSER_ASSERT(predicate) \ + GOOGLE_PROTOBUF_ASSERT_RETURN(predicate, nullptr) + +template +PROTOBUF_MUST_USE_RESULT const char* FieldParser(uint64 tag, T& field_parser, + const char* ptr, + ParseContext* ctx) { + uint32 number = tag >> 3; + GOOGLE_PROTOBUF_PARSER_ASSERT(number != 0); + using WireType = internal::WireFormatLite::WireType; + switch (tag & 7) { + case WireType::WIRETYPE_VARINT: { + uint64 value; + ptr = VarintParse(ptr, &value); + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); + field_parser.AddVarint(number, value); + break; + } + case WireType::WIRETYPE_FIXED64: { + uint64 value = UnalignedLoad(ptr); + ptr += 8; + field_parser.AddFixed64(number, value); + break; + } + case WireType::WIRETYPE_LENGTH_DELIMITED: { + ptr = field_parser.ParseLengthDelimited(number, ptr, ctx); + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); + break; + } + case WireType::WIRETYPE_START_GROUP: { + ptr = field_parser.ParseGroup(number, ptr, ctx); + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); + break; + } + case WireType::WIRETYPE_END_GROUP: { + GOOGLE_LOG(FATAL) << "Can't happen"; + break; + } + case WireType::WIRETYPE_FIXED32: { + uint32 value = UnalignedLoad(ptr); + ptr += 4; + field_parser.AddFixed32(number, value); + break; + } + default: + return nullptr; + } + return ptr; +} + +template +PROTOBUF_MUST_USE_RESULT const char* WireFormatParser(T& field_parser, + const char* ptr, + ParseContext* ctx) { + while (!ctx->Done(&ptr)) { + uint32 tag; + ptr = ReadTag(ptr, &tag); + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr != nullptr); + if (tag == 0 || (tag & 7) == 4) { + ctx->SetLastTag(tag); + return ptr; + } + ptr = FieldParser(tag, field_parser, ptr, ctx); + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr != nullptr); + } + return ptr; +} + +// The packed parsers parse repeated numeric primitives directly into the +// corresponding field + +// These are packed varints +PROTOBUF_EXPORT PROTOBUF_MUST_USE_RESULT const char* PackedInt32Parser( + void* object, const char* ptr, ParseContext* ctx); +PROTOBUF_EXPORT PROTOBUF_MUST_USE_RESULT const char* PackedUInt32Parser( + void* object, const char* ptr, ParseContext* ctx); +PROTOBUF_EXPORT PROTOBUF_MUST_USE_RESULT const char* PackedInt64Parser( + void* object, const char* ptr, ParseContext* ctx); +PROTOBUF_EXPORT PROTOBUF_MUST_USE_RESULT const char* PackedUInt64Parser( + void* object, const char* ptr, ParseContext* ctx); +PROTOBUF_EXPORT PROTOBUF_MUST_USE_RESULT const char* PackedSInt32Parser( + void* object, const char* ptr, ParseContext* ctx); +PROTOBUF_EXPORT PROTOBUF_MUST_USE_RESULT const char* PackedSInt64Parser( + void* object, const char* ptr, ParseContext* ctx); +PROTOBUF_EXPORT PROTOBUF_MUST_USE_RESULT const char* PackedEnumParser( + void* object, const char* ptr, ParseContext* ctx); + +template +PROTOBUF_MUST_USE_RESULT const char* PackedEnumParser( + void* object, const char* ptr, ParseContext* ctx, bool (*is_valid)(int), + InternalMetadata* metadata, int field_num) { + return ctx->ReadPackedVarint( + ptr, [object, is_valid, metadata, field_num](uint64 val) { + if (is_valid(val)) { + static_cast*>(object)->Add(val); + } else { + WriteVarint(field_num, val, metadata->mutable_unknown_fields()); + } + }); +} + +template +PROTOBUF_MUST_USE_RESULT const char* PackedEnumParserArg( + void* object, const char* ptr, ParseContext* ctx, + bool (*is_valid)(const void*, int), const void* data, + InternalMetadata* metadata, int field_num) { + return ctx->ReadPackedVarint( + ptr, [object, is_valid, data, metadata, field_num](uint64 val) { + if (is_valid(data, val)) { + static_cast*>(object)->Add(val); + } else { + WriteVarint(field_num, val, metadata->mutable_unknown_fields()); + } + }); +} + +PROTOBUF_EXPORT PROTOBUF_MUST_USE_RESULT const char* PackedBoolParser( + void* object, const char* ptr, ParseContext* ctx); +PROTOBUF_EXPORT PROTOBUF_MUST_USE_RESULT const char* PackedFixed32Parser( + void* object, const char* ptr, ParseContext* ctx); +PROTOBUF_EXPORT PROTOBUF_MUST_USE_RESULT const char* PackedSFixed32Parser( + void* object, const char* ptr, ParseContext* ctx); +PROTOBUF_EXPORT PROTOBUF_MUST_USE_RESULT const char* PackedFixed64Parser( + void* object, const char* ptr, ParseContext* ctx); +PROTOBUF_EXPORT PROTOBUF_MUST_USE_RESULT const char* PackedSFixed64Parser( + void* object, const char* ptr, ParseContext* ctx); +PROTOBUF_EXPORT PROTOBUF_MUST_USE_RESULT const char* PackedFloatParser( + void* object, const char* ptr, ParseContext* ctx); +PROTOBUF_EXPORT PROTOBUF_MUST_USE_RESULT const char* PackedDoubleParser( + void* object, const char* ptr, ParseContext* ctx); + +// This is the only recursive parser. +PROTOBUF_EXPORT PROTOBUF_MUST_USE_RESULT const char* UnknownGroupLiteParse( + std::string* unknown, const char* ptr, ParseContext* ctx); +// This is a helper to for the UnknownGroupLiteParse but is actually also +// useful in the generated code. It uses overload on std::string* vs +// UnknownFieldSet* to make the generated code isomorphic between full and lite. +PROTOBUF_EXPORT PROTOBUF_MUST_USE_RESULT const char* UnknownFieldParse( + uint32 tag, std::string* unknown, const char* ptr, ParseContext* ctx); + +} // namespace internal +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_PARSE_CONTEXT_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/port.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/port.h new file mode 100644 index 0000000000000000000000000000000000000000..cfb7d4fcdf21632bf9a342aa72d3d5c15c8c1615 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/port.h @@ -0,0 +1,48 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// A common header that is included across all protobuf headers. We do our best +// to avoid #defining any macros here; instead we generally put macros in +// port_def.inc and port_undef.inc so they are not visible from outside of +// protobuf. + +#ifndef GOOGLE_PROTOBUF_PORT_H__ +#define GOOGLE_PROTOBUF_PORT_H__ + + +#include + + +#endif // GOOGLE_PROTOBUF_PORT_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/reflection.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/reflection.h new file mode 100644 index 0000000000000000000000000000000000000000..ff2da6f7f6b11af5f6fc0dce7be23fb67fc51345 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/reflection.h @@ -0,0 +1,568 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// This header defines the RepeatedFieldRef class template used to access +// repeated fields with protobuf reflection API. +#ifndef GOOGLE_PROTOBUF_REFLECTION_H__ +#define GOOGLE_PROTOBUF_REFLECTION_H__ + +#include + +#include +#include + +#ifdef SWIG +#error "You cannot SWIG proto headers" +#endif + +#include + +namespace google { +namespace protobuf { +namespace internal { +template +struct RefTypeTraits; +} // namespace internal + +template +RepeatedFieldRef Reflection::GetRepeatedFieldRef( + const Message& message, const FieldDescriptor* field) const { + return RepeatedFieldRef(message, field); +} + +template +MutableRepeatedFieldRef Reflection::GetMutableRepeatedFieldRef( + Message* message, const FieldDescriptor* field) const { + return MutableRepeatedFieldRef(message, field); +} + +// RepeatedFieldRef definition for non-message types. +template +class RepeatedFieldRef< + T, typename std::enable_if::value>::type> { + typedef typename internal::RefTypeTraits::iterator IteratorType; + typedef typename internal::RefTypeTraits::AccessorType AccessorType; + + public: + bool empty() const { return accessor_->IsEmpty(data_); } + int size() const { return accessor_->Size(data_); } + T Get(int index) const { return accessor_->template Get(data_, index); } + + typedef IteratorType iterator; + typedef IteratorType const_iterator; + typedef T value_type; + typedef T& reference; + typedef const T& const_reference; + typedef int size_type; + typedef ptrdiff_t difference_type; + + iterator begin() const { return iterator(data_, accessor_, true); } + iterator end() const { return iterator(data_, accessor_, false); } + + private: + friend class Reflection; + RepeatedFieldRef(const Message& message, const FieldDescriptor* field) { + const Reflection* reflection = message.GetReflection(); + data_ = reflection->RepeatedFieldData(const_cast(&message), field, + internal::RefTypeTraits::cpp_type, + NULL); + accessor_ = reflection->RepeatedFieldAccessor(field); + } + + const void* data_; + const AccessorType* accessor_; +}; + +// MutableRepeatedFieldRef definition for non-message types. +template +class MutableRepeatedFieldRef< + T, typename std::enable_if::value>::type> { + typedef typename internal::RefTypeTraits::AccessorType AccessorType; + + public: + bool empty() const { return accessor_->IsEmpty(data_); } + int size() const { return accessor_->Size(data_); } + T Get(int index) const { return accessor_->template Get(data_, index); } + + void Set(int index, const T& value) const { + accessor_->template Set(data_, index, value); + } + void Add(const T& value) const { accessor_->template Add(data_, value); } + void RemoveLast() const { accessor_->RemoveLast(data_); } + void SwapElements(int index1, int index2) const { + accessor_->SwapElements(data_, index1, index2); + } + void Clear() const { accessor_->Clear(data_); } + + void Swap(const MutableRepeatedFieldRef& other) const { + accessor_->Swap(data_, other.accessor_, other.data_); + } + + template + void MergeFrom(const Container& container) const { + typedef typename Container::const_iterator Iterator; + for (Iterator it = container.begin(); it != container.end(); ++it) { + Add(*it); + } + } + template + void CopyFrom(const Container& container) const { + Clear(); + MergeFrom(container); + } + + private: + friend class Reflection; + MutableRepeatedFieldRef(Message* message, const FieldDescriptor* field) { + const Reflection* reflection = message->GetReflection(); + data_ = reflection->RepeatedFieldData( + message, field, internal::RefTypeTraits::cpp_type, NULL); + accessor_ = reflection->RepeatedFieldAccessor(field); + } + + void* data_; + const AccessorType* accessor_; +}; + +// RepeatedFieldRef definition for message types. +template +class RepeatedFieldRef< + T, typename std::enable_if::value>::type> { + typedef typename internal::RefTypeTraits::iterator IteratorType; + typedef typename internal::RefTypeTraits::AccessorType AccessorType; + + public: + bool empty() const { return accessor_->IsEmpty(data_); } + int size() const { return accessor_->Size(data_); } + // This method returns a reference to the underlying message object if it + // exists. If a message object doesn't exist (e.g., data stored in serialized + // form), scratch_space will be filled with the data and a reference to it + // will be returned. + // + // Example: + // RepeatedFieldRef h = ... + // unique_ptr scratch_space(h.NewMessage()); + // const Message& item = h.Get(index, scratch_space.get()); + const T& Get(int index, T* scratch_space) const { + return *static_cast(accessor_->Get(data_, index, scratch_space)); + } + // Create a new message of the same type as the messages stored in this + // repeated field. Caller takes ownership of the returned object. + T* NewMessage() const { return static_cast(default_instance_->New()); } + + typedef IteratorType iterator; + typedef IteratorType const_iterator; + typedef T value_type; + typedef T& reference; + typedef const T& const_reference; + typedef int size_type; + typedef ptrdiff_t difference_type; + + iterator begin() const { + return iterator(data_, accessor_, true, NewMessage()); + } + iterator end() const { + // The end iterator must not be dereferenced, no need for scratch space. + return iterator(data_, accessor_, false, nullptr); + } + + private: + friend class Reflection; + RepeatedFieldRef(const Message& message, const FieldDescriptor* field) { + const Reflection* reflection = message.GetReflection(); + data_ = reflection->RepeatedFieldData( + const_cast(&message), field, + internal::RefTypeTraits::cpp_type, + internal::RefTypeTraits::GetMessageFieldDescriptor()); + accessor_ = reflection->RepeatedFieldAccessor(field); + default_instance_ = + reflection->GetMessageFactory()->GetPrototype(field->message_type()); + } + + const void* data_; + const AccessorType* accessor_; + const Message* default_instance_; +}; + +// MutableRepeatedFieldRef definition for message types. +template +class MutableRepeatedFieldRef< + T, typename std::enable_if::value>::type> { + typedef typename internal::RefTypeTraits::AccessorType AccessorType; + + public: + bool empty() const { return accessor_->IsEmpty(data_); } + int size() const { return accessor_->Size(data_); } + // See comments for RepeatedFieldRef::Get() + const T& Get(int index, T* scratch_space) const { + return *static_cast(accessor_->Get(data_, index, scratch_space)); + } + // Create a new message of the same type as the messages stored in this + // repeated field. Caller takes ownership of the returned object. + T* NewMessage() const { return static_cast(default_instance_->New()); } + + void Set(int index, const T& value) const { + accessor_->Set(data_, index, &value); + } + void Add(const T& value) const { accessor_->Add(data_, &value); } + void RemoveLast() const { accessor_->RemoveLast(data_); } + void SwapElements(int index1, int index2) const { + accessor_->SwapElements(data_, index1, index2); + } + void Clear() const { accessor_->Clear(data_); } + + void Swap(const MutableRepeatedFieldRef& other) const { + accessor_->Swap(data_, other.accessor_, other.data_); + } + + template + void MergeFrom(const Container& container) const { + typedef typename Container::const_iterator Iterator; + for (Iterator it = container.begin(); it != container.end(); ++it) { + Add(*it); + } + } + template + void CopyFrom(const Container& container) const { + Clear(); + MergeFrom(container); + } + + private: + friend class Reflection; + MutableRepeatedFieldRef(Message* message, const FieldDescriptor* field) { + const Reflection* reflection = message->GetReflection(); + data_ = reflection->RepeatedFieldData( + message, field, internal::RefTypeTraits::cpp_type, + internal::RefTypeTraits::GetMessageFieldDescriptor()); + accessor_ = reflection->RepeatedFieldAccessor(field); + default_instance_ = + reflection->GetMessageFactory()->GetPrototype(field->message_type()); + } + + void* data_; + const AccessorType* accessor_; + const Message* default_instance_; +}; + +namespace internal { +// Interfaces used to implement reflection RepeatedFieldRef API. +// Reflection::GetRepeatedAccessor() should return a pointer to an singleton +// object that implements the below interface. +// +// This interface passes/returns values using void pointers. The actual type +// of the value depends on the field's cpp_type. Following is a mapping from +// cpp_type to the type that should be used in this interface: +// +// field->cpp_type() T Actual type of void* +// CPPTYPE_INT32 int32 int32 +// CPPTYPE_UINT32 uint32 uint32 +// CPPTYPE_INT64 int64 int64 +// CPPTYPE_UINT64 uint64 uint64 +// CPPTYPE_DOUBLE double double +// CPPTYPE_FLOAT float float +// CPPTYPE_BOOL bool bool +// CPPTYPE_ENUM generated enum type int32 +// CPPTYPE_STRING string std::string +// CPPTYPE_MESSAGE generated message type google::protobuf::Message +// or google::protobuf::Message +// +// Note that for enums we use int32 in the interface. +// +// You can map from T to the actual type using RefTypeTraits: +// typedef RefTypeTraits::AccessorValueType ActualType; +class PROTOBUF_EXPORT RepeatedFieldAccessor { + public: + // Typedefs for clarity. + typedef void Field; + typedef void Value; + typedef void Iterator; + + virtual bool IsEmpty(const Field* data) const = 0; + virtual int Size(const Field* data) const = 0; + // Depends on the underlying representation of the repeated field, this + // method can return a pointer to the underlying object if such an object + // exists, or fill the data into scratch_space and return scratch_space. + // Callers of this method must ensure scratch_space is a valid pointer + // to a mutable object of the correct type. + virtual const Value* Get(const Field* data, int index, + Value* scratch_space) const = 0; + + virtual void Clear(Field* data) const = 0; + virtual void Set(Field* data, int index, const Value* value) const = 0; + virtual void Add(Field* data, const Value* value) const = 0; + virtual void RemoveLast(Field* data) const = 0; + virtual void SwapElements(Field* data, int index1, int index2) const = 0; + virtual void Swap(Field* data, const RepeatedFieldAccessor* other_mutator, + Field* other_data) const = 0; + + // Create an iterator that points at the beginning of the repeated field. + virtual Iterator* BeginIterator(const Field* data) const = 0; + // Create an iterator that points at the end of the repeated field. + virtual Iterator* EndIterator(const Field* data) const = 0; + // Make a copy of an iterator and return the new copy. + virtual Iterator* CopyIterator(const Field* data, + const Iterator* iterator) const = 0; + // Move an iterator to point to the next element. + virtual Iterator* AdvanceIterator(const Field* data, + Iterator* iterator) const = 0; + // Compare whether two iterators point to the same element. + virtual bool EqualsIterator(const Field* data, const Iterator* a, + const Iterator* b) const = 0; + // Delete an iterator created by BeginIterator(), EndIterator() and + // CopyIterator(). + virtual void DeleteIterator(const Field* data, Iterator* iterator) const = 0; + // Like Get() but for iterators. + virtual const Value* GetIteratorValue(const Field* data, + const Iterator* iterator, + Value* scratch_space) const = 0; + + // Templated methods that make using this interface easier for non-message + // types. + template + T Get(const Field* data, int index) const { + typedef typename RefTypeTraits::AccessorValueType ActualType; + ActualType scratch_space; + return static_cast(*reinterpret_cast( + Get(data, index, static_cast(&scratch_space)))); + } + + template + void Set(Field* data, int index, const ValueType& value) const { + typedef typename RefTypeTraits::AccessorValueType ActualType; + // In this RepeatedFieldAccessor interface we pass/return data using + // raw pointers. Type of the data these raw pointers point to should + // be ActualType. Here we have a ValueType object and want a ActualType + // pointer. We can't cast a ValueType pointer to an ActualType pointer + // directly because their type might be different (for enums ValueType + // may be a generated enum type while ActualType is int32). To be safe + // we make a copy to get a temporary ActualType object and use it. + ActualType tmp = static_cast(value); + Set(data, index, static_cast(&tmp)); + } + + template + void Add(Field* data, const ValueType& value) const { + typedef typename RefTypeTraits::AccessorValueType ActualType; + // In this RepeatedFieldAccessor interface we pass/return data using + // raw pointers. Type of the data these raw pointers point to should + // be ActualType. Here we have a ValueType object and want a ActualType + // pointer. We can't cast a ValueType pointer to an ActualType pointer + // directly because their type might be different (for enums ValueType + // may be a generated enum type while ActualType is int32). To be safe + // we make a copy to get a temporary ActualType object and use it. + ActualType tmp = static_cast(value); + Add(data, static_cast(&tmp)); + } + + protected: + // We want the destructor to be completely trivial as to allow it to be + // a function local static. Hence we make it non-virtual and protected, + // this class only live as part of a global singleton and should not be + // deleted. + ~RepeatedFieldAccessor() = default; +}; + +// Implement (Mutable)RepeatedFieldRef::iterator +template +class RepeatedFieldRefIterator + : public std::iterator { + typedef typename RefTypeTraits::AccessorValueType AccessorValueType; + typedef typename RefTypeTraits::IteratorValueType IteratorValueType; + typedef typename RefTypeTraits::IteratorPointerType IteratorPointerType; + + public: + // Constructor for non-message fields. + RepeatedFieldRefIterator(const void* data, + const RepeatedFieldAccessor* accessor, bool begin) + : data_(data), + accessor_(accessor), + iterator_(begin ? accessor->BeginIterator(data) + : accessor->EndIterator(data)), + // The end iterator must not be dereferenced, no need for scratch space. + scratch_space_(begin ? new AccessorValueType : nullptr) {} + // Constructor for message fields. + RepeatedFieldRefIterator(const void* data, + const RepeatedFieldAccessor* accessor, bool begin, + AccessorValueType* scratch_space) + : data_(data), + accessor_(accessor), + iterator_(begin ? accessor->BeginIterator(data) + : accessor->EndIterator(data)), + scratch_space_(scratch_space) {} + ~RepeatedFieldRefIterator() { accessor_->DeleteIterator(data_, iterator_); } + RepeatedFieldRefIterator operator++(int) { + RepeatedFieldRefIterator tmp(*this); + iterator_ = accessor_->AdvanceIterator(data_, iterator_); + return tmp; + } + RepeatedFieldRefIterator& operator++() { + iterator_ = accessor_->AdvanceIterator(data_, iterator_); + return *this; + } + IteratorValueType operator*() const { + return static_cast( + *static_cast(accessor_->GetIteratorValue( + data_, iterator_, scratch_space_.get()))); + } + IteratorPointerType operator->() const { + return static_cast( + accessor_->GetIteratorValue(data_, iterator_, scratch_space_.get())); + } + bool operator!=(const RepeatedFieldRefIterator& other) const { + assert(data_ == other.data_); + assert(accessor_ == other.accessor_); + return !accessor_->EqualsIterator(data_, iterator_, other.iterator_); + } + bool operator==(const RepeatedFieldRefIterator& other) const { + return !this->operator!=(other); + } + + RepeatedFieldRefIterator(const RepeatedFieldRefIterator& other) + : data_(other.data_), + accessor_(other.accessor_), + iterator_(accessor_->CopyIterator(data_, other.iterator_)) {} + RepeatedFieldRefIterator& operator=(const RepeatedFieldRefIterator& other) { + if (this != &other) { + accessor_->DeleteIterator(data_, iterator_); + data_ = other.data_; + accessor_ = other.accessor_; + iterator_ = accessor_->CopyIterator(data_, other.iterator_); + } + return *this; + } + + protected: + const void* data_; + const RepeatedFieldAccessor* accessor_; + void* iterator_; + std::unique_ptr scratch_space_; +}; + +// TypeTraits that maps the type parameter T of RepeatedFieldRef or +// MutableRepeatedFieldRef to corresponding iterator type, +// RepeatedFieldAccessor type, etc. +template +struct PrimitiveTraits { + static constexpr bool is_primitive = false; +}; +#define DEFINE_PRIMITIVE(TYPE, type) \ + template <> \ + struct PrimitiveTraits { \ + static const bool is_primitive = true; \ + static const FieldDescriptor::CppType cpp_type = \ + FieldDescriptor::CPPTYPE_##TYPE; \ + }; +DEFINE_PRIMITIVE(INT32, int32) +DEFINE_PRIMITIVE(UINT32, uint32) +DEFINE_PRIMITIVE(INT64, int64) +DEFINE_PRIMITIVE(UINT64, uint64) +DEFINE_PRIMITIVE(FLOAT, float) +DEFINE_PRIMITIVE(DOUBLE, double) +DEFINE_PRIMITIVE(BOOL, bool) +#undef DEFINE_PRIMITIVE + +template +struct RefTypeTraits< + T, typename std::enable_if::is_primitive>::type> { + typedef RepeatedFieldRefIterator iterator; + typedef RepeatedFieldAccessor AccessorType; + typedef T AccessorValueType; + typedef T IteratorValueType; + typedef T* IteratorPointerType; + static constexpr FieldDescriptor::CppType cpp_type = + PrimitiveTraits::cpp_type; + static const Descriptor* GetMessageFieldDescriptor() { return NULL; } +}; + +template +struct RefTypeTraits< + T, typename std::enable_if::value>::type> { + typedef RepeatedFieldRefIterator iterator; + typedef RepeatedFieldAccessor AccessorType; + // We use int32 for repeated enums in RepeatedFieldAccessor. + typedef int32 AccessorValueType; + typedef T IteratorValueType; + typedef int32* IteratorPointerType; + static constexpr FieldDescriptor::CppType cpp_type = + FieldDescriptor::CPPTYPE_ENUM; + static const Descriptor* GetMessageFieldDescriptor() { return NULL; } +}; + +template +struct RefTypeTraits< + T, typename std::enable_if::value>::type> { + typedef RepeatedFieldRefIterator iterator; + typedef RepeatedFieldAccessor AccessorType; + typedef std::string AccessorValueType; + typedef const std::string IteratorValueType; + typedef const std::string* IteratorPointerType; + static constexpr FieldDescriptor::CppType cpp_type = + FieldDescriptor::CPPTYPE_STRING; + static const Descriptor* GetMessageFieldDescriptor() { return NULL; } +}; + +template +struct MessageDescriptorGetter { + static const Descriptor* get() { + return T::default_instance().GetDescriptor(); + } +}; +template <> +struct MessageDescriptorGetter { + static const Descriptor* get() { return NULL; } +}; + +template +struct RefTypeTraits< + T, typename std::enable_if::value>::type> { + typedef RepeatedFieldRefIterator iterator; + typedef RepeatedFieldAccessor AccessorType; + typedef Message AccessorValueType; + typedef const T& IteratorValueType; + typedef const T* IteratorPointerType; + static constexpr FieldDescriptor::CppType cpp_type = + FieldDescriptor::CPPTYPE_MESSAGE; + static const Descriptor* GetMessageFieldDescriptor() { + return MessageDescriptorGetter::get(); + } +}; +} // namespace internal +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_REFLECTION_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/reflection_ops.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/reflection_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..857faa035fb8c2f0501dd69eac03a24b70f401a9 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/reflection_ops.h @@ -0,0 +1,96 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: kenton@google.com (Kenton Varda) +// Based on original Protocol Buffers design by +// Sanjay Ghemawat, Jeff Dean, and others. +// +// This header is logically internal, but is made public because it is used +// from protocol-compiler-generated code, which may reside in other components. + +#ifndef GOOGLE_PROTOBUF_REFLECTION_OPS_H__ +#define GOOGLE_PROTOBUF_REFLECTION_OPS_H__ + +#include +#include + +#ifdef SWIG +#error "You cannot SWIG proto headers" +#endif + +#include + +namespace google { +namespace protobuf { +namespace internal { + +// Basic operations that can be performed using reflection. +// These can be used as a cheap way to implement the corresponding +// methods of the Message interface, though they are likely to be +// slower than implementations tailored for the specific message type. +// +// This class should stay limited to operations needed to implement +// the Message interface. +// +// This class is really a namespace that contains only static methods. +class PROTOBUF_EXPORT ReflectionOps { + public: + static void Copy(const Message& from, Message* to); + static void Merge(const Message& from, Message* to); + static void Clear(Message* message); + static bool IsInitialized(const Message& message); + static bool IsInitialized(const Message& message, bool check_fields, + bool check_descendants); + static void DiscardUnknownFields(Message* message); + + // Finds all unset required fields in the message and adds their full + // paths (e.g. "foo.bar[5].baz") to *names. "prefix" will be attached to + // the front of each name. + static void FindInitializationErrors(const Message& message, + const std::string& prefix, + std::vector* errors); + + private: + // All methods are static. No need to construct. + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(ReflectionOps); +}; + +} // namespace internal +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_REFLECTION_OPS_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/repeated_field.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/repeated_field.h new file mode 100644 index 0000000000000000000000000000000000000000..23fc61b92ea04e40bf3666a25c41d93b32183711 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/repeated_field.h @@ -0,0 +1,2853 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: kenton@google.com (Kenton Varda) +// Based on original Protocol Buffers design by +// Sanjay Ghemawat, Jeff Dean, and others. +// +// RepeatedField and RepeatedPtrField are used by generated protocol message +// classes to manipulate repeated fields. These classes are very similar to +// STL's vector, but include a number of optimizations found to be useful +// specifically in the case of Protocol Buffers. RepeatedPtrField is +// particularly different from STL vector as it manages ownership of the +// pointers that it contains. +// +// Typically, clients should not need to access RepeatedField objects directly, +// but should instead use the accessor functions generated automatically by the +// protocol compiler. + +#ifndef GOOGLE_PROTOBUF_REPEATED_FIELD_H__ +#define GOOGLE_PROTOBUF_REPEATED_FIELD_H__ + +#include +#ifdef _MSC_VER +// This is required for min/max on VS2013 only. +#include +#endif + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + + +// Must be included last. +#include + +#ifdef SWIG +#error "You cannot SWIG proto headers" +#endif + +namespace google { +namespace protobuf { + +class Message; +class Reflection; + +template +struct WeakRepeatedPtrField; + +namespace internal { + +class MergePartialFromCodedStreamHelper; + +// kRepeatedFieldLowerClampLimit is the smallest size that will be allocated +// when growing a repeated field. +constexpr int kRepeatedFieldLowerClampLimit = 4; + +// kRepeatedFieldUpperClampLimit is the lowest signed integer value that +// overflows when multiplied by 2 (which is undefined behavior). Sizes above +// this will clamp to the maximum int value instead of following exponential +// growth when growing a repeated field. +constexpr int kRepeatedFieldUpperClampLimit = + (std::numeric_limits::max() / 2) + 1; + +// A utility function for logging that doesn't need any template types. +void LogIndexOutOfBounds(int index, int size); + +template +inline int CalculateReserve(Iter begin, Iter end, std::forward_iterator_tag) { + return static_cast(std::distance(begin, end)); +} + +template +inline int CalculateReserve(Iter /*begin*/, Iter /*end*/, + std::input_iterator_tag /*unused*/) { + return -1; +} + +template +inline int CalculateReserve(Iter begin, Iter end) { + typedef typename std::iterator_traits::iterator_category Category; + return CalculateReserve(begin, end, Category()); +} + +// Swaps two blocks of memory of size sizeof(T). +template +inline void SwapBlock(char* p, char* q) { + T tmp; + memcpy(&tmp, p, sizeof(T)); + memcpy(p, q, sizeof(T)); + memcpy(q, &tmp, sizeof(T)); +} + +// Swaps two blocks of memory of size kSize: +// template void memswap(char* p, char* q); + +template +inline typename std::enable_if<(kSize == 0), void>::type memswap(char*, char*) { +} + +#define PROTO_MEMSWAP_DEF_SIZE(reg_type, max_size) \ + template \ + typename std::enable_if<(kSize >= sizeof(reg_type) && kSize < (max_size)), \ + void>::type \ + memswap(char* p, char* q) { \ + SwapBlock(p, q); \ + memswap(p + sizeof(reg_type), \ + q + sizeof(reg_type)); \ + } + +PROTO_MEMSWAP_DEF_SIZE(uint8, 2) +PROTO_MEMSWAP_DEF_SIZE(uint16, 4) +PROTO_MEMSWAP_DEF_SIZE(uint32, 8) + +#ifdef __SIZEOF_INT128__ +PROTO_MEMSWAP_DEF_SIZE(uint64, 16) +PROTO_MEMSWAP_DEF_SIZE(__uint128_t, (1u << 31)) +#else +PROTO_MEMSWAP_DEF_SIZE(uint64, (1u << 31)) +#endif + +#undef PROTO_MEMSWAP_DEF_SIZE + +} // namespace internal + +// RepeatedField is used to represent repeated fields of a primitive type (in +// other words, everything except strings and nested Messages). Most users will +// not ever use a RepeatedField directly; they will use the get-by-index, +// set-by-index, and add accessors that are generated for all repeated fields. +template +class RepeatedField final { + static_assert( + alignof(Arena) >= alignof(Element), + "We only support types that have an alignment smaller than Arena"); + + public: + RepeatedField(); + explicit RepeatedField(Arena* arena); + RepeatedField(const RepeatedField& other); + template + RepeatedField(Iter begin, const Iter& end); + ~RepeatedField(); + + RepeatedField& operator=(const RepeatedField& other); + + RepeatedField(RepeatedField&& other) noexcept; + RepeatedField& operator=(RepeatedField&& other) noexcept; + + bool empty() const; + int size() const; + + const Element& Get(int index) const; + Element* Mutable(int index); + + const Element& operator[](int index) const { return Get(index); } + Element& operator[](int index) { return *Mutable(index); } + + const Element& at(int index) const; + Element& at(int index); + + void Set(int index, const Element& value); + void Add(const Element& value); + // Appends a new element and return a pointer to it. + // The new element is uninitialized if |Element| is a POD type. + Element* Add(); + // Append elements in the range [begin, end) after reserving + // the appropriate number of elements. + template + void Add(Iter begin, Iter end); + + // Remove the last element in the array. + void RemoveLast(); + + // Extract elements with indices in "[start .. start+num-1]". + // Copy them into "elements[0 .. num-1]" if "elements" is not NULL. + // Caution: implementation also moves elements with indices [start+num ..]. + // Calling this routine inside a loop can cause quadratic behavior. + void ExtractSubrange(int start, int num, Element* elements); + + void Clear(); + void MergeFrom(const RepeatedField& other); + void CopyFrom(const RepeatedField& other); + + // Reserve space to expand the field to at least the given size. If the + // array is grown, it will always be at least doubled in size. + void Reserve(int new_size); + + // Resize the RepeatedField to a new, smaller size. This is O(1). + void Truncate(int new_size); + + void AddAlreadyReserved(const Element& value); + // Appends a new element and return a pointer to it. + // The new element is uninitialized if |Element| is a POD type. + // Should be called only if Capacity() > Size(). + Element* AddAlreadyReserved(); + Element* AddNAlreadyReserved(int elements); + int Capacity() const; + + // Like STL resize. Uses value to fill appended elements. + // Like Truncate() if new_size <= size(), otherwise this is + // O(new_size - size()). + void Resize(int new_size, const Element& value); + + // Gets the underlying array. This pointer is possibly invalidated by + // any add or remove operation. + Element* mutable_data(); + const Element* data() const; + + // Swap entire contents with "other". If they are separate arenas then, copies + // data between each other. + void Swap(RepeatedField* other); + + // Swap entire contents with "other". Should be called only if the caller can + // guarantee that both repeated fields are on the same arena or are on the + // heap. Swapping between different arenas is disallowed and caught by a + // GOOGLE_DCHECK (see API docs for details). + void UnsafeArenaSwap(RepeatedField* other); + + // Swap two elements. + void SwapElements(int index1, int index2); + + // STL-like iterator support + typedef Element* iterator; + typedef const Element* const_iterator; + typedef Element value_type; + typedef value_type& reference; + typedef const value_type& const_reference; + typedef value_type* pointer; + typedef const value_type* const_pointer; + typedef int size_type; + typedef ptrdiff_t difference_type; + + iterator begin(); + const_iterator begin() const; + const_iterator cbegin() const; + iterator end(); + const_iterator end() const; + const_iterator cend() const; + + // Reverse iterator support + typedef std::reverse_iterator const_reverse_iterator; + typedef std::reverse_iterator reverse_iterator; + reverse_iterator rbegin() { return reverse_iterator(end()); } + const_reverse_iterator rbegin() const { + return const_reverse_iterator(end()); + } + reverse_iterator rend() { return reverse_iterator(begin()); } + const_reverse_iterator rend() const { + return const_reverse_iterator(begin()); + } + + // Returns the number of bytes used by the repeated field, excluding + // sizeof(*this) + size_t SpaceUsedExcludingSelfLong() const; + + int SpaceUsedExcludingSelf() const { + return internal::ToIntSize(SpaceUsedExcludingSelfLong()); + } + + // Removes the element referenced by position. + // + // Returns an iterator to the element immediately following the removed + // element. + // + // Invalidates all iterators at or after the removed element, including end(). + iterator erase(const_iterator position); + + // Removes the elements in the range [first, last). + // + // Returns an iterator to the element immediately following the removed range. + // + // Invalidates all iterators at or after the removed range, including end(). + iterator erase(const_iterator first, const_iterator last); + + // Get the Arena on which this RepeatedField stores its elements. + inline Arena* GetArena() const { + return (total_size_ == 0) ? static_cast(arena_or_elements_) + : rep()->arena; + } + + // For internal use only. + // + // This is public due to it being called by generated code. + inline void InternalSwap(RepeatedField* other); + + private: + static constexpr int kInitialSize = 0; + // A note on the representation here (see also comment below for + // RepeatedPtrFieldBase's struct Rep): + // + // We maintain the same sizeof(RepeatedField) as before we added arena support + // so that we do not degrade performance by bloating memory usage. Directly + // adding an arena_ element to RepeatedField is quite costly. By using + // indirection in this way, we keep the same size when the RepeatedField is + // empty (common case), and add only an 8-byte header to the elements array + // when non-empty. We make sure to place the size fields directly in the + // RepeatedField class to avoid costly cache misses due to the indirection. + int current_size_; + int total_size_; + struct Rep { + Arena* arena; + Element elements[1]; + }; + // We can not use sizeof(Rep) - sizeof(Element) due to the trailing padding on + // the struct. We can not use sizeof(Arena*) as well because there might be + // a "gap" after the field arena and before the field elements (e.g., when + // Element is double and pointer is 32bit). + static const size_t kRepHeaderSize; + + // If total_size_ == 0 this points to an Arena otherwise it points to the + // elements member of a Rep struct. Using this invariant allows the storage of + // the arena pointer without an extra allocation in the constructor. + void* arena_or_elements_; + + // Return pointer to elements array. + // pre-condition: the array must have been allocated. + Element* elements() const { + GOOGLE_DCHECK_GT(total_size_, 0); + // Because of above pre-condition this cast is safe. + return unsafe_elements(); + } + + // Return pointer to elements array if it exists otherwise either null or + // a invalid pointer is returned. This only happens for empty repeated fields, + // where you can't dereference this pointer anyway (it's empty). + Element* unsafe_elements() const { + return static_cast(arena_or_elements_); + } + + // Return pointer to the Rep struct. + // pre-condition: the Rep must have been allocated, ie elements() is safe. + Rep* rep() const { + char* addr = reinterpret_cast(elements()) - offsetof(Rep, elements); + return reinterpret_cast(addr); + } + + friend class Arena; + typedef void InternalArenaConstructable_; + + // Move the contents of |from| into |to|, possibly clobbering |from| in the + // process. For primitive types this is just a memcpy(), but it could be + // specialized for non-primitive types to, say, swap each element instead. + void MoveArray(Element* to, Element* from, int size); + + // Copy the elements of |from| into |to|. + void CopyArray(Element* to, const Element* from, int size); + + // Internal helper to delete all elements and deallocate the storage. + // If Element has a trivial destructor (for example, if it's a fundamental + // type, like int32), the loop will be removed by the optimizer. + void InternalDeallocate(Rep* rep, int size) { + if (rep != NULL) { + Element* e = &rep->elements[0]; + Element* limit = &rep->elements[size]; + for (; e < limit; e++) { + e->~Element(); + } + if (rep->arena == NULL) { +#if defined(__GXX_DELETE_WITH_SIZE__) || defined(__cpp_sized_deallocation) + const size_t bytes = size * sizeof(*e) + kRepHeaderSize; + ::operator delete(static_cast(rep), bytes); +#else + ::operator delete(static_cast(rep)); +#endif + } + } + } + + // This class is a performance wrapper around RepeatedField::Add(const T&) + // function. In general unless a RepeatedField is a local stack variable LLVM + // has a hard time optimizing Add. The machine code tends to be + // loop: + // mov %size, dword ptr [%repeated_field] // load + // cmp %size, dword ptr [%repeated_field + 4] + // jae fallback + // mov %buffer, qword ptr [%repeated_field + 8] + // mov dword [%buffer + %size * 4], %value + // inc %size // increment + // mov dword ptr [%repeated_field], %size // store + // jmp loop + // + // This puts a load/store in each iteration of the important loop variable + // size. It's a pretty bad compile that happens even in simple cases, but + // largely the presence of the fallback path disturbs the compilers mem-to-reg + // analysis. + // + // This class takes ownership of a repeated field for the duration of it's + // lifetime. The repeated field should not be accessed during this time, ie. + // only access through this class is allowed. This class should always be a + // function local stack variable. Intended use + // + // void AddSequence(const int* begin, const int* end, RepeatedField* out) + // { + // RepeatedFieldAdder adder(out); // Take ownership of out + // for (auto it = begin; it != end; ++it) { + // adder.Add(*it); + // } + // } + // + // Typically due to the fact adder is a local stack variable. The compiler + // will be successful in mem-to-reg transformation and the machine code will + // be loop: cmp %size, %capacity jae fallback mov dword ptr [%buffer + %size * + // 4], %val inc %size jmp loop + // + // The first version executes at 7 cycles per iteration while the second + // version near 1 or 2 cycles. + template ::value> + class FastAdderImpl { + public: + explicit FastAdderImpl(RepeatedField* rf) : repeated_field_(rf) { + index_ = repeated_field_->current_size_; + capacity_ = repeated_field_->total_size_; + buffer_ = repeated_field_->unsafe_elements(); + } + ~FastAdderImpl() { repeated_field_->current_size_ = index_; } + + void Add(Element val) { + if (index_ == capacity_) { + repeated_field_->current_size_ = index_; + repeated_field_->Reserve(index_ + 1); + capacity_ = repeated_field_->total_size_; + buffer_ = repeated_field_->unsafe_elements(); + } + buffer_[index_++] = val; + } + + private: + RepeatedField* repeated_field_; + int index_; + int capacity_; + Element* buffer_; + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(FastAdderImpl); + }; + + // FastAdder is a wrapper for adding fields. The specialization above handles + // POD types more efficiently than RepeatedField. + template + class FastAdderImpl { + public: + explicit FastAdderImpl(RepeatedField* rf) : repeated_field_(rf) {} + void Add(const Element& val) { repeated_field_->Add(val); } + + private: + RepeatedField* repeated_field_; + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(FastAdderImpl); + }; + + using FastAdder = FastAdderImpl<>; + + friend class TestRepeatedFieldHelper; + friend class ::google::protobuf::internal::ParseContext; +}; + +template +const size_t RepeatedField::kRepHeaderSize = + reinterpret_cast(&reinterpret_cast(16)->elements[0]) - 16; + +namespace internal { +template +class RepeatedPtrIterator; +template +class RepeatedPtrOverPtrsIterator; +} // namespace internal + +namespace internal { + +// This is a helper template to copy an array of elements efficiently when they +// have a trivial copy constructor, and correctly otherwise. This really +// shouldn't be necessary, but our compiler doesn't optimize std::copy very +// effectively. +template ::value> +struct ElementCopier { + void operator()(Element* to, const Element* from, int array_size); +}; + +} // namespace internal + +namespace internal { + +// type-traits helper for RepeatedPtrFieldBase: we only want to invoke +// arena-related "copy if on different arena" behavior if the necessary methods +// exist on the contained type. In particular, we rely on MergeFrom() existing +// as a general proxy for the fact that a copy will work, and we also provide a +// specific override for std::string*. +template +struct TypeImplementsMergeBehaviorProbeForMergeFrom { + typedef char HasMerge; + typedef long HasNoMerge; + + // We accept either of: + // - void MergeFrom(const T& other) + // - bool MergeFrom(const T& other) + // + // We mangle these names a bit to avoid compatibility issues in 'unclean' + // include environments that may have, e.g., "#define test ..." (yes, this + // exists). + template + struct CheckType; + template + static HasMerge Check(CheckType*); + template + static HasMerge Check(CheckType*); + template + static HasNoMerge Check(...); + + // Resolves to either std::true_type or std::false_type. + typedef std::integral_constant(0)) == sizeof(HasMerge))> + type; +}; + +template +struct TypeImplementsMergeBehavior + : TypeImplementsMergeBehaviorProbeForMergeFrom {}; + + +template <> +struct TypeImplementsMergeBehavior { + typedef std::true_type type; +}; + +template +struct IsMovable + : std::integral_constant::value && + std::is_move_assignable::value> {}; + +// This is the common base class for RepeatedPtrFields. It deals only in void* +// pointers. Users should not use this interface directly. +// +// The methods of this interface correspond to the methods of RepeatedPtrField, +// but may have a template argument called TypeHandler. Its signature is: +// class TypeHandler { +// public: +// typedef MyType Type; +// static Type* New(); +// static Type* NewFromPrototype(const Type* prototype, +// Arena* arena); +// static void Delete(Type*); +// static void Clear(Type*); +// static void Merge(const Type& from, Type* to); +// +// // Only needs to be implemented if SpaceUsedExcludingSelf() is called. +// static int SpaceUsedLong(const Type&); +// }; +class PROTOBUF_EXPORT RepeatedPtrFieldBase { + protected: + RepeatedPtrFieldBase(); + explicit RepeatedPtrFieldBase(Arena* arena); + ~RepeatedPtrFieldBase() { +#ifndef NDEBUG + // Try to trigger segfault / asan failure in non-opt builds. If arena_ + // lifetime has ended before the destructor. + if (arena_) (void)arena_->SpaceAllocated(); +#endif + } + + public: + // Must be called from destructor. + template + void Destroy(); + + protected: + bool empty() const; + int size() const; + + template + const typename TypeHandler::Type& at(int index) const; + template + typename TypeHandler::Type& at(int index); + + template + typename TypeHandler::Type* Mutable(int index); + template + void Delete(int index); + template + typename TypeHandler::Type* Add(typename TypeHandler::Type* prototype = NULL); + + public: + // The next few methods are public so that they can be called from generated + // code when implicit weak fields are used, but they should never be called by + // application code. + + template + const typename TypeHandler::Type& Get(int index) const; + + // Creates and adds an element using the given prototype, without introducing + // a link-time dependency on the concrete message type. This method is used to + // implement implicit weak fields. The prototype may be NULL, in which case an + // ImplicitWeakMessage will be used as a placeholder. + MessageLite* AddWeak(const MessageLite* prototype); + + template + void Clear(); + + template + void MergeFrom(const RepeatedPtrFieldBase& other); + + inline void InternalSwap(RepeatedPtrFieldBase* other); + + protected: + template < + typename TypeHandler, + typename std::enable_if::type* = nullptr> + void Add(typename TypeHandler::Type&& value); + + template + void RemoveLast(); + template + void CopyFrom(const RepeatedPtrFieldBase& other); + + void CloseGap(int start, int num); + + void Reserve(int new_size); + + int Capacity() const; + + // Used for constructing iterators. + void* const* raw_data() const; + void** raw_mutable_data() const; + + template + typename TypeHandler::Type** mutable_data(); + template + const typename TypeHandler::Type* const* data() const; + + template + PROTOBUF_ALWAYS_INLINE void Swap(RepeatedPtrFieldBase* other); + + void SwapElements(int index1, int index2); + + template + size_t SpaceUsedExcludingSelfLong() const; + + // Advanced memory management -------------------------------------- + + // Like Add(), but if there are no cleared objects to use, returns NULL. + template + typename TypeHandler::Type* AddFromCleared(); + + template + void AddAllocated(typename TypeHandler::Type* value) { + typename TypeImplementsMergeBehavior::type t; + AddAllocatedInternal(value, t); + } + + template + void UnsafeArenaAddAllocated(typename TypeHandler::Type* value); + + template + typename TypeHandler::Type* ReleaseLast() { + typename TypeImplementsMergeBehavior::type t; + return ReleaseLastInternal(t); + } + + // Releases last element and returns it, but does not do out-of-arena copy. + // And just returns the raw pointer to the contained element in the arena. + template + typename TypeHandler::Type* UnsafeArenaReleaseLast(); + + int ClearedCount() const; + template + void AddCleared(typename TypeHandler::Type* value); + template + typename TypeHandler::Type* ReleaseCleared(); + + template + void AddAllocatedInternal(typename TypeHandler::Type* value, std::true_type); + template + void AddAllocatedInternal(typename TypeHandler::Type* value, std::false_type); + + template + PROTOBUF_NOINLINE void AddAllocatedSlowWithCopy( + typename TypeHandler::Type* value, Arena* value_arena, Arena* my_arena); + template + PROTOBUF_NOINLINE void AddAllocatedSlowWithoutCopy( + typename TypeHandler::Type* value); + + template + typename TypeHandler::Type* ReleaseLastInternal(std::true_type); + template + typename TypeHandler::Type* ReleaseLastInternal(std::false_type); + + template + PROTOBUF_NOINLINE void SwapFallback(RepeatedPtrFieldBase* other); + + inline Arena* GetArena() const { return arena_; } + + private: + static constexpr int kInitialSize = 0; + // A few notes on internal representation: + // + // We use an indirected approach, with struct Rep, to keep + // sizeof(RepeatedPtrFieldBase) equivalent to what it was before arena support + // was added, namely, 3 8-byte machine words on x86-64. An instance of Rep is + // allocated only when the repeated field is non-empty, and it is a + // dynamically-sized struct (the header is directly followed by elements[]). + // We place arena_ and current_size_ directly in the object to avoid cache + // misses due to the indirection, because these fields are checked frequently. + // Placing all fields directly in the RepeatedPtrFieldBase instance costs + // significant performance for memory-sensitive workloads. + Arena* arena_; + int current_size_; + int total_size_; + struct Rep { + int allocated_size; + void* elements[1]; + }; + static constexpr size_t kRepHeaderSize = sizeof(Rep) - sizeof(void*); + Rep* rep_; + + template + static inline typename TypeHandler::Type* cast(void* element) { + return reinterpret_cast(element); + } + template + static inline const typename TypeHandler::Type* cast(const void* element) { + return reinterpret_cast(element); + } + + // Non-templated inner function to avoid code duplication. Takes a function + // pointer to the type-specific (templated) inner allocate/merge loop. + void MergeFromInternal(const RepeatedPtrFieldBase& other, + void (RepeatedPtrFieldBase::*inner_loop)(void**, + void**, int, + int)); + + template + void MergeFromInnerLoop(void** our_elems, void** other_elems, int length, + int already_allocated); + + // Internal helper: extend array space if necessary to contain |extend_amount| + // more elements, and return a pointer to the element immediately following + // the old list of elements. This interface factors out common behavior from + // Reserve() and MergeFrom() to reduce code size. |extend_amount| must be > 0. + void** InternalExtend(int extend_amount); + + // The reflection implementation needs to call protected methods directly, + // reinterpreting pointers as being to Message instead of a specific Message + // subclass. + friend class ::PROTOBUF_NAMESPACE_ID::Reflection; + + // ExtensionSet stores repeated message extensions as + // RepeatedPtrField, but non-lite ExtensionSets need to implement + // SpaceUsedLong(), and thus need to call SpaceUsedExcludingSelfLong() + // reinterpreting MessageLite as Message. ExtensionSet also needs to make use + // of AddFromCleared(), which is not part of the public interface. + friend class ExtensionSet; + + // The MapFieldBase implementation needs to call protected methods directly, + // reinterpreting pointers as being to Message instead of a specific Message + // subclass. + friend class MapFieldBase; + + // The table-driven MergePartialFromCodedStream implementation needs to + // operate on RepeatedPtrField. + friend class MergePartialFromCodedStreamHelper; + friend class AccessorHelper; + template + friend struct google::protobuf::WeakRepeatedPtrField; + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(RepeatedPtrFieldBase); +}; + +template +class GenericTypeHandler { + public: + typedef GenericType Type; + using Movable = IsMovable; + + static inline GenericType* New(Arena* arena) { + return Arena::CreateMaybeMessage(arena); + } + static inline GenericType* New(Arena* arena, GenericType&& value) { + return Arena::Create(arena, std::move(value)); + } + static inline GenericType* NewFromPrototype(const GenericType* prototype, + Arena* arena = NULL); + static inline void Delete(GenericType* value, Arena* arena) { + if (arena == NULL) { + delete value; + } + } + static inline Arena* GetArena(GenericType* value) { + return Arena::GetArena(value); + } + static inline void* GetMaybeArenaPointer(GenericType* value) { + return Arena::GetArena(value); + } + + static inline void Clear(GenericType* value) { value->Clear(); } + PROTOBUF_NOINLINE + static void Merge(const GenericType& from, GenericType* to); + static inline size_t SpaceUsedLong(const GenericType& value) { + return value.SpaceUsedLong(); + } +}; + +template +GenericType* GenericTypeHandler::NewFromPrototype( + const GenericType* /* prototype */, Arena* arena) { + return New(arena); +} +template +void GenericTypeHandler::Merge(const GenericType& from, + GenericType* to) { + to->MergeFrom(from); +} + +// NewFromPrototype() and Merge() are not defined inline here, as we will need +// to do a virtual function dispatch anyways to go from Message* to call +// New/Merge. +template <> +MessageLite* GenericTypeHandler::NewFromPrototype( + const MessageLite* prototype, Arena* arena); +template <> +inline Arena* GenericTypeHandler::GetArena(MessageLite* value) { + return value->GetArena(); +} +template <> +inline void* GenericTypeHandler::GetMaybeArenaPointer( + MessageLite* value) { + return value->GetMaybeArenaPointer(); +} +template <> +void GenericTypeHandler::Merge(const MessageLite& from, + MessageLite* to); +template <> +inline void GenericTypeHandler::Clear(std::string* value) { + value->clear(); +} +template <> +void GenericTypeHandler::Merge(const std::string& from, + std::string* to); + +// Message specialization bodies defined in message.cc. This split is necessary +// to allow proto2-lite (which includes this header) to be independent of +// Message. +template <> +PROTOBUF_EXPORT Message* GenericTypeHandler::NewFromPrototype( + const Message* prototype, Arena* arena); +template <> +PROTOBUF_EXPORT Arena* GenericTypeHandler::GetArena(Message* value); +template <> +PROTOBUF_EXPORT void* GenericTypeHandler::GetMaybeArenaPointer( + Message* value); + +class StringTypeHandler { + public: + typedef std::string Type; + using Movable = IsMovable; + + static inline std::string* New(Arena* arena) { + return Arena::Create(arena); + } + static inline std::string* New(Arena* arena, std::string&& value) { + return Arena::Create(arena, std::move(value)); + } + static inline std::string* NewFromPrototype(const std::string*, + Arena* arena) { + return New(arena); + } + static inline Arena* GetArena(std::string*) { return NULL; } + static inline void* GetMaybeArenaPointer(std::string* /* value */) { + return NULL; + } + static inline void Delete(std::string* value, Arena* arena) { + if (arena == NULL) { + delete value; + } + } + static inline void Clear(std::string* value) { value->clear(); } + static inline void Merge(const std::string& from, std::string* to) { + *to = from; + } + static size_t SpaceUsedLong(const std::string& value) { + return sizeof(value) + StringSpaceUsedExcludingSelfLong(value); + } +}; + +} // namespace internal + +// RepeatedPtrField is like RepeatedField, but used for repeated strings or +// Messages. +template +class RepeatedPtrField final : private internal::RepeatedPtrFieldBase { + public: + RepeatedPtrField(); + explicit RepeatedPtrField(Arena* arena); + + RepeatedPtrField(const RepeatedPtrField& other); + template + RepeatedPtrField(Iter begin, const Iter& end); + ~RepeatedPtrField(); + + RepeatedPtrField& operator=(const RepeatedPtrField& other); + + RepeatedPtrField(RepeatedPtrField&& other) noexcept; + RepeatedPtrField& operator=(RepeatedPtrField&& other) noexcept; + + bool empty() const; + int size() const; + + const Element& Get(int index) const; + Element* Mutable(int index); + Element* Add(); + void Add(Element&& value); + + const Element& operator[](int index) const { return Get(index); } + Element& operator[](int index) { return *Mutable(index); } + + const Element& at(int index) const; + Element& at(int index); + + // Remove the last element in the array. + // Ownership of the element is retained by the array. + void RemoveLast(); + + // Delete elements with indices in the range [start .. start+num-1]. + // Caution: implementation moves all elements with indices [start+num .. ]. + // Calling this routine inside a loop can cause quadratic behavior. + void DeleteSubrange(int start, int num); + + void Clear(); + void MergeFrom(const RepeatedPtrField& other); + void CopyFrom(const RepeatedPtrField& other); + + // Reserve space to expand the field to at least the given size. This only + // resizes the pointer array; it doesn't allocate any objects. If the + // array is grown, it will always be at least doubled in size. + void Reserve(int new_size); + + int Capacity() const; + + // Gets the underlying array. This pointer is possibly invalidated by + // any add or remove operation. + Element** mutable_data(); + const Element* const* data() const; + + // Swap entire contents with "other". If they are on separate arenas, then + // copies data. + void Swap(RepeatedPtrField* other); + + // Swap entire contents with "other". Caller should guarantee that either both + // fields are on the same arena or both are on the heap. Swapping between + // different arenas with this function is disallowed and is caught via + // GOOGLE_DCHECK. + void UnsafeArenaSwap(RepeatedPtrField* other); + + // Swap two elements. + void SwapElements(int index1, int index2); + + // STL-like iterator support + typedef internal::RepeatedPtrIterator iterator; + typedef internal::RepeatedPtrIterator const_iterator; + typedef Element value_type; + typedef value_type& reference; + typedef const value_type& const_reference; + typedef value_type* pointer; + typedef const value_type* const_pointer; + typedef int size_type; + typedef ptrdiff_t difference_type; + + iterator begin(); + const_iterator begin() const; + const_iterator cbegin() const; + iterator end(); + const_iterator end() const; + const_iterator cend() const; + + // Reverse iterator support + typedef std::reverse_iterator const_reverse_iterator; + typedef std::reverse_iterator reverse_iterator; + reverse_iterator rbegin() { return reverse_iterator(end()); } + const_reverse_iterator rbegin() const { + return const_reverse_iterator(end()); + } + reverse_iterator rend() { return reverse_iterator(begin()); } + const_reverse_iterator rend() const { + return const_reverse_iterator(begin()); + } + + // Custom STL-like iterator that iterates over and returns the underlying + // pointers to Element rather than Element itself. + typedef internal::RepeatedPtrOverPtrsIterator + pointer_iterator; + typedef internal::RepeatedPtrOverPtrsIterator + const_pointer_iterator; + pointer_iterator pointer_begin(); + const_pointer_iterator pointer_begin() const; + pointer_iterator pointer_end(); + const_pointer_iterator pointer_end() const; + + // Returns (an estimate of) the number of bytes used by the repeated field, + // excluding sizeof(*this). + size_t SpaceUsedExcludingSelfLong() const; + + int SpaceUsedExcludingSelf() const { + return internal::ToIntSize(SpaceUsedExcludingSelfLong()); + } + + // Advanced memory management -------------------------------------- + // When hardcore memory management becomes necessary -- as it sometimes + // does here at Google -- the following methods may be useful. + + // Add an already-allocated object, passing ownership to the + // RepeatedPtrField. + // + // Note that some special behavior occurs with respect to arenas: + // + // (i) if this field holds submessages, the new submessage will be copied if + // the original is in an arena and this RepeatedPtrField is either in a + // different arena, or on the heap. + // (ii) if this field holds strings, the passed-in string *must* be + // heap-allocated, not arena-allocated. There is no way to dynamically check + // this at runtime, so User Beware. + void AddAllocated(Element* value); + + // Remove the last element and return it, passing ownership to the caller. + // Requires: size() > 0 + // + // If this RepeatedPtrField is on an arena, an object copy is required to pass + // ownership back to the user (for compatible semantics). Use + // UnsafeArenaReleaseLast() if this behavior is undesired. + Element* ReleaseLast(); + + // Add an already-allocated object, skipping arena-ownership checks. The user + // must guarantee that the given object is in the same arena as this + // RepeatedPtrField. + // It is also useful in legacy code that uses temporary ownership to avoid + // copies. Example: + // RepeatedPtrField temp_field; + // temp_field.AddAllocated(new T); + // ... // Do something with temp_field + // temp_field.ExtractSubrange(0, temp_field.size(), nullptr); + // If you put temp_field on the arena this fails, because the ownership + // transfers to the arena at the "AddAllocated" call and is not released + // anymore causing a double delete. UnsafeArenaAddAllocated prevents this. + void UnsafeArenaAddAllocated(Element* value); + + // Remove the last element and return it. Works only when operating on an + // arena. The returned pointer is to the original object in the arena, hence + // has the arena's lifetime. + // Requires: current_size_ > 0 + Element* UnsafeArenaReleaseLast(); + + // Extract elements with indices in the range "[start .. start+num-1]". + // The caller assumes ownership of the extracted elements and is responsible + // for deleting them when they are no longer needed. + // If "elements" is non-NULL, then pointers to the extracted elements + // are stored in "elements[0 .. num-1]" for the convenience of the caller. + // If "elements" is NULL, then the caller must use some other mechanism + // to perform any further operations (like deletion) on these elements. + // Caution: implementation also moves elements with indices [start+num ..]. + // Calling this routine inside a loop can cause quadratic behavior. + // + // Memory copying behavior is identical to ReleaseLast(), described above: if + // this RepeatedPtrField is on an arena, an object copy is performed for each + // returned element, so that all returned element pointers are to + // heap-allocated copies. If this copy is not desired, the user should call + // UnsafeArenaExtractSubrange(). + void ExtractSubrange(int start, int num, Element** elements); + + // Identical to ExtractSubrange() described above, except that when this + // repeated field is on an arena, no object copies are performed. Instead, the + // raw object pointers are returned. Thus, if on an arena, the returned + // objects must not be freed, because they will not be heap-allocated objects. + void UnsafeArenaExtractSubrange(int start, int num, Element** elements); + + // When elements are removed by calls to RemoveLast() or Clear(), they + // are not actually freed. Instead, they are cleared and kept so that + // they can be reused later. This can save lots of CPU time when + // repeatedly reusing a protocol message for similar purposes. + // + // Hardcore programs may choose to manipulate these cleared objects + // to better optimize memory management using the following routines. + + // Get the number of cleared objects that are currently being kept + // around for reuse. + int ClearedCount() const; + // Add an element to the pool of cleared objects, passing ownership to + // the RepeatedPtrField. The element must be cleared prior to calling + // this method. + // + // This method cannot be called when the repeated field is on an arena or when + // |value| is; both cases will trigger a GOOGLE_DCHECK-failure. + void AddCleared(Element* value); + // Remove a single element from the cleared pool and return it, passing + // ownership to the caller. The element is guaranteed to be cleared. + // Requires: ClearedCount() > 0 + // + // + // This method cannot be called when the repeated field is on an arena; doing + // so will trigger a GOOGLE_DCHECK-failure. + Element* ReleaseCleared(); + + // Removes the element referenced by position. + // + // Returns an iterator to the element immediately following the removed + // element. + // + // Invalidates all iterators at or after the removed element, including end(). + iterator erase(const_iterator position); + + // Removes the elements in the range [first, last). + // + // Returns an iterator to the element immediately following the removed range. + // + // Invalidates all iterators at or after the removed range, including end(). + iterator erase(const_iterator first, const_iterator last); + + // Gets the arena on which this RepeatedPtrField stores its elements. + inline Arena* GetArena() const; + + // For internal use only. + // + // This is public due to it being called by generated code. + void InternalSwap(RepeatedPtrField* other) { + internal::RepeatedPtrFieldBase::InternalSwap(other); + } + + private: + // Note: RepeatedPtrField SHOULD NOT be subclassed by users. + class TypeHandler; + + // Implementations for ExtractSubrange(). The copying behavior must be + // included only if the type supports the necessary operations (e.g., + // MergeFrom()), so we must resolve this at compile time. ExtractSubrange() + // uses SFINAE to choose one of the below implementations. + void ExtractSubrangeInternal(int start, int num, Element** elements, + std::true_type); + void ExtractSubrangeInternal(int start, int num, Element** elements, + std::false_type); + + friend class Arena; + + template + friend struct WeakRepeatedPtrField; + + typedef void InternalArenaConstructable_; + +}; + +// implementation ==================================================== + +template +inline RepeatedField::RepeatedField() + : current_size_(0), total_size_(0), arena_or_elements_(nullptr) {} + +template +inline RepeatedField::RepeatedField(Arena* arena) + : current_size_(0), total_size_(0), arena_or_elements_(arena) {} + +template +inline RepeatedField::RepeatedField(const RepeatedField& other) + : current_size_(0), total_size_(0), arena_or_elements_(nullptr) { + if (other.current_size_ != 0) { + Reserve(other.size()); + AddNAlreadyReserved(other.size()); + CopyArray(Mutable(0), &other.Get(0), other.size()); + } +} + +template +template +RepeatedField::RepeatedField(Iter begin, const Iter& end) + : current_size_(0), total_size_(0), arena_or_elements_(nullptr) { + Add(begin, end); +} + +template +RepeatedField::~RepeatedField() { + if (total_size_ > 0) { + InternalDeallocate(rep(), total_size_); + } +} + +template +inline RepeatedField& RepeatedField::operator=( + const RepeatedField& other) { + if (this != &other) CopyFrom(other); + return *this; +} + +template +inline RepeatedField::RepeatedField(RepeatedField&& other) noexcept + : RepeatedField() { + // We don't just call Swap(&other) here because it would perform 3 copies if + // other is on an arena. This field can't be on an arena because arena + // construction always uses the Arena* accepting constructor. + if (other.GetArena()) { + CopyFrom(other); + } else { + InternalSwap(&other); + } +} + +template +inline RepeatedField& RepeatedField::operator=( + RepeatedField&& other) noexcept { + // We don't just call Swap(&other) here because it would perform 3 copies if + // the two fields are on different arenas. + if (this != &other) { + if (this->GetArena() != other.GetArena()) { + CopyFrom(other); + } else { + InternalSwap(&other); + } + } + return *this; +} + +template +inline bool RepeatedField::empty() const { + return current_size_ == 0; +} + +template +inline int RepeatedField::size() const { + return current_size_; +} + +template +inline int RepeatedField::Capacity() const { + return total_size_; +} + +template +inline void RepeatedField::AddAlreadyReserved(const Element& value) { + GOOGLE_DCHECK_LT(current_size_, total_size_); + elements()[current_size_++] = value; +} + +template +inline Element* RepeatedField::AddAlreadyReserved() { + GOOGLE_DCHECK_LT(current_size_, total_size_); + return &elements()[current_size_++]; +} + +template +inline Element* RepeatedField::AddNAlreadyReserved(int n) { + GOOGLE_DCHECK_GE(total_size_ - current_size_, n) + << total_size_ << ", " << current_size_; + // Warning: sometimes people call this when n == 0 and total_size_ == 0. In + // this case the return pointer points to a zero size array (n == 0). Hence + // we can just use unsafe_elements(), because the user cannot dereference the + // pointer anyway. + Element* ret = unsafe_elements() + current_size_; + current_size_ += n; + return ret; +} + +template +inline void RepeatedField::Resize(int new_size, const Element& value) { + GOOGLE_DCHECK_GE(new_size, 0); + if (new_size > current_size_) { + Reserve(new_size); + std::fill(&elements()[current_size_], &elements()[new_size], value); + } + current_size_ = new_size; +} + +template +inline const Element& RepeatedField::Get(int index) const { + GOOGLE_DCHECK_GE(index, 0); + GOOGLE_DCHECK_LT(index, current_size_); + return elements()[index]; +} + +template +inline const Element& RepeatedField::at(int index) const { + GOOGLE_CHECK_GE(index, 0); + GOOGLE_CHECK_LT(index, current_size_); + return elements()[index]; +} + +template +inline Element& RepeatedField::at(int index) { + GOOGLE_CHECK_GE(index, 0); + GOOGLE_CHECK_LT(index, current_size_); + return elements()[index]; +} + +template +inline Element* RepeatedField::Mutable(int index) { + GOOGLE_DCHECK_GE(index, 0); + GOOGLE_DCHECK_LT(index, current_size_); + return &elements()[index]; +} + +template +inline void RepeatedField::Set(int index, const Element& value) { + GOOGLE_DCHECK_GE(index, 0); + GOOGLE_DCHECK_LT(index, current_size_); + elements()[index] = value; +} + +template +inline void RepeatedField::Add(const Element& value) { + uint32 size = current_size_; + if (static_cast(size) == total_size_) { + // value could reference an element of the array. Reserving new space will + // invalidate the reference. So we must make a copy first. + auto tmp = value; + Reserve(total_size_ + 1); + elements()[size] = std::move(tmp); + } else { + elements()[size] = value; + } + current_size_ = size + 1; +} + +template +inline Element* RepeatedField::Add() { + uint32 size = current_size_; + if (static_cast(size) == total_size_) Reserve(total_size_ + 1); + auto ptr = &elements()[size]; + current_size_ = size + 1; + return ptr; +} + +template +template +inline void RepeatedField::Add(Iter begin, Iter end) { + int reserve = internal::CalculateReserve(begin, end); + if (reserve != -1) { + if (reserve == 0) { + return; + } + + Reserve(reserve + size()); + // TODO(ckennelly): The compiler loses track of the buffer freshly + // allocated by Reserve() by the time we call elements, so it cannot + // guarantee that elements does not alias [begin(), end()). + // + // If restrict is available, annotating the pointer obtained from elements() + // causes this to lower to memcpy instead of memmove. + std::copy(begin, end, elements() + size()); + current_size_ = reserve + size(); + } else { + FastAdder fast_adder(this); + for (; begin != end; ++begin) fast_adder.Add(*begin); + } +} + +template +inline void RepeatedField::RemoveLast() { + GOOGLE_DCHECK_GT(current_size_, 0); + current_size_--; +} + +template +void RepeatedField::ExtractSubrange(int start, int num, + Element* elements) { + GOOGLE_DCHECK_GE(start, 0); + GOOGLE_DCHECK_GE(num, 0); + GOOGLE_DCHECK_LE(start + num, this->current_size_); + + // Save the values of the removed elements if requested. + if (elements != NULL) { + for (int i = 0; i < num; ++i) elements[i] = this->Get(i + start); + } + + // Slide remaining elements down to fill the gap. + if (num > 0) { + for (int i = start + num; i < this->current_size_; ++i) + this->Set(i - num, this->Get(i)); + this->Truncate(this->current_size_ - num); + } +} + +template +inline void RepeatedField::Clear() { + current_size_ = 0; +} + +template +inline void RepeatedField::MergeFrom(const RepeatedField& other) { + GOOGLE_DCHECK_NE(&other, this); + if (other.current_size_ != 0) { + int existing_size = size(); + Reserve(existing_size + other.size()); + AddNAlreadyReserved(other.size()); + CopyArray(Mutable(existing_size), &other.Get(0), other.size()); + } +} + +template +inline void RepeatedField::CopyFrom(const RepeatedField& other) { + if (&other == this) return; + Clear(); + MergeFrom(other); +} + +template +inline typename RepeatedField::iterator RepeatedField::erase( + const_iterator position) { + return erase(position, position + 1); +} + +template +inline typename RepeatedField::iterator RepeatedField::erase( + const_iterator first, const_iterator last) { + size_type first_offset = first - cbegin(); + if (first != last) { + Truncate(std::copy(last, cend(), begin() + first_offset) - cbegin()); + } + return begin() + first_offset; +} + +template +inline Element* RepeatedField::mutable_data() { + return unsafe_elements(); +} + +template +inline const Element* RepeatedField::data() const { + return unsafe_elements(); +} + +template +inline void RepeatedField::InternalSwap(RepeatedField* other) { + GOOGLE_DCHECK(this != other); + GOOGLE_DCHECK(GetArena() == other->GetArena()); + + // Swap all fields at once. + static_assert(std::is_standard_layout>::value, + "offsetof() requires standard layout before c++17"); + internal::memswaparena_or_elements_) - + offsetof(RepeatedField, current_size_)>( + reinterpret_cast(this) + offsetof(RepeatedField, current_size_), + reinterpret_cast(other) + offsetof(RepeatedField, current_size_)); +} + +template +void RepeatedField::Swap(RepeatedField* other) { + if (this == other) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + RepeatedField temp(other->GetArena()); + temp.MergeFrom(*this); + CopyFrom(*other); + other->UnsafeArenaSwap(&temp); + } +} + +template +void RepeatedField::UnsafeArenaSwap(RepeatedField* other) { + if (this == other) return; + InternalSwap(other); +} + +template +void RepeatedField::SwapElements(int index1, int index2) { + using std::swap; // enable ADL with fallback + swap(elements()[index1], elements()[index2]); +} + +template +inline typename RepeatedField::iterator +RepeatedField::begin() { + return unsafe_elements(); +} +template +inline typename RepeatedField::const_iterator +RepeatedField::begin() const { + return unsafe_elements(); +} +template +inline typename RepeatedField::const_iterator +RepeatedField::cbegin() const { + return unsafe_elements(); +} +template +inline typename RepeatedField::iterator RepeatedField::end() { + return unsafe_elements() + current_size_; +} +template +inline typename RepeatedField::const_iterator +RepeatedField::end() const { + return unsafe_elements() + current_size_; +} +template +inline typename RepeatedField::const_iterator +RepeatedField::cend() const { + return unsafe_elements() + current_size_; +} + +template +inline size_t RepeatedField::SpaceUsedExcludingSelfLong() const { + return total_size_ > 0 ? (total_size_ * sizeof(Element) + kRepHeaderSize) : 0; +} + +namespace internal { +// Returns the new size for a reserved field based on its 'total_size' and the +// requested 'new_size'. The result is clamped to the closed interval: +// [internal::kMinRepeatedFieldAllocationSize, +// std::numeric_limits::max()] +// Requires: +// new_size > total_size && +// (total_size == 0 || +// total_size >= kRepeatedFieldLowerClampLimit) +inline int CalculateReserveSize(int total_size, int new_size) { + if (new_size < kRepeatedFieldLowerClampLimit) { + // Clamp to smallest allowed size. + return kRepeatedFieldLowerClampLimit; + } + if (total_size < kRepeatedFieldUpperClampLimit) { + return std::max(total_size * 2, new_size); + } else { + // Clamp to largest allowed size. + GOOGLE_DCHECK_GT(new_size, kRepeatedFieldUpperClampLimit); + return std::numeric_limits::max(); + } +} +} // namespace internal + +// Avoid inlining of Reserve(): new, copy, and delete[] lead to a significant +// amount of code bloat. +template +void RepeatedField::Reserve(int new_size) { + if (total_size_ >= new_size) return; + Rep* old_rep = total_size_ > 0 ? rep() : NULL; + Rep* new_rep; + Arena* arena = GetArena(); + new_size = internal::CalculateReserveSize(total_size_, new_size); + GOOGLE_DCHECK_LE( + static_cast(new_size), + (std::numeric_limits::max() - kRepHeaderSize) / sizeof(Element)) + << "Requested size is too large to fit into size_t."; + size_t bytes = + kRepHeaderSize + sizeof(Element) * static_cast(new_size); + if (arena == NULL) { + new_rep = static_cast(::operator new(bytes)); + } else { + new_rep = reinterpret_cast(Arena::CreateArray(arena, bytes)); + } + new_rep->arena = arena; + int old_total_size = total_size_; + // Already known: new_size >= internal::kMinRepeatedFieldAllocationSize + // Maintain invariant: + // total_size_ == 0 || + // total_size_ >= internal::kMinRepeatedFieldAllocationSize + total_size_ = new_size; + arena_or_elements_ = new_rep->elements; + // Invoke placement-new on newly allocated elements. We shouldn't have to do + // this, since Element is supposed to be POD, but a previous version of this + // code allocated storage with "new Element[size]" and some code uses + // RepeatedField with non-POD types, relying on constructor invocation. If + // Element has a trivial constructor (e.g., int32), gcc (tested with -O2) + // completely removes this loop because the loop body is empty, so this has no + // effect unless its side-effects are required for correctness. + // Note that we do this before MoveArray() below because Element's copy + // assignment implementation will want an initialized instance first. + Element* e = &elements()[0]; + Element* limit = e + total_size_; + for (; e < limit; e++) { + new (e) Element; + } + if (current_size_ > 0) { + MoveArray(&elements()[0], old_rep->elements, current_size_); + } + + // Likewise, we need to invoke destructors on the old array. + InternalDeallocate(old_rep, old_total_size); + +} + +template +inline void RepeatedField::Truncate(int new_size) { + GOOGLE_DCHECK_LE(new_size, current_size_); + if (current_size_ > 0) { + current_size_ = new_size; + } +} + +template +inline void RepeatedField::MoveArray(Element* to, Element* from, + int array_size) { + CopyArray(to, from, array_size); +} + +template +inline void RepeatedField::CopyArray(Element* to, const Element* from, + int array_size) { + internal::ElementCopier()(to, from, array_size); +} + +namespace internal { + +template +void ElementCopier::operator()(Element* to, + const Element* from, + int array_size) { + std::copy(from, from + array_size, to); +} + +template +struct ElementCopier { + void operator()(Element* to, const Element* from, int array_size) { + memcpy(to, from, static_cast(array_size) * sizeof(Element)); + } +}; + +} // namespace internal + + +// ------------------------------------------------------------------- + +namespace internal { + +inline RepeatedPtrFieldBase::RepeatedPtrFieldBase() + : arena_(NULL), current_size_(0), total_size_(0), rep_(NULL) {} + +inline RepeatedPtrFieldBase::RepeatedPtrFieldBase(Arena* arena) + : arena_(arena), current_size_(0), total_size_(0), rep_(NULL) {} + +template +void RepeatedPtrFieldBase::Destroy() { + if (rep_ != NULL && arena_ == NULL) { + int n = rep_->allocated_size; + void* const* elements = rep_->elements; + for (int i = 0; i < n; i++) { + TypeHandler::Delete(cast(elements[i]), NULL); + } +#if defined(__GXX_DELETE_WITH_SIZE__) || defined(__cpp_sized_deallocation) + const size_t size = total_size_ * sizeof(elements[0]) + kRepHeaderSize; + ::operator delete(static_cast(rep_), size); +#else + ::operator delete(static_cast(rep_)); +#endif + } + rep_ = NULL; +} + +template +inline void RepeatedPtrFieldBase::Swap(RepeatedPtrFieldBase* other) { + if (other->GetArena() == GetArena()) { + InternalSwap(other); + } else { + SwapFallback(other); + } +} + +template +void RepeatedPtrFieldBase::SwapFallback(RepeatedPtrFieldBase* other) { + GOOGLE_DCHECK(other->GetArena() != GetArena()); + + // Copy semantics in this case. We try to improve efficiency by placing the + // temporary on |other|'s arena so that messages are copied twice rather than + // three times. + RepeatedPtrFieldBase temp(other->GetArena()); + temp.MergeFrom(*this); + this->Clear(); + this->MergeFrom(*other); + other->InternalSwap(&temp); + temp.Destroy(); // Frees rep_ if `other` had no arena. +} + +inline bool RepeatedPtrFieldBase::empty() const { return current_size_ == 0; } + +inline int RepeatedPtrFieldBase::size() const { return current_size_; } + +template +inline const typename TypeHandler::Type& RepeatedPtrFieldBase::Get( + int index) const { + GOOGLE_DCHECK_GE(index, 0); + GOOGLE_DCHECK_LT(index, current_size_); + return *cast(rep_->elements[index]); +} + +template +inline const typename TypeHandler::Type& RepeatedPtrFieldBase::at( + int index) const { + GOOGLE_CHECK_GE(index, 0); + GOOGLE_CHECK_LT(index, current_size_); + return *cast(rep_->elements[index]); +} + +template +inline typename TypeHandler::Type& RepeatedPtrFieldBase::at(int index) { + GOOGLE_CHECK_GE(index, 0); + GOOGLE_CHECK_LT(index, current_size_); + return *cast(rep_->elements[index]); +} + +template +inline typename TypeHandler::Type* RepeatedPtrFieldBase::Mutable(int index) { + GOOGLE_DCHECK_GE(index, 0); + GOOGLE_DCHECK_LT(index, current_size_); + return cast(rep_->elements[index]); +} + +template +inline void RepeatedPtrFieldBase::Delete(int index) { + GOOGLE_DCHECK_GE(index, 0); + GOOGLE_DCHECK_LT(index, current_size_); + TypeHandler::Delete(cast(rep_->elements[index]), arena_); +} + +template +inline typename TypeHandler::Type* RepeatedPtrFieldBase::Add( + typename TypeHandler::Type* prototype) { + if (rep_ != NULL && current_size_ < rep_->allocated_size) { + return cast(rep_->elements[current_size_++]); + } + if (!rep_ || rep_->allocated_size == total_size_) { + Reserve(total_size_ + 1); + } + ++rep_->allocated_size; + typename TypeHandler::Type* result = + TypeHandler::NewFromPrototype(prototype, arena_); + rep_->elements[current_size_++] = result; + return result; +} + +template ::type*> +inline void RepeatedPtrFieldBase::Add(typename TypeHandler::Type&& value) { + if (rep_ != NULL && current_size_ < rep_->allocated_size) { + *cast(rep_->elements[current_size_++]) = std::move(value); + return; + } + if (!rep_ || rep_->allocated_size == total_size_) { + Reserve(total_size_ + 1); + } + ++rep_->allocated_size; + typename TypeHandler::Type* result = + TypeHandler::New(arena_, std::move(value)); + rep_->elements[current_size_++] = result; +} + +template +inline void RepeatedPtrFieldBase::RemoveLast() { + GOOGLE_DCHECK_GT(current_size_, 0); + TypeHandler::Clear(cast(rep_->elements[--current_size_])); +} + +template +void RepeatedPtrFieldBase::Clear() { + const int n = current_size_; + GOOGLE_DCHECK_GE(n, 0); + if (n > 0) { + void* const* elements = rep_->elements; + int i = 0; + do { + TypeHandler::Clear(cast(elements[i++])); + } while (i < n); + current_size_ = 0; + } +} + +// To avoid unnecessary code duplication and reduce binary size, we use a +// layered approach to implementing MergeFrom(). The toplevel method is +// templated, so we get a small thunk per concrete message type in the binary. +// This calls a shared implementation with most of the logic, passing a function +// pointer to another type-specific piece of code that calls the object-allocate +// and merge handlers. +template +inline void RepeatedPtrFieldBase::MergeFrom(const RepeatedPtrFieldBase& other) { + GOOGLE_DCHECK_NE(&other, this); + if (other.current_size_ == 0) return; + MergeFromInternal(other, + &RepeatedPtrFieldBase::MergeFromInnerLoop); +} + +inline void RepeatedPtrFieldBase::MergeFromInternal( + const RepeatedPtrFieldBase& other, + void (RepeatedPtrFieldBase::*inner_loop)(void**, void**, int, int)) { + // Note: wrapper has already guaranteed that other.rep_ != NULL here. + int other_size = other.current_size_; + void** other_elements = other.rep_->elements; + void** new_elements = InternalExtend(other_size); + int allocated_elems = rep_->allocated_size - current_size_; + (this->*inner_loop)(new_elements, other_elements, other_size, + allocated_elems); + current_size_ += other_size; + if (rep_->allocated_size < current_size_) { + rep_->allocated_size = current_size_; + } +} + +// Merges other_elems to our_elems. +template +void RepeatedPtrFieldBase::MergeFromInnerLoop(void** our_elems, + void** other_elems, int length, + int already_allocated) { + // Split into two loops, over ranges [0, allocated) and [allocated, length), + // to avoid a branch within the loop. + for (int i = 0; i < already_allocated && i < length; i++) { + // Already allocated: use existing element. + typename TypeHandler::Type* other_elem = + reinterpret_cast(other_elems[i]); + typename TypeHandler::Type* new_elem = + reinterpret_cast(our_elems[i]); + TypeHandler::Merge(*other_elem, new_elem); + } + Arena* arena = GetArena(); + for (int i = already_allocated; i < length; i++) { + // Not allocated: alloc a new element first, then merge it. + typename TypeHandler::Type* other_elem = + reinterpret_cast(other_elems[i]); + typename TypeHandler::Type* new_elem = + TypeHandler::NewFromPrototype(other_elem, arena); + TypeHandler::Merge(*other_elem, new_elem); + our_elems[i] = new_elem; + } +} + +template +inline void RepeatedPtrFieldBase::CopyFrom(const RepeatedPtrFieldBase& other) { + if (&other == this) return; + RepeatedPtrFieldBase::Clear(); + RepeatedPtrFieldBase::MergeFrom(other); +} + +inline int RepeatedPtrFieldBase::Capacity() const { return total_size_; } + +inline void* const* RepeatedPtrFieldBase::raw_data() const { + return rep_ ? rep_->elements : NULL; +} + +inline void** RepeatedPtrFieldBase::raw_mutable_data() const { + return rep_ ? const_cast(rep_->elements) : NULL; +} + +template +inline typename TypeHandler::Type** RepeatedPtrFieldBase::mutable_data() { + // TODO(kenton): Breaks C++ aliasing rules. We should probably remove this + // method entirely. + return reinterpret_cast(raw_mutable_data()); +} + +template +inline const typename TypeHandler::Type* const* RepeatedPtrFieldBase::data() + const { + // TODO(kenton): Breaks C++ aliasing rules. We should probably remove this + // method entirely. + return reinterpret_cast(raw_data()); +} + +inline void RepeatedPtrFieldBase::SwapElements(int index1, int index2) { + using std::swap; // enable ADL with fallback + swap(rep_->elements[index1], rep_->elements[index2]); +} + +template +inline size_t RepeatedPtrFieldBase::SpaceUsedExcludingSelfLong() const { + size_t allocated_bytes = static_cast(total_size_) * sizeof(void*); + if (rep_ != NULL) { + for (int i = 0; i < rep_->allocated_size; ++i) { + allocated_bytes += + TypeHandler::SpaceUsedLong(*cast(rep_->elements[i])); + } + allocated_bytes += kRepHeaderSize; + } + return allocated_bytes; +} + +template +inline typename TypeHandler::Type* RepeatedPtrFieldBase::AddFromCleared() { + if (rep_ != NULL && current_size_ < rep_->allocated_size) { + return cast(rep_->elements[current_size_++]); + } else { + return NULL; + } +} + +// AddAllocated version that implements arena-safe copying behavior. +template +void RepeatedPtrFieldBase::AddAllocatedInternal( + typename TypeHandler::Type* value, std::true_type) { + Arena* element_arena = + reinterpret_cast(TypeHandler::GetMaybeArenaPointer(value)); + Arena* arena = GetArena(); + if (arena == element_arena && rep_ && rep_->allocated_size < total_size_) { + // Fast path: underlying arena representation (tagged pointer) is equal to + // our arena pointer, and we can add to array without resizing it (at least + // one slot that is not allocated). + void** elems = rep_->elements; + if (current_size_ < rep_->allocated_size) { + // Make space at [current] by moving first allocated element to end of + // allocated list. + elems[rep_->allocated_size] = elems[current_size_]; + } + elems[current_size_] = value; + current_size_ = current_size_ + 1; + rep_->allocated_size = rep_->allocated_size + 1; + } else { + AddAllocatedSlowWithCopy(value, TypeHandler::GetArena(value), + arena); + } +} + +// Slowpath handles all cases, copying if necessary. +template +void RepeatedPtrFieldBase::AddAllocatedSlowWithCopy( + // Pass value_arena and my_arena to avoid duplicate virtual call (value) or + // load (mine). + typename TypeHandler::Type* value, Arena* value_arena, Arena* my_arena) { + // Ensure that either the value is in the same arena, or if not, we do the + // appropriate thing: Own() it (if it's on heap and we're in an arena) or copy + // it to our arena/heap (otherwise). + if (my_arena != NULL && value_arena == NULL) { + my_arena->Own(value); + } else if (my_arena != value_arena) { + typename TypeHandler::Type* new_value = + TypeHandler::NewFromPrototype(value, my_arena); + TypeHandler::Merge(*value, new_value); + TypeHandler::Delete(value, value_arena); + value = new_value; + } + + UnsafeArenaAddAllocated(value); +} + +// AddAllocated version that does not implement arena-safe copying behavior. +template +void RepeatedPtrFieldBase::AddAllocatedInternal( + typename TypeHandler::Type* value, std::false_type) { + if (rep_ && rep_->allocated_size < total_size_) { + // Fast path: underlying arena representation (tagged pointer) is equal to + // our arena pointer, and we can add to array without resizing it (at least + // one slot that is not allocated). + void** elems = rep_->elements; + if (current_size_ < rep_->allocated_size) { + // Make space at [current] by moving first allocated element to end of + // allocated list. + elems[rep_->allocated_size] = elems[current_size_]; + } + elems[current_size_] = value; + current_size_ = current_size_ + 1; + ++rep_->allocated_size; + } else { + UnsafeArenaAddAllocated(value); + } +} + +template +void RepeatedPtrFieldBase::UnsafeArenaAddAllocated( + typename TypeHandler::Type* value) { + // Make room for the new pointer. + if (!rep_ || current_size_ == total_size_) { + // The array is completely full with no cleared objects, so grow it. + Reserve(total_size_ + 1); + ++rep_->allocated_size; + } else if (rep_->allocated_size == total_size_) { + // There is no more space in the pointer array because it contains some + // cleared objects awaiting reuse. We don't want to grow the array in this + // case because otherwise a loop calling AddAllocated() followed by Clear() + // would leak memory. + TypeHandler::Delete(cast(rep_->elements[current_size_]), + arena_); + } else if (current_size_ < rep_->allocated_size) { + // We have some cleared objects. We don't care about their order, so we + // can just move the first one to the end to make space. + rep_->elements[rep_->allocated_size] = rep_->elements[current_size_]; + ++rep_->allocated_size; + } else { + // There are no cleared objects. + ++rep_->allocated_size; + } + + rep_->elements[current_size_++] = value; +} + +// ReleaseLast() for types that implement merge/copy behavior. +template +inline typename TypeHandler::Type* RepeatedPtrFieldBase::ReleaseLastInternal( + std::true_type) { + // First, release an element. + typename TypeHandler::Type* result = UnsafeArenaReleaseLast(); + // Now perform a copy if we're on an arena. + Arena* arena = GetArena(); + if (arena == NULL) { + return result; + } else { + typename TypeHandler::Type* new_result = + TypeHandler::NewFromPrototype(result, NULL); + TypeHandler::Merge(*result, new_result); + return new_result; + } +} + +// ReleaseLast() for types that *do not* implement merge/copy behavior -- this +// is the same as UnsafeArenaReleaseLast(). Note that we GOOGLE_DCHECK-fail if we're on +// an arena, since the user really should implement the copy operation in this +// case. +template +inline typename TypeHandler::Type* RepeatedPtrFieldBase::ReleaseLastInternal( + std::false_type) { + GOOGLE_DCHECK(GetArena() == NULL) + << "ReleaseLast() called on a RepeatedPtrField that is on an arena, " + << "with a type that does not implement MergeFrom. This is unsafe; " + << "please implement MergeFrom for your type."; + return UnsafeArenaReleaseLast(); +} + +template +inline typename TypeHandler::Type* +RepeatedPtrFieldBase::UnsafeArenaReleaseLast() { + GOOGLE_DCHECK_GT(current_size_, 0); + typename TypeHandler::Type* result = + cast(rep_->elements[--current_size_]); + --rep_->allocated_size; + if (current_size_ < rep_->allocated_size) { + // There are cleared elements on the end; replace the removed element + // with the last allocated element. + rep_->elements[current_size_] = rep_->elements[rep_->allocated_size]; + } + return result; +} + +inline int RepeatedPtrFieldBase::ClearedCount() const { + return rep_ ? (rep_->allocated_size - current_size_) : 0; +} + +template +inline void RepeatedPtrFieldBase::AddCleared( + typename TypeHandler::Type* value) { + GOOGLE_DCHECK(GetArena() == NULL) + << "AddCleared() can only be used on a RepeatedPtrField not on an arena."; + GOOGLE_DCHECK(TypeHandler::GetArena(value) == NULL) + << "AddCleared() can only accept values not on an arena."; + if (!rep_ || rep_->allocated_size == total_size_) { + Reserve(total_size_ + 1); + } + rep_->elements[rep_->allocated_size++] = value; +} + +template +inline typename TypeHandler::Type* RepeatedPtrFieldBase::ReleaseCleared() { + GOOGLE_DCHECK(GetArena() == NULL) + << "ReleaseCleared() can only be used on a RepeatedPtrField not on " + << "an arena."; + GOOGLE_DCHECK(GetArena() == NULL); + GOOGLE_DCHECK(rep_ != NULL); + GOOGLE_DCHECK_GT(rep_->allocated_size, current_size_); + return cast(rep_->elements[--rep_->allocated_size]); +} + +} // namespace internal + +// ------------------------------------------------------------------- + +template +class RepeatedPtrField::TypeHandler + : public internal::GenericTypeHandler {}; + +template <> +class RepeatedPtrField::TypeHandler + : public internal::StringTypeHandler {}; + +template +inline RepeatedPtrField::RepeatedPtrField() : RepeatedPtrFieldBase() {} + +template +inline RepeatedPtrField::RepeatedPtrField(Arena* arena) + : RepeatedPtrFieldBase(arena) {} + +template +inline RepeatedPtrField::RepeatedPtrField( + const RepeatedPtrField& other) + : RepeatedPtrFieldBase() { + MergeFrom(other); +} + +template +template +inline RepeatedPtrField::RepeatedPtrField(Iter begin, + const Iter& end) { + int reserve = internal::CalculateReserve(begin, end); + if (reserve != -1) { + Reserve(reserve); + } + for (; begin != end; ++begin) { + *Add() = *begin; + } +} + +template +RepeatedPtrField::~RepeatedPtrField() { + Destroy(); +} + +template +inline RepeatedPtrField& RepeatedPtrField::operator=( + const RepeatedPtrField& other) { + if (this != &other) CopyFrom(other); + return *this; +} + +template +inline RepeatedPtrField::RepeatedPtrField( + RepeatedPtrField&& other) noexcept + : RepeatedPtrField() { + // We don't just call Swap(&other) here because it would perform 3 copies if + // other is on an arena. This field can't be on an arena because arena + // construction always uses the Arena* accepting constructor. + if (other.GetArena()) { + CopyFrom(other); + } else { + InternalSwap(&other); + } +} + +template +inline RepeatedPtrField& RepeatedPtrField::operator=( + RepeatedPtrField&& other) noexcept { + // We don't just call Swap(&other) here because it would perform 3 copies if + // the two fields are on different arenas. + if (this != &other) { + if (this->GetArena() != other.GetArena()) { + CopyFrom(other); + } else { + InternalSwap(&other); + } + } + return *this; +} + +template +inline bool RepeatedPtrField::empty() const { + return RepeatedPtrFieldBase::empty(); +} + +template +inline int RepeatedPtrField::size() const { + return RepeatedPtrFieldBase::size(); +} + +template +inline const Element& RepeatedPtrField::Get(int index) const { + return RepeatedPtrFieldBase::Get(index); +} + +template +inline const Element& RepeatedPtrField::at(int index) const { + return RepeatedPtrFieldBase::at(index); +} + +template +inline Element& RepeatedPtrField::at(int index) { + return RepeatedPtrFieldBase::at(index); +} + + +template +inline Element* RepeatedPtrField::Mutable(int index) { + return RepeatedPtrFieldBase::Mutable(index); +} + +template +inline Element* RepeatedPtrField::Add() { + return RepeatedPtrFieldBase::Add(); +} + +template +inline void RepeatedPtrField::Add(Element&& value) { + RepeatedPtrFieldBase::Add(std::move(value)); +} + +template +inline void RepeatedPtrField::RemoveLast() { + RepeatedPtrFieldBase::RemoveLast(); +} + +template +inline void RepeatedPtrField::DeleteSubrange(int start, int num) { + GOOGLE_DCHECK_GE(start, 0); + GOOGLE_DCHECK_GE(num, 0); + GOOGLE_DCHECK_LE(start + num, size()); + for (int i = 0; i < num; ++i) { + RepeatedPtrFieldBase::Delete(start + i); + } + ExtractSubrange(start, num, NULL); +} + +template +inline void RepeatedPtrField::ExtractSubrange(int start, int num, + Element** elements) { + typename internal::TypeImplementsMergeBehavior< + typename TypeHandler::Type>::type t; + ExtractSubrangeInternal(start, num, elements, t); +} + +// ExtractSubrange() implementation for types that implement merge/copy +// behavior. +template +inline void RepeatedPtrField::ExtractSubrangeInternal( + int start, int num, Element** elements, std::true_type) { + GOOGLE_DCHECK_GE(start, 0); + GOOGLE_DCHECK_GE(num, 0); + GOOGLE_DCHECK_LE(start + num, size()); + + if (num > 0) { + // Save the values of the removed elements if requested. + if (elements != NULL) { + if (GetArena() != NULL) { + // If we're on an arena, we perform a copy for each element so that the + // returned elements are heap-allocated. + for (int i = 0; i < num; ++i) { + Element* element = + RepeatedPtrFieldBase::Mutable(i + start); + typename TypeHandler::Type* new_value = + TypeHandler::NewFromPrototype(element, NULL); + TypeHandler::Merge(*element, new_value); + elements[i] = new_value; + } + } else { + for (int i = 0; i < num; ++i) { + elements[i] = RepeatedPtrFieldBase::Mutable(i + start); + } + } + } + CloseGap(start, num); + } +} + +// ExtractSubrange() implementation for types that do not implement merge/copy +// behavior. +template +inline void RepeatedPtrField::ExtractSubrangeInternal( + int start, int num, Element** elements, std::false_type) { + // This case is identical to UnsafeArenaExtractSubrange(). However, since + // ExtractSubrange() must return heap-allocated objects by contract, and we + // cannot fulfill this contract if we are an on arena, we must GOOGLE_DCHECK() that + // we are not on an arena. + GOOGLE_DCHECK(GetArena() == NULL) + << "ExtractSubrange() when arena is non-NULL is only supported when " + << "the Element type supplies a MergeFrom() operation to make copies."; + UnsafeArenaExtractSubrange(start, num, elements); +} + +template +inline void RepeatedPtrField::UnsafeArenaExtractSubrange( + int start, int num, Element** elements) { + GOOGLE_DCHECK_GE(start, 0); + GOOGLE_DCHECK_GE(num, 0); + GOOGLE_DCHECK_LE(start + num, size()); + + if (num > 0) { + // Save the values of the removed elements if requested. + if (elements != NULL) { + for (int i = 0; i < num; ++i) { + elements[i] = RepeatedPtrFieldBase::Mutable(i + start); + } + } + CloseGap(start, num); + } +} + +template +inline void RepeatedPtrField::Clear() { + RepeatedPtrFieldBase::Clear(); +} + +template +inline void RepeatedPtrField::MergeFrom( + const RepeatedPtrField& other) { + RepeatedPtrFieldBase::MergeFrom(other); +} + +template +inline void RepeatedPtrField::CopyFrom(const RepeatedPtrField& other) { + RepeatedPtrFieldBase::CopyFrom(other); +} + +template +inline typename RepeatedPtrField::iterator +RepeatedPtrField::erase(const_iterator position) { + return erase(position, position + 1); +} + +template +inline typename RepeatedPtrField::iterator +RepeatedPtrField::erase(const_iterator first, const_iterator last) { + size_type pos_offset = std::distance(cbegin(), first); + size_type last_offset = std::distance(cbegin(), last); + DeleteSubrange(pos_offset, last_offset - pos_offset); + return begin() + pos_offset; +} + +template +inline Element** RepeatedPtrField::mutable_data() { + return RepeatedPtrFieldBase::mutable_data(); +} + +template +inline const Element* const* RepeatedPtrField::data() const { + return RepeatedPtrFieldBase::data(); +} + +template +inline void RepeatedPtrField::Swap(RepeatedPtrField* other) { + if (this == other) return; + RepeatedPtrFieldBase::Swap(other); +} + +template +inline void RepeatedPtrField::UnsafeArenaSwap( + RepeatedPtrField* other) { + if (this == other) return; + RepeatedPtrFieldBase::InternalSwap(other); +} + +template +inline void RepeatedPtrField::SwapElements(int index1, int index2) { + RepeatedPtrFieldBase::SwapElements(index1, index2); +} + +template +inline Arena* RepeatedPtrField::GetArena() const { + return RepeatedPtrFieldBase::GetArena(); +} + +template +inline size_t RepeatedPtrField::SpaceUsedExcludingSelfLong() const { + return RepeatedPtrFieldBase::SpaceUsedExcludingSelfLong(); +} + +template +inline void RepeatedPtrField::AddAllocated(Element* value) { + RepeatedPtrFieldBase::AddAllocated(value); +} + +template +inline void RepeatedPtrField::UnsafeArenaAddAllocated(Element* value) { + RepeatedPtrFieldBase::UnsafeArenaAddAllocated(value); +} + +template +inline Element* RepeatedPtrField::ReleaseLast() { + return RepeatedPtrFieldBase::ReleaseLast(); +} + +template +inline Element* RepeatedPtrField::UnsafeArenaReleaseLast() { + return RepeatedPtrFieldBase::UnsafeArenaReleaseLast(); +} + +template +inline int RepeatedPtrField::ClearedCount() const { + return RepeatedPtrFieldBase::ClearedCount(); +} + +template +inline void RepeatedPtrField::AddCleared(Element* value) { + return RepeatedPtrFieldBase::AddCleared(value); +} + +template +inline Element* RepeatedPtrField::ReleaseCleared() { + return RepeatedPtrFieldBase::ReleaseCleared(); +} + +template +inline void RepeatedPtrField::Reserve(int new_size) { + return RepeatedPtrFieldBase::Reserve(new_size); +} + +template +inline int RepeatedPtrField::Capacity() const { + return RepeatedPtrFieldBase::Capacity(); +} + +// ------------------------------------------------------------------- + +namespace internal { + +// STL-like iterator implementation for RepeatedPtrField. You should not +// refer to this class directly; use RepeatedPtrField::iterator instead. +// +// The iterator for RepeatedPtrField, RepeatedPtrIterator, is +// very similar to iterator_ptr in util/gtl/iterator_adaptors.h, +// but adds random-access operators and is modified to wrap a void** base +// iterator (since RepeatedPtrField stores its array as a void* array and +// casting void** to T** would violate C++ aliasing rules). +// +// This code based on net/proto/proto-array-internal.h by Jeffrey Yasskin +// (jyasskin@google.com). +template +class RepeatedPtrIterator { + public: + using iterator = RepeatedPtrIterator; + using iterator_category = std::random_access_iterator_tag; + using value_type = typename std::remove_const::type; + using difference_type = std::ptrdiff_t; + using pointer = Element*; + using reference = Element&; + + RepeatedPtrIterator() : it_(NULL) {} + explicit RepeatedPtrIterator(void* const* it) : it_(it) {} + + // Allow "upcasting" from RepeatedPtrIterator to + // RepeatedPtrIterator. + template + RepeatedPtrIterator(const RepeatedPtrIterator& other) + : it_(other.it_) { + // Force a compiler error if the other type is not convertible to ours. + if (false) { + implicit_cast(static_cast(nullptr)); + } + } + + // dereferenceable + reference operator*() const { return *reinterpret_cast(*it_); } + pointer operator->() const { return &(operator*()); } + + // {inc,dec}rementable + iterator& operator++() { + ++it_; + return *this; + } + iterator operator++(int) { return iterator(it_++); } + iterator& operator--() { + --it_; + return *this; + } + iterator operator--(int) { return iterator(it_--); } + + // equality_comparable + bool operator==(const iterator& x) const { return it_ == x.it_; } + bool operator!=(const iterator& x) const { return it_ != x.it_; } + + // less_than_comparable + bool operator<(const iterator& x) const { return it_ < x.it_; } + bool operator<=(const iterator& x) const { return it_ <= x.it_; } + bool operator>(const iterator& x) const { return it_ > x.it_; } + bool operator>=(const iterator& x) const { return it_ >= x.it_; } + + // addable, subtractable + iterator& operator+=(difference_type d) { + it_ += d; + return *this; + } + friend iterator operator+(iterator it, const difference_type d) { + it += d; + return it; + } + friend iterator operator+(const difference_type d, iterator it) { + it += d; + return it; + } + iterator& operator-=(difference_type d) { + it_ -= d; + return *this; + } + friend iterator operator-(iterator it, difference_type d) { + it -= d; + return it; + } + + // indexable + reference operator[](difference_type d) const { return *(*this + d); } + + // random access iterator + difference_type operator-(const iterator& x) const { return it_ - x.it_; } + + private: + template + friend class RepeatedPtrIterator; + + // The internal iterator. + void* const* it_; +}; + +// Provide an iterator that operates on pointers to the underlying objects +// rather than the objects themselves as RepeatedPtrIterator does. +// Consider using this when working with stl algorithms that change +// the array. +// The VoidPtr template parameter holds the type-agnostic pointer value +// referenced by the iterator. It should either be "void *" for a mutable +// iterator, or "const void* const" for a constant iterator. +template +class RepeatedPtrOverPtrsIterator { + public: + using iterator = RepeatedPtrOverPtrsIterator; + using iterator_category = std::random_access_iterator_tag; + using value_type = typename std::remove_const::type; + using difference_type = std::ptrdiff_t; + using pointer = Element*; + using reference = Element&; + + RepeatedPtrOverPtrsIterator() : it_(NULL) {} + explicit RepeatedPtrOverPtrsIterator(VoidPtr* it) : it_(it) {} + + // dereferenceable + reference operator*() const { return *reinterpret_cast(it_); } + pointer operator->() const { return &(operator*()); } + + // {inc,dec}rementable + iterator& operator++() { + ++it_; + return *this; + } + iterator operator++(int) { return iterator(it_++); } + iterator& operator--() { + --it_; + return *this; + } + iterator operator--(int) { return iterator(it_--); } + + // equality_comparable + bool operator==(const iterator& x) const { return it_ == x.it_; } + bool operator!=(const iterator& x) const { return it_ != x.it_; } + + // less_than_comparable + bool operator<(const iterator& x) const { return it_ < x.it_; } + bool operator<=(const iterator& x) const { return it_ <= x.it_; } + bool operator>(const iterator& x) const { return it_ > x.it_; } + bool operator>=(const iterator& x) const { return it_ >= x.it_; } + + // addable, subtractable + iterator& operator+=(difference_type d) { + it_ += d; + return *this; + } + friend iterator operator+(iterator it, difference_type d) { + it += d; + return it; + } + friend iterator operator+(difference_type d, iterator it) { + it += d; + return it; + } + iterator& operator-=(difference_type d) { + it_ -= d; + return *this; + } + friend iterator operator-(iterator it, difference_type d) { + it -= d; + return it; + } + + // indexable + reference operator[](difference_type d) const { return *(*this + d); } + + // random access iterator + difference_type operator-(const iterator& x) const { return it_ - x.it_; } + + private: + template + friend class RepeatedPtrIterator; + + // The internal iterator. + VoidPtr* it_; +}; + +void RepeatedPtrFieldBase::InternalSwap(RepeatedPtrFieldBase* other) { + GOOGLE_DCHECK(this != other); + GOOGLE_DCHECK(GetArena() == other->GetArena()); + + // Swap all fields at once. + static_assert(std::is_standard_layout::value, + "offsetof() requires standard layout before c++17"); + internal::memswaprep_) - + offsetof(RepeatedPtrFieldBase, current_size_)>( + reinterpret_cast(this) + + offsetof(RepeatedPtrFieldBase, current_size_), + reinterpret_cast(other) + + offsetof(RepeatedPtrFieldBase, current_size_)); +} + +} // namespace internal + +template +inline typename RepeatedPtrField::iterator +RepeatedPtrField::begin() { + return iterator(raw_data()); +} +template +inline typename RepeatedPtrField::const_iterator +RepeatedPtrField::begin() const { + return iterator(raw_data()); +} +template +inline typename RepeatedPtrField::const_iterator +RepeatedPtrField::cbegin() const { + return begin(); +} +template +inline typename RepeatedPtrField::iterator +RepeatedPtrField::end() { + return iterator(raw_data() + size()); +} +template +inline typename RepeatedPtrField::const_iterator +RepeatedPtrField::end() const { + return iterator(raw_data() + size()); +} +template +inline typename RepeatedPtrField::const_iterator +RepeatedPtrField::cend() const { + return end(); +} + +template +inline typename RepeatedPtrField::pointer_iterator +RepeatedPtrField::pointer_begin() { + return pointer_iterator(raw_mutable_data()); +} +template +inline typename RepeatedPtrField::const_pointer_iterator +RepeatedPtrField::pointer_begin() const { + return const_pointer_iterator(const_cast(raw_data())); +} +template +inline typename RepeatedPtrField::pointer_iterator +RepeatedPtrField::pointer_end() { + return pointer_iterator(raw_mutable_data() + size()); +} +template +inline typename RepeatedPtrField::const_pointer_iterator +RepeatedPtrField::pointer_end() const { + return const_pointer_iterator( + const_cast(raw_data() + size())); +} + +// Iterators and helper functions that follow the spirit of the STL +// std::back_insert_iterator and std::back_inserter but are tailor-made +// for RepeatedField and RepeatedPtrField. Typical usage would be: +// +// std::copy(some_sequence.begin(), some_sequence.end(), +// RepeatedFieldBackInserter(proto.mutable_sequence())); +// +// Ported by johannes from util/gtl/proto-array-iterators.h + +namespace internal { +// A back inserter for RepeatedField objects. +template +class RepeatedFieldBackInsertIterator + : public std::iterator { + public: + explicit RepeatedFieldBackInsertIterator( + RepeatedField* const mutable_field) + : field_(mutable_field) {} + RepeatedFieldBackInsertIterator& operator=(const T& value) { + field_->Add(value); + return *this; + } + RepeatedFieldBackInsertIterator& operator*() { return *this; } + RepeatedFieldBackInsertIterator& operator++() { return *this; } + RepeatedFieldBackInsertIterator& operator++(int /* unused */) { + return *this; + } + + private: + RepeatedField* field_; +}; + +// A back inserter for RepeatedPtrField objects. +template +class RepeatedPtrFieldBackInsertIterator + : public std::iterator { + public: + RepeatedPtrFieldBackInsertIterator(RepeatedPtrField* const mutable_field) + : field_(mutable_field) {} + RepeatedPtrFieldBackInsertIterator& operator=(const T& value) { + *field_->Add() = value; + return *this; + } + RepeatedPtrFieldBackInsertIterator& operator=( + const T* const ptr_to_value) { + *field_->Add() = *ptr_to_value; + return *this; + } + RepeatedPtrFieldBackInsertIterator& operator=(T&& value) { + *field_->Add() = std::move(value); + return *this; + } + RepeatedPtrFieldBackInsertIterator& operator*() { return *this; } + RepeatedPtrFieldBackInsertIterator& operator++() { return *this; } + RepeatedPtrFieldBackInsertIterator& operator++(int /* unused */) { + return *this; + } + + private: + RepeatedPtrField* field_; +}; + +// A back inserter for RepeatedPtrFields that inserts by transferring ownership +// of a pointer. +template +class AllocatedRepeatedPtrFieldBackInsertIterator + : public std::iterator { + public: + explicit AllocatedRepeatedPtrFieldBackInsertIterator( + RepeatedPtrField* const mutable_field) + : field_(mutable_field) {} + AllocatedRepeatedPtrFieldBackInsertIterator& operator=( + T* const ptr_to_value) { + field_->AddAllocated(ptr_to_value); + return *this; + } + AllocatedRepeatedPtrFieldBackInsertIterator& operator*() { return *this; } + AllocatedRepeatedPtrFieldBackInsertIterator& operator++() { return *this; } + AllocatedRepeatedPtrFieldBackInsertIterator& operator++(int /* unused */) { + return *this; + } + + private: + RepeatedPtrField* field_; +}; + +// Almost identical to AllocatedRepeatedPtrFieldBackInsertIterator. This one +// uses the UnsafeArenaAddAllocated instead. +template +class UnsafeArenaAllocatedRepeatedPtrFieldBackInsertIterator + : public std::iterator { + public: + explicit UnsafeArenaAllocatedRepeatedPtrFieldBackInsertIterator( + RepeatedPtrField* const mutable_field) + : field_(mutable_field) {} + UnsafeArenaAllocatedRepeatedPtrFieldBackInsertIterator& operator=( + T const* const ptr_to_value) { + field_->UnsafeArenaAddAllocated(const_cast(ptr_to_value)); + return *this; + } + UnsafeArenaAllocatedRepeatedPtrFieldBackInsertIterator& operator*() { + return *this; + } + UnsafeArenaAllocatedRepeatedPtrFieldBackInsertIterator& operator++() { + return *this; + } + UnsafeArenaAllocatedRepeatedPtrFieldBackInsertIterator& operator++( + int /* unused */) { + return *this; + } + + private: + RepeatedPtrField* field_; +}; + +} // namespace internal + +// Provides a back insert iterator for RepeatedField instances, +// similar to std::back_inserter(). +template +internal::RepeatedFieldBackInsertIterator RepeatedFieldBackInserter( + RepeatedField* const mutable_field) { + return internal::RepeatedFieldBackInsertIterator(mutable_field); +} + +// Provides a back insert iterator for RepeatedPtrField instances, +// similar to std::back_inserter(). +template +internal::RepeatedPtrFieldBackInsertIterator RepeatedPtrFieldBackInserter( + RepeatedPtrField* const mutable_field) { + return internal::RepeatedPtrFieldBackInsertIterator(mutable_field); +} + +// Special back insert iterator for RepeatedPtrField instances, just in +// case someone wants to write generic template code that can access both +// RepeatedFields and RepeatedPtrFields using a common name. +template +internal::RepeatedPtrFieldBackInsertIterator RepeatedFieldBackInserter( + RepeatedPtrField* const mutable_field) { + return internal::RepeatedPtrFieldBackInsertIterator(mutable_field); +} + +// Provides a back insert iterator for RepeatedPtrField instances +// similar to std::back_inserter() which transfers the ownership while +// copying elements. +template +internal::AllocatedRepeatedPtrFieldBackInsertIterator +AllocatedRepeatedPtrFieldBackInserter( + RepeatedPtrField* const mutable_field) { + return internal::AllocatedRepeatedPtrFieldBackInsertIterator( + mutable_field); +} + +// Similar to AllocatedRepeatedPtrFieldBackInserter, using +// UnsafeArenaAddAllocated instead of AddAllocated. +// This is slightly faster if that matters. It is also useful in legacy code +// that uses temporary ownership to avoid copies. Example: +// RepeatedPtrField temp_field; +// temp_field.AddAllocated(new T); +// ... // Do something with temp_field +// temp_field.ExtractSubrange(0, temp_field.size(), nullptr); +// If you put temp_field on the arena this fails, because the ownership +// transfers to the arena at the "AddAllocated" call and is not released anymore +// causing a double delete. Using UnsafeArenaAddAllocated prevents this. +template +internal::UnsafeArenaAllocatedRepeatedPtrFieldBackInsertIterator +UnsafeArenaAllocatedRepeatedPtrFieldBackInserter( + RepeatedPtrField* const mutable_field) { + return internal::UnsafeArenaAllocatedRepeatedPtrFieldBackInsertIterator( + mutable_field); +} + +// Extern declarations of common instantiations to reduce library bloat. +extern template class PROTOBUF_EXPORT_TEMPLATE_DECLARE RepeatedField; +extern template class PROTOBUF_EXPORT_TEMPLATE_DECLARE RepeatedField; +extern template class PROTOBUF_EXPORT_TEMPLATE_DECLARE RepeatedField; +extern template class PROTOBUF_EXPORT_TEMPLATE_DECLARE RepeatedField; +extern template class PROTOBUF_EXPORT_TEMPLATE_DECLARE RepeatedField; +extern template class PROTOBUF_EXPORT_TEMPLATE_DECLARE RepeatedField; +extern template class PROTOBUF_EXPORT_TEMPLATE_DECLARE RepeatedField; +extern template class PROTOBUF_EXPORT_TEMPLATE_DECLARE + RepeatedPtrField; + +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_REPEATED_FIELD_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/service.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/service.h new file mode 100644 index 0000000000000000000000000000000000000000..830792456ff72855e7cc08f4983cd8dc4f1c1d0b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/service.h @@ -0,0 +1,298 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: kenton@google.com (Kenton Varda) +// Based on original Protocol Buffers design by +// Sanjay Ghemawat, Jeff Dean, and others. +// +// DEPRECATED: This module declares the abstract interfaces underlying proto2 +// RPC services. These are intended to be independent of any particular RPC +// implementation, so that proto2 services can be used on top of a variety +// of implementations. Starting with version 2.3.0, RPC implementations should +// not try to build on these, but should instead provide code generator plugins +// which generate code specific to the particular RPC implementation. This way +// the generated code can be more appropriate for the implementation in use +// and can avoid unnecessary layers of indirection. +// +// +// When you use the protocol compiler to compile a service definition, it +// generates two classes: An abstract interface for the service (with +// methods matching the service definition) and a "stub" implementation. +// A stub is just a type-safe wrapper around an RpcChannel which emulates a +// local implementation of the service. +// +// For example, the service definition: +// service MyService { +// rpc Foo(MyRequest) returns(MyResponse); +// } +// will generate abstract interface "MyService" and class "MyService::Stub". +// You could implement a MyService as follows: +// class MyServiceImpl : public MyService { +// public: +// MyServiceImpl() {} +// ~MyServiceImpl() {} +// +// // implements MyService --------------------------------------- +// +// void Foo(google::protobuf::RpcController* controller, +// const MyRequest* request, +// MyResponse* response, +// Closure* done) { +// // ... read request and fill in response ... +// done->Run(); +// } +// }; +// You would then register an instance of MyServiceImpl with your RPC server +// implementation. (How to do that depends on the implementation.) +// +// To call a remote MyServiceImpl, first you need an RpcChannel connected to it. +// How to construct a channel depends, again, on your RPC implementation. +// Here we use a hypothetical "MyRpcChannel" as an example: +// MyRpcChannel channel("rpc:hostname:1234/myservice"); +// MyRpcController controller; +// MyServiceImpl::Stub stub(&channel); +// FooRequest request; +// FooResponse response; +// +// // ... fill in request ... +// +// stub.Foo(&controller, request, &response, NewCallback(HandleResponse)); +// +// On Thread-Safety: +// +// Different RPC implementations may make different guarantees about what +// threads they may run callbacks on, and what threads the application is +// allowed to use to call the RPC system. Portable software should be ready +// for callbacks to be called on any thread, but should not try to call the +// RPC system from any thread except for the ones on which it received the +// callbacks. Realistically, though, simple software will probably want to +// use a single-threaded RPC system while high-end software will want to +// use multiple threads. RPC implementations should provide multiple +// choices. + +#ifndef GOOGLE_PROTOBUF_SERVICE_H__ +#define GOOGLE_PROTOBUF_SERVICE_H__ + +#include +#include +#include + +#ifdef SWIG +#error "You cannot SWIG proto headers" +#endif + +#include + +namespace google { +namespace protobuf { + +// Defined in this file. +class Service; +class RpcController; +class RpcChannel; + +// Defined in other files. +class Descriptor; // descriptor.h +class ServiceDescriptor; // descriptor.h +class MethodDescriptor; // descriptor.h +class Message; // message.h + +// Abstract base interface for protocol-buffer-based RPC services. Services +// themselves are abstract interfaces (implemented either by servers or as +// stubs), but they subclass this base interface. The methods of this +// interface can be used to call the methods of the Service without knowing +// its exact type at compile time (analogous to Reflection). +class PROTOBUF_EXPORT Service { + public: + inline Service() {} + virtual ~Service(); + + // When constructing a stub, you may pass STUB_OWNS_CHANNEL as the second + // parameter to the constructor to tell it to delete its RpcChannel when + // destroyed. + enum ChannelOwnership { STUB_OWNS_CHANNEL, STUB_DOESNT_OWN_CHANNEL }; + + // Get the ServiceDescriptor describing this service and its methods. + virtual const ServiceDescriptor* GetDescriptor() = 0; + + // Call a method of the service specified by MethodDescriptor. This is + // normally implemented as a simple switch() that calls the standard + // definitions of the service's methods. + // + // Preconditions: + // * method->service() == GetDescriptor() + // * request and response are of the exact same classes as the objects + // returned by GetRequestPrototype(method) and + // GetResponsePrototype(method). + // * After the call has started, the request must not be modified and the + // response must not be accessed at all until "done" is called. + // * "controller" is of the correct type for the RPC implementation being + // used by this Service. For stubs, the "correct type" depends on the + // RpcChannel which the stub is using. Server-side Service + // implementations are expected to accept whatever type of RpcController + // the server-side RPC implementation uses. + // + // Postconditions: + // * "done" will be called when the method is complete. This may be + // before CallMethod() returns or it may be at some point in the future. + // * If the RPC succeeded, "response" contains the response returned by + // the server. + // * If the RPC failed, "response"'s contents are undefined. The + // RpcController can be queried to determine if an error occurred and + // possibly to get more information about the error. + virtual void CallMethod(const MethodDescriptor* method, + RpcController* controller, const Message* request, + Message* response, Closure* done) = 0; + + // CallMethod() requires that the request and response passed in are of a + // particular subclass of Message. GetRequestPrototype() and + // GetResponsePrototype() get the default instances of these required types. + // You can then call Message::New() on these instances to construct mutable + // objects which you can then pass to CallMethod(). + // + // Example: + // const MethodDescriptor* method = + // service->GetDescriptor()->FindMethodByName("Foo"); + // Message* request = stub->GetRequestPrototype (method)->New(); + // Message* response = stub->GetResponsePrototype(method)->New(); + // request->ParseFromString(input); + // service->CallMethod(method, *request, response, callback); + virtual const Message& GetRequestPrototype( + const MethodDescriptor* method) const = 0; + virtual const Message& GetResponsePrototype( + const MethodDescriptor* method) const = 0; + + private: + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(Service); +}; + +// An RpcController mediates a single method call. The primary purpose of +// the controller is to provide a way to manipulate settings specific to the +// RPC implementation and to find out about RPC-level errors. +// +// The methods provided by the RpcController interface are intended to be a +// "least common denominator" set of features which we expect all +// implementations to support. Specific implementations may provide more +// advanced features (e.g. deadline propagation). +class PROTOBUF_EXPORT RpcController { + public: + inline RpcController() {} + virtual ~RpcController(); + + // Client-side methods --------------------------------------------- + // These calls may be made from the client side only. Their results + // are undefined on the server side (may crash). + + // Resets the RpcController to its initial state so that it may be reused in + // a new call. Must not be called while an RPC is in progress. + virtual void Reset() = 0; + + // After a call has finished, returns true if the call failed. The possible + // reasons for failure depend on the RPC implementation. Failed() must not + // be called before a call has finished. If Failed() returns true, the + // contents of the response message are undefined. + virtual bool Failed() const = 0; + + // If Failed() is true, returns a human-readable description of the error. + virtual std::string ErrorText() const = 0; + + // Advises the RPC system that the caller desires that the RPC call be + // canceled. The RPC system may cancel it immediately, may wait awhile and + // then cancel it, or may not even cancel the call at all. If the call is + // canceled, the "done" callback will still be called and the RpcController + // will indicate that the call failed at that time. + virtual void StartCancel() = 0; + + // Server-side methods --------------------------------------------- + // These calls may be made from the server side only. Their results + // are undefined on the client side (may crash). + + // Causes Failed() to return true on the client side. "reason" will be + // incorporated into the message returned by ErrorText(). If you find + // you need to return machine-readable information about failures, you + // should incorporate it into your response protocol buffer and should + // NOT call SetFailed(). + virtual void SetFailed(const std::string& reason) = 0; + + // If true, indicates that the client canceled the RPC, so the server may + // as well give up on replying to it. The server should still call the + // final "done" callback. + virtual bool IsCanceled() const = 0; + + // Asks that the given callback be called when the RPC is canceled. The + // callback will always be called exactly once. If the RPC completes without + // being canceled, the callback will be called after completion. If the RPC + // has already been canceled when NotifyOnCancel() is called, the callback + // will be called immediately. + // + // NotifyOnCancel() must be called no more than once per request. + virtual void NotifyOnCancel(Closure* callback) = 0; + + private: + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(RpcController); +}; + +// Abstract interface for an RPC channel. An RpcChannel represents a +// communication line to a Service which can be used to call that Service's +// methods. The Service may be running on another machine. Normally, you +// should not call an RpcChannel directly, but instead construct a stub Service +// wrapping it. Example: +// RpcChannel* channel = new MyRpcChannel("remotehost.example.com:1234"); +// MyService* service = new MyService::Stub(channel); +// service->MyMethod(request, &response, callback); +class PROTOBUF_EXPORT RpcChannel { + public: + inline RpcChannel() {} + virtual ~RpcChannel(); + + // Call the given method of the remote service. The signature of this + // procedure looks the same as Service::CallMethod(), but the requirements + // are less strict in one important way: the request and response objects + // need not be of any specific class as long as their descriptors are + // method->input_type() and method->output_type(). + virtual void CallMethod(const MethodDescriptor* method, + RpcController* controller, const Message* request, + Message* response, Closure* done) = 0; + + private: + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(RpcChannel); +}; + +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_SERVICE_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/source_context.pb.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/source_context.pb.h new file mode 100644 index 0000000000000000000000000000000000000000..f3d36923e0507a71255343c2c56d6853d774babe --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/source_context.pb.h @@ -0,0 +1,300 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: google/protobuf/source_context.proto + +#ifndef GOOGLE_PROTOBUF_INCLUDED_google_2fprotobuf_2fsource_5fcontext_2eproto +#define GOOGLE_PROTOBUF_INCLUDED_google_2fprotobuf_2fsource_5fcontext_2eproto + +#include +#include + +#include +#if PROTOBUF_VERSION < 3013000 +#error This file was generated by a newer version of protoc which is +#error incompatible with your Protocol Buffer headers. Please update +#error your headers. +#endif +#if 3013000 < PROTOBUF_MIN_PROTOC_VERSION +#error This file was generated by an older version of protoc which is +#error incompatible with your Protocol Buffer headers. Please +#error regenerate this file with a newer version of protoc. +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include // IWYU pragma: export +#include // IWYU pragma: export +#include +// @@protoc_insertion_point(includes) +#include +#define PROTOBUF_INTERNAL_EXPORT_google_2fprotobuf_2fsource_5fcontext_2eproto PROTOBUF_EXPORT +PROTOBUF_NAMESPACE_OPEN +namespace internal { +class AnyMetadata; +} // namespace internal +PROTOBUF_NAMESPACE_CLOSE + +// Internal implementation detail -- do not use these members. +struct PROTOBUF_EXPORT TableStruct_google_2fprotobuf_2fsource_5fcontext_2eproto { + static const ::PROTOBUF_NAMESPACE_ID::internal::ParseTableField entries[] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::AuxiliaryParseTableField aux[] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::ParseTable schema[1] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::FieldMetadata field_metadata[]; + static const ::PROTOBUF_NAMESPACE_ID::internal::SerializationTable serialization_table[]; + static const ::PROTOBUF_NAMESPACE_ID::uint32 offsets[]; +}; +extern PROTOBUF_EXPORT const ::PROTOBUF_NAMESPACE_ID::internal::DescriptorTable descriptor_table_google_2fprotobuf_2fsource_5fcontext_2eproto; +PROTOBUF_NAMESPACE_OPEN +class SourceContext; +class SourceContextDefaultTypeInternal; +PROTOBUF_EXPORT extern SourceContextDefaultTypeInternal _SourceContext_default_instance_; +PROTOBUF_NAMESPACE_CLOSE +PROTOBUF_NAMESPACE_OPEN +template<> PROTOBUF_EXPORT PROTOBUF_NAMESPACE_ID::SourceContext* Arena::CreateMaybeMessage(Arena*); +PROTOBUF_NAMESPACE_CLOSE +PROTOBUF_NAMESPACE_OPEN + +// =================================================================== + +class PROTOBUF_EXPORT SourceContext PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:google.protobuf.SourceContext) */ { + public: + inline SourceContext() : SourceContext(nullptr) {} + virtual ~SourceContext(); + + SourceContext(const SourceContext& from); + SourceContext(SourceContext&& from) noexcept + : SourceContext() { + *this = ::std::move(from); + } + + inline SourceContext& operator=(const SourceContext& from) { + CopyFrom(from); + return *this; + } + inline SourceContext& operator=(SourceContext&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const SourceContext& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const SourceContext* internal_default_instance() { + return reinterpret_cast( + &_SourceContext_default_instance_); + } + static constexpr int kIndexInFileMessages = + 0; + + friend void swap(SourceContext& a, SourceContext& b) { + a.Swap(&b); + } + inline void Swap(SourceContext* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(SourceContext* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline SourceContext* New() const final { + return CreateMaybeMessage(nullptr); + } + + SourceContext* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const SourceContext& from); + void MergeFrom(const SourceContext& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(SourceContext* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "google.protobuf.SourceContext"; + } + protected: + explicit SourceContext(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_google_2fprotobuf_2fsource_5fcontext_2eproto); + return ::descriptor_table_google_2fprotobuf_2fsource_5fcontext_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kFileNameFieldNumber = 1, + }; + // string file_name = 1; + void clear_file_name(); + const std::string& file_name() const; + void set_file_name(const std::string& value); + void set_file_name(std::string&& value); + void set_file_name(const char* value); + void set_file_name(const char* value, size_t size); + std::string* mutable_file_name(); + std::string* release_file_name(); + void set_allocated_file_name(std::string* file_name); + private: + const std::string& _internal_file_name() const; + void _internal_set_file_name(const std::string& value); + std::string* _internal_mutable_file_name(); + public: + + // @@protoc_insertion_point(class_scope:google.protobuf.SourceContext) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr file_name_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + friend struct ::TableStruct_google_2fprotobuf_2fsource_5fcontext_2eproto; +}; +// =================================================================== + + +// =================================================================== + +#ifdef __GNUC__ + #pragma GCC diagnostic push + #pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif // __GNUC__ +// SourceContext + +// string file_name = 1; +inline void SourceContext::clear_file_name() { + file_name_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline const std::string& SourceContext::file_name() const { + // @@protoc_insertion_point(field_get:google.protobuf.SourceContext.file_name) + return _internal_file_name(); +} +inline void SourceContext::set_file_name(const std::string& value) { + _internal_set_file_name(value); + // @@protoc_insertion_point(field_set:google.protobuf.SourceContext.file_name) +} +inline std::string* SourceContext::mutable_file_name() { + // @@protoc_insertion_point(field_mutable:google.protobuf.SourceContext.file_name) + return _internal_mutable_file_name(); +} +inline const std::string& SourceContext::_internal_file_name() const { + return file_name_.Get(); +} +inline void SourceContext::_internal_set_file_name(const std::string& value) { + + file_name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void SourceContext::set_file_name(std::string&& value) { + + file_name_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:google.protobuf.SourceContext.file_name) +} +inline void SourceContext::set_file_name(const char* value) { + GOOGLE_DCHECK(value != nullptr); + + file_name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:google.protobuf.SourceContext.file_name) +} +inline void SourceContext::set_file_name(const char* value, + size_t size) { + + file_name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:google.protobuf.SourceContext.file_name) +} +inline std::string* SourceContext::_internal_mutable_file_name() { + + return file_name_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* SourceContext::release_file_name() { + // @@protoc_insertion_point(field_release:google.protobuf.SourceContext.file_name) + return file_name_.Release(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void SourceContext::set_allocated_file_name(std::string* file_name) { + if (file_name != nullptr) { + + } else { + + } + file_name_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), file_name, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:google.protobuf.SourceContext.file_name) +} + +#ifdef __GNUC__ + #pragma GCC diagnostic pop +#endif // __GNUC__ + +// @@protoc_insertion_point(namespace_scope) + +PROTOBUF_NAMESPACE_CLOSE + +// @@protoc_insertion_point(global_scope) + +#include +#endif // GOOGLE_PROTOBUF_INCLUDED_GOOGLE_PROTOBUF_INCLUDED_google_2fprotobuf_2fsource_5fcontext_2eproto + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/struct.pb.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/struct.pb.h new file mode 100644 index 0000000000000000000000000000000000000000..349b293ade6e3011ec94538b2acfb0099be7a369 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/struct.pb.h @@ -0,0 +1,1178 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: google/protobuf/struct.proto + +#ifndef GOOGLE_PROTOBUF_INCLUDED_google_2fprotobuf_2fstruct_2eproto +#define GOOGLE_PROTOBUF_INCLUDED_google_2fprotobuf_2fstruct_2eproto + +#include +#include + +#include +#if PROTOBUF_VERSION < 3013000 +#error This file was generated by a newer version of protoc which is +#error incompatible with your Protocol Buffer headers. Please update +#error your headers. +#endif +#if 3013000 < PROTOBUF_MIN_PROTOC_VERSION +#error This file was generated by an older version of protoc which is +#error incompatible with your Protocol Buffer headers. Please +#error regenerate this file with a newer version of protoc. +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include // IWYU pragma: export +#include // IWYU pragma: export +#include // IWYU pragma: export +#include +#include +#include +#include +// @@protoc_insertion_point(includes) +#include +#define PROTOBUF_INTERNAL_EXPORT_google_2fprotobuf_2fstruct_2eproto PROTOBUF_EXPORT +PROTOBUF_NAMESPACE_OPEN +namespace internal { +class AnyMetadata; +} // namespace internal +PROTOBUF_NAMESPACE_CLOSE + +// Internal implementation detail -- do not use these members. +struct PROTOBUF_EXPORT TableStruct_google_2fprotobuf_2fstruct_2eproto { + static const ::PROTOBUF_NAMESPACE_ID::internal::ParseTableField entries[] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::AuxiliaryParseTableField aux[] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::ParseTable schema[4] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::FieldMetadata field_metadata[]; + static const ::PROTOBUF_NAMESPACE_ID::internal::SerializationTable serialization_table[]; + static const ::PROTOBUF_NAMESPACE_ID::uint32 offsets[]; +}; +extern PROTOBUF_EXPORT const ::PROTOBUF_NAMESPACE_ID::internal::DescriptorTable descriptor_table_google_2fprotobuf_2fstruct_2eproto; +PROTOBUF_NAMESPACE_OPEN +class ListValue; +class ListValueDefaultTypeInternal; +PROTOBUF_EXPORT extern ListValueDefaultTypeInternal _ListValue_default_instance_; +class Struct; +class StructDefaultTypeInternal; +PROTOBUF_EXPORT extern StructDefaultTypeInternal _Struct_default_instance_; +class Struct_FieldsEntry_DoNotUse; +class Struct_FieldsEntry_DoNotUseDefaultTypeInternal; +PROTOBUF_EXPORT extern Struct_FieldsEntry_DoNotUseDefaultTypeInternal _Struct_FieldsEntry_DoNotUse_default_instance_; +class Value; +class ValueDefaultTypeInternal; +PROTOBUF_EXPORT extern ValueDefaultTypeInternal _Value_default_instance_; +PROTOBUF_NAMESPACE_CLOSE +PROTOBUF_NAMESPACE_OPEN +template<> PROTOBUF_EXPORT PROTOBUF_NAMESPACE_ID::ListValue* Arena::CreateMaybeMessage(Arena*); +template<> PROTOBUF_EXPORT PROTOBUF_NAMESPACE_ID::Struct* Arena::CreateMaybeMessage(Arena*); +template<> PROTOBUF_EXPORT PROTOBUF_NAMESPACE_ID::Struct_FieldsEntry_DoNotUse* Arena::CreateMaybeMessage(Arena*); +template<> PROTOBUF_EXPORT PROTOBUF_NAMESPACE_ID::Value* Arena::CreateMaybeMessage(Arena*); +PROTOBUF_NAMESPACE_CLOSE +PROTOBUF_NAMESPACE_OPEN + +enum NullValue : int { + NULL_VALUE = 0, + NullValue_INT_MIN_SENTINEL_DO_NOT_USE_ = std::numeric_limits<::PROTOBUF_NAMESPACE_ID::int32>::min(), + NullValue_INT_MAX_SENTINEL_DO_NOT_USE_ = std::numeric_limits<::PROTOBUF_NAMESPACE_ID::int32>::max() +}; +PROTOBUF_EXPORT bool NullValue_IsValid(int value); +constexpr NullValue NullValue_MIN = NULL_VALUE; +constexpr NullValue NullValue_MAX = NULL_VALUE; +constexpr int NullValue_ARRAYSIZE = NullValue_MAX + 1; + +PROTOBUF_EXPORT const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* NullValue_descriptor(); +template +inline const std::string& NullValue_Name(T enum_t_value) { + static_assert(::std::is_same::value || + ::std::is_integral::value, + "Incorrect type passed to function NullValue_Name."); + return ::PROTOBUF_NAMESPACE_ID::internal::NameOfEnum( + NullValue_descriptor(), enum_t_value); +} +inline bool NullValue_Parse( + ::PROTOBUF_NAMESPACE_ID::ConstStringParam name, NullValue* value) { + return ::PROTOBUF_NAMESPACE_ID::internal::ParseNamedEnum( + NullValue_descriptor(), name, value); +} +// =================================================================== + +class Struct_FieldsEntry_DoNotUse : public ::PROTOBUF_NAMESPACE_ID::internal::MapEntry { +public: + typedef ::PROTOBUF_NAMESPACE_ID::internal::MapEntry SuperType; + Struct_FieldsEntry_DoNotUse(); + explicit Struct_FieldsEntry_DoNotUse(::PROTOBUF_NAMESPACE_ID::Arena* arena); + void MergeFrom(const Struct_FieldsEntry_DoNotUse& other); + static const Struct_FieldsEntry_DoNotUse* internal_default_instance() { return reinterpret_cast(&_Struct_FieldsEntry_DoNotUse_default_instance_); } + static bool ValidateKey(std::string* s) { + return ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::VerifyUtf8String(s->data(), static_cast(s->size()), ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::PARSE, "google.protobuf.Struct.FieldsEntry.key"); + } + static bool ValidateValue(void*) { return true; } + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& other) final; + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_google_2fprotobuf_2fstruct_2eproto); + return ::descriptor_table_google_2fprotobuf_2fstruct_2eproto.file_level_metadata[0]; + } + + public: +}; + +// ------------------------------------------------------------------- + +class PROTOBUF_EXPORT Struct PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:google.protobuf.Struct) */ { + public: + inline Struct() : Struct(nullptr) {} + virtual ~Struct(); + + Struct(const Struct& from); + Struct(Struct&& from) noexcept + : Struct() { + *this = ::std::move(from); + } + + inline Struct& operator=(const Struct& from) { + CopyFrom(from); + return *this; + } + inline Struct& operator=(Struct&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const Struct& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const Struct* internal_default_instance() { + return reinterpret_cast( + &_Struct_default_instance_); + } + static constexpr int kIndexInFileMessages = + 1; + + friend void swap(Struct& a, Struct& b) { + a.Swap(&b); + } + inline void Swap(Struct* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(Struct* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline Struct* New() const final { + return CreateMaybeMessage(nullptr); + } + + Struct* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const Struct& from); + void MergeFrom(const Struct& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(Struct* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "google.protobuf.Struct"; + } + protected: + explicit Struct(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_google_2fprotobuf_2fstruct_2eproto); + return ::descriptor_table_google_2fprotobuf_2fstruct_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + + // accessors ------------------------------------------------------- + + enum : int { + kFieldsFieldNumber = 1, + }; + // map fields = 1; + int fields_size() const; + private: + int _internal_fields_size() const; + public: + void clear_fields(); + private: + const ::PROTOBUF_NAMESPACE_ID::Map< std::string, PROTOBUF_NAMESPACE_ID::Value >& + _internal_fields() const; + ::PROTOBUF_NAMESPACE_ID::Map< std::string, PROTOBUF_NAMESPACE_ID::Value >* + _internal_mutable_fields(); + public: + const ::PROTOBUF_NAMESPACE_ID::Map< std::string, PROTOBUF_NAMESPACE_ID::Value >& + fields() const; + ::PROTOBUF_NAMESPACE_ID::Map< std::string, PROTOBUF_NAMESPACE_ID::Value >* + mutable_fields(); + + // @@protoc_insertion_point(class_scope:google.protobuf.Struct) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::MapField< + Struct_FieldsEntry_DoNotUse, + std::string, PROTOBUF_NAMESPACE_ID::Value, + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::TYPE_STRING, + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::TYPE_MESSAGE, + 0 > fields_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + friend struct ::TableStruct_google_2fprotobuf_2fstruct_2eproto; +}; +// ------------------------------------------------------------------- + +class PROTOBUF_EXPORT Value PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:google.protobuf.Value) */ { + public: + inline Value() : Value(nullptr) {} + virtual ~Value(); + + Value(const Value& from); + Value(Value&& from) noexcept + : Value() { + *this = ::std::move(from); + } + + inline Value& operator=(const Value& from) { + CopyFrom(from); + return *this; + } + inline Value& operator=(Value&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const Value& default_instance(); + + enum KindCase { + kNullValue = 1, + kNumberValue = 2, + kStringValue = 3, + kBoolValue = 4, + kStructValue = 5, + kListValue = 6, + KIND_NOT_SET = 0, + }; + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const Value* internal_default_instance() { + return reinterpret_cast( + &_Value_default_instance_); + } + static constexpr int kIndexInFileMessages = + 2; + + friend void swap(Value& a, Value& b) { + a.Swap(&b); + } + inline void Swap(Value* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(Value* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline Value* New() const final { + return CreateMaybeMessage(nullptr); + } + + Value* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const Value& from); + void MergeFrom(const Value& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(Value* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "google.protobuf.Value"; + } + protected: + explicit Value(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_google_2fprotobuf_2fstruct_2eproto); + return ::descriptor_table_google_2fprotobuf_2fstruct_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kNullValueFieldNumber = 1, + kNumberValueFieldNumber = 2, + kStringValueFieldNumber = 3, + kBoolValueFieldNumber = 4, + kStructValueFieldNumber = 5, + kListValueFieldNumber = 6, + }; + // .google.protobuf.NullValue null_value = 1; + private: + bool _internal_has_null_value() const; + public: + void clear_null_value(); + PROTOBUF_NAMESPACE_ID::NullValue null_value() const; + void set_null_value(PROTOBUF_NAMESPACE_ID::NullValue value); + private: + PROTOBUF_NAMESPACE_ID::NullValue _internal_null_value() const; + void _internal_set_null_value(PROTOBUF_NAMESPACE_ID::NullValue value); + public: + + // double number_value = 2; + private: + bool _internal_has_number_value() const; + public: + void clear_number_value(); + double number_value() const; + void set_number_value(double value); + private: + double _internal_number_value() const; + void _internal_set_number_value(double value); + public: + + // string string_value = 3; + private: + bool _internal_has_string_value() const; + public: + void clear_string_value(); + const std::string& string_value() const; + void set_string_value(const std::string& value); + void set_string_value(std::string&& value); + void set_string_value(const char* value); + void set_string_value(const char* value, size_t size); + std::string* mutable_string_value(); + std::string* release_string_value(); + void set_allocated_string_value(std::string* string_value); + private: + const std::string& _internal_string_value() const; + void _internal_set_string_value(const std::string& value); + std::string* _internal_mutable_string_value(); + public: + + // bool bool_value = 4; + private: + bool _internal_has_bool_value() const; + public: + void clear_bool_value(); + bool bool_value() const; + void set_bool_value(bool value); + private: + bool _internal_bool_value() const; + void _internal_set_bool_value(bool value); + public: + + // .google.protobuf.Struct struct_value = 5; + bool has_struct_value() const; + private: + bool _internal_has_struct_value() const; + public: + void clear_struct_value(); + const PROTOBUF_NAMESPACE_ID::Struct& struct_value() const; + PROTOBUF_NAMESPACE_ID::Struct* release_struct_value(); + PROTOBUF_NAMESPACE_ID::Struct* mutable_struct_value(); + void set_allocated_struct_value(PROTOBUF_NAMESPACE_ID::Struct* struct_value); + private: + const PROTOBUF_NAMESPACE_ID::Struct& _internal_struct_value() const; + PROTOBUF_NAMESPACE_ID::Struct* _internal_mutable_struct_value(); + public: + void unsafe_arena_set_allocated_struct_value( + PROTOBUF_NAMESPACE_ID::Struct* struct_value); + PROTOBUF_NAMESPACE_ID::Struct* unsafe_arena_release_struct_value(); + + // .google.protobuf.ListValue list_value = 6; + bool has_list_value() const; + private: + bool _internal_has_list_value() const; + public: + void clear_list_value(); + const PROTOBUF_NAMESPACE_ID::ListValue& list_value() const; + PROTOBUF_NAMESPACE_ID::ListValue* release_list_value(); + PROTOBUF_NAMESPACE_ID::ListValue* mutable_list_value(); + void set_allocated_list_value(PROTOBUF_NAMESPACE_ID::ListValue* list_value); + private: + const PROTOBUF_NAMESPACE_ID::ListValue& _internal_list_value() const; + PROTOBUF_NAMESPACE_ID::ListValue* _internal_mutable_list_value(); + public: + void unsafe_arena_set_allocated_list_value( + PROTOBUF_NAMESPACE_ID::ListValue* list_value); + PROTOBUF_NAMESPACE_ID::ListValue* unsafe_arena_release_list_value(); + + void clear_kind(); + KindCase kind_case() const; + // @@protoc_insertion_point(class_scope:google.protobuf.Value) + private: + class _Internal; + void set_has_null_value(); + void set_has_number_value(); + void set_has_string_value(); + void set_has_bool_value(); + void set_has_struct_value(); + void set_has_list_value(); + + inline bool has_kind() const; + inline void clear_has_kind(); + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + union KindUnion { + KindUnion() {} + int null_value_; + double number_value_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr string_value_; + bool bool_value_; + PROTOBUF_NAMESPACE_ID::Struct* struct_value_; + PROTOBUF_NAMESPACE_ID::ListValue* list_value_; + } kind_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::uint32 _oneof_case_[1]; + + friend struct ::TableStruct_google_2fprotobuf_2fstruct_2eproto; +}; +// ------------------------------------------------------------------- + +class PROTOBUF_EXPORT ListValue PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:google.protobuf.ListValue) */ { + public: + inline ListValue() : ListValue(nullptr) {} + virtual ~ListValue(); + + ListValue(const ListValue& from); + ListValue(ListValue&& from) noexcept + : ListValue() { + *this = ::std::move(from); + } + + inline ListValue& operator=(const ListValue& from) { + CopyFrom(from); + return *this; + } + inline ListValue& operator=(ListValue&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const ListValue& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const ListValue* internal_default_instance() { + return reinterpret_cast( + &_ListValue_default_instance_); + } + static constexpr int kIndexInFileMessages = + 3; + + friend void swap(ListValue& a, ListValue& b) { + a.Swap(&b); + } + inline void Swap(ListValue* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(ListValue* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline ListValue* New() const final { + return CreateMaybeMessage(nullptr); + } + + ListValue* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const ListValue& from); + void MergeFrom(const ListValue& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(ListValue* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "google.protobuf.ListValue"; + } + protected: + explicit ListValue(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_google_2fprotobuf_2fstruct_2eproto); + return ::descriptor_table_google_2fprotobuf_2fstruct_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kValuesFieldNumber = 1, + }; + // repeated .google.protobuf.Value values = 1; + int values_size() const; + private: + int _internal_values_size() const; + public: + void clear_values(); + PROTOBUF_NAMESPACE_ID::Value* mutable_values(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::Value >* + mutable_values(); + private: + const PROTOBUF_NAMESPACE_ID::Value& _internal_values(int index) const; + PROTOBUF_NAMESPACE_ID::Value* _internal_add_values(); + public: + const PROTOBUF_NAMESPACE_ID::Value& values(int index) const; + PROTOBUF_NAMESPACE_ID::Value* add_values(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::Value >& + values() const; + + // @@protoc_insertion_point(class_scope:google.protobuf.ListValue) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::Value > values_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + friend struct ::TableStruct_google_2fprotobuf_2fstruct_2eproto; +}; +// =================================================================== + + +// =================================================================== + +#ifdef __GNUC__ + #pragma GCC diagnostic push + #pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif // __GNUC__ +// ------------------------------------------------------------------- + +// Struct + +// map fields = 1; +inline int Struct::_internal_fields_size() const { + return fields_.size(); +} +inline int Struct::fields_size() const { + return _internal_fields_size(); +} +inline void Struct::clear_fields() { + fields_.Clear(); +} +inline const ::PROTOBUF_NAMESPACE_ID::Map< std::string, PROTOBUF_NAMESPACE_ID::Value >& +Struct::_internal_fields() const { + return fields_.GetMap(); +} +inline const ::PROTOBUF_NAMESPACE_ID::Map< std::string, PROTOBUF_NAMESPACE_ID::Value >& +Struct::fields() const { + // @@protoc_insertion_point(field_map:google.protobuf.Struct.fields) + return _internal_fields(); +} +inline ::PROTOBUF_NAMESPACE_ID::Map< std::string, PROTOBUF_NAMESPACE_ID::Value >* +Struct::_internal_mutable_fields() { + return fields_.MutableMap(); +} +inline ::PROTOBUF_NAMESPACE_ID::Map< std::string, PROTOBUF_NAMESPACE_ID::Value >* +Struct::mutable_fields() { + // @@protoc_insertion_point(field_mutable_map:google.protobuf.Struct.fields) + return _internal_mutable_fields(); +} + +// ------------------------------------------------------------------- + +// Value + +// .google.protobuf.NullValue null_value = 1; +inline bool Value::_internal_has_null_value() const { + return kind_case() == kNullValue; +} +inline void Value::set_has_null_value() { + _oneof_case_[0] = kNullValue; +} +inline void Value::clear_null_value() { + if (_internal_has_null_value()) { + kind_.null_value_ = 0; + clear_has_kind(); + } +} +inline PROTOBUF_NAMESPACE_ID::NullValue Value::_internal_null_value() const { + if (_internal_has_null_value()) { + return static_cast< PROTOBUF_NAMESPACE_ID::NullValue >(kind_.null_value_); + } + return static_cast< PROTOBUF_NAMESPACE_ID::NullValue >(0); +} +inline PROTOBUF_NAMESPACE_ID::NullValue Value::null_value() const { + // @@protoc_insertion_point(field_get:google.protobuf.Value.null_value) + return _internal_null_value(); +} +inline void Value::_internal_set_null_value(PROTOBUF_NAMESPACE_ID::NullValue value) { + if (!_internal_has_null_value()) { + clear_kind(); + set_has_null_value(); + } + kind_.null_value_ = value; +} +inline void Value::set_null_value(PROTOBUF_NAMESPACE_ID::NullValue value) { + // @@protoc_insertion_point(field_set:google.protobuf.Value.null_value) + _internal_set_null_value(value); +} + +// double number_value = 2; +inline bool Value::_internal_has_number_value() const { + return kind_case() == kNumberValue; +} +inline void Value::set_has_number_value() { + _oneof_case_[0] = kNumberValue; +} +inline void Value::clear_number_value() { + if (_internal_has_number_value()) { + kind_.number_value_ = 0; + clear_has_kind(); + } +} +inline double Value::_internal_number_value() const { + if (_internal_has_number_value()) { + return kind_.number_value_; + } + return 0; +} +inline void Value::_internal_set_number_value(double value) { + if (!_internal_has_number_value()) { + clear_kind(); + set_has_number_value(); + } + kind_.number_value_ = value; +} +inline double Value::number_value() const { + // @@protoc_insertion_point(field_get:google.protobuf.Value.number_value) + return _internal_number_value(); +} +inline void Value::set_number_value(double value) { + _internal_set_number_value(value); + // @@protoc_insertion_point(field_set:google.protobuf.Value.number_value) +} + +// string string_value = 3; +inline bool Value::_internal_has_string_value() const { + return kind_case() == kStringValue; +} +inline void Value::set_has_string_value() { + _oneof_case_[0] = kStringValue; +} +inline void Value::clear_string_value() { + if (_internal_has_string_value()) { + kind_.string_value_.Destroy(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + clear_has_kind(); + } +} +inline const std::string& Value::string_value() const { + // @@protoc_insertion_point(field_get:google.protobuf.Value.string_value) + return _internal_string_value(); +} +inline void Value::set_string_value(const std::string& value) { + _internal_set_string_value(value); + // @@protoc_insertion_point(field_set:google.protobuf.Value.string_value) +} +inline std::string* Value::mutable_string_value() { + // @@protoc_insertion_point(field_mutable:google.protobuf.Value.string_value) + return _internal_mutable_string_value(); +} +inline const std::string& Value::_internal_string_value() const { + if (_internal_has_string_value()) { + return kind_.string_value_.Get(); + } + return *&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(); +} +inline void Value::_internal_set_string_value(const std::string& value) { + if (!_internal_has_string_value()) { + clear_kind(); + set_has_string_value(); + kind_.string_value_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + } + kind_.string_value_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void Value::set_string_value(std::string&& value) { + // @@protoc_insertion_point(field_set:google.protobuf.Value.string_value) + if (!_internal_has_string_value()) { + clear_kind(); + set_has_string_value(); + kind_.string_value_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + } + kind_.string_value_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:google.protobuf.Value.string_value) +} +inline void Value::set_string_value(const char* value) { + GOOGLE_DCHECK(value != nullptr); + if (!_internal_has_string_value()) { + clear_kind(); + set_has_string_value(); + kind_.string_value_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + } + kind_.string_value_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + ::std::string(value), GetArena()); + // @@protoc_insertion_point(field_set_char:google.protobuf.Value.string_value) +} +inline void Value::set_string_value(const char* value, + size_t size) { + if (!_internal_has_string_value()) { + clear_kind(); + set_has_string_value(); + kind_.string_value_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + } + kind_.string_value_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), + GetArena()); + // @@protoc_insertion_point(field_set_pointer:google.protobuf.Value.string_value) +} +inline std::string* Value::_internal_mutable_string_value() { + if (!_internal_has_string_value()) { + clear_kind(); + set_has_string_value(); + kind_.string_value_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + } + return kind_.string_value_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* Value::release_string_value() { + // @@protoc_insertion_point(field_release:google.protobuf.Value.string_value) + if (_internal_has_string_value()) { + clear_has_kind(); + return kind_.string_value_.Release(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + } else { + return nullptr; + } +} +inline void Value::set_allocated_string_value(std::string* string_value) { + if (has_kind()) { + clear_kind(); + } + if (string_value != nullptr) { + set_has_string_value(); + kind_.string_value_.UnsafeSetDefault(string_value); + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); + if (arena != nullptr) { + arena->Own(string_value); + } + } + // @@protoc_insertion_point(field_set_allocated:google.protobuf.Value.string_value) +} + +// bool bool_value = 4; +inline bool Value::_internal_has_bool_value() const { + return kind_case() == kBoolValue; +} +inline void Value::set_has_bool_value() { + _oneof_case_[0] = kBoolValue; +} +inline void Value::clear_bool_value() { + if (_internal_has_bool_value()) { + kind_.bool_value_ = false; + clear_has_kind(); + } +} +inline bool Value::_internal_bool_value() const { + if (_internal_has_bool_value()) { + return kind_.bool_value_; + } + return false; +} +inline void Value::_internal_set_bool_value(bool value) { + if (!_internal_has_bool_value()) { + clear_kind(); + set_has_bool_value(); + } + kind_.bool_value_ = value; +} +inline bool Value::bool_value() const { + // @@protoc_insertion_point(field_get:google.protobuf.Value.bool_value) + return _internal_bool_value(); +} +inline void Value::set_bool_value(bool value) { + _internal_set_bool_value(value); + // @@protoc_insertion_point(field_set:google.protobuf.Value.bool_value) +} + +// .google.protobuf.Struct struct_value = 5; +inline bool Value::_internal_has_struct_value() const { + return kind_case() == kStructValue; +} +inline bool Value::has_struct_value() const { + return _internal_has_struct_value(); +} +inline void Value::set_has_struct_value() { + _oneof_case_[0] = kStructValue; +} +inline void Value::clear_struct_value() { + if (_internal_has_struct_value()) { + if (GetArena() == nullptr) { + delete kind_.struct_value_; + } + clear_has_kind(); + } +} +inline PROTOBUF_NAMESPACE_ID::Struct* Value::release_struct_value() { + // @@protoc_insertion_point(field_release:google.protobuf.Value.struct_value) + if (_internal_has_struct_value()) { + clear_has_kind(); + PROTOBUF_NAMESPACE_ID::Struct* temp = kind_.struct_value_; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + kind_.struct_value_ = nullptr; + return temp; + } else { + return nullptr; + } +} +inline const PROTOBUF_NAMESPACE_ID::Struct& Value::_internal_struct_value() const { + return _internal_has_struct_value() + ? *kind_.struct_value_ + : *reinterpret_cast< PROTOBUF_NAMESPACE_ID::Struct*>(&PROTOBUF_NAMESPACE_ID::_Struct_default_instance_); +} +inline const PROTOBUF_NAMESPACE_ID::Struct& Value::struct_value() const { + // @@protoc_insertion_point(field_get:google.protobuf.Value.struct_value) + return _internal_struct_value(); +} +inline PROTOBUF_NAMESPACE_ID::Struct* Value::unsafe_arena_release_struct_value() { + // @@protoc_insertion_point(field_unsafe_arena_release:google.protobuf.Value.struct_value) + if (_internal_has_struct_value()) { + clear_has_kind(); + PROTOBUF_NAMESPACE_ID::Struct* temp = kind_.struct_value_; + kind_.struct_value_ = nullptr; + return temp; + } else { + return nullptr; + } +} +inline void Value::unsafe_arena_set_allocated_struct_value(PROTOBUF_NAMESPACE_ID::Struct* struct_value) { + clear_kind(); + if (struct_value) { + set_has_struct_value(); + kind_.struct_value_ = struct_value; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:google.protobuf.Value.struct_value) +} +inline PROTOBUF_NAMESPACE_ID::Struct* Value::_internal_mutable_struct_value() { + if (!_internal_has_struct_value()) { + clear_kind(); + set_has_struct_value(); + kind_.struct_value_ = CreateMaybeMessage< PROTOBUF_NAMESPACE_ID::Struct >(GetArena()); + } + return kind_.struct_value_; +} +inline PROTOBUF_NAMESPACE_ID::Struct* Value::mutable_struct_value() { + // @@protoc_insertion_point(field_mutable:google.protobuf.Value.struct_value) + return _internal_mutable_struct_value(); +} + +// .google.protobuf.ListValue list_value = 6; +inline bool Value::_internal_has_list_value() const { + return kind_case() == kListValue; +} +inline bool Value::has_list_value() const { + return _internal_has_list_value(); +} +inline void Value::set_has_list_value() { + _oneof_case_[0] = kListValue; +} +inline void Value::clear_list_value() { + if (_internal_has_list_value()) { + if (GetArena() == nullptr) { + delete kind_.list_value_; + } + clear_has_kind(); + } +} +inline PROTOBUF_NAMESPACE_ID::ListValue* Value::release_list_value() { + // @@protoc_insertion_point(field_release:google.protobuf.Value.list_value) + if (_internal_has_list_value()) { + clear_has_kind(); + PROTOBUF_NAMESPACE_ID::ListValue* temp = kind_.list_value_; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + kind_.list_value_ = nullptr; + return temp; + } else { + return nullptr; + } +} +inline const PROTOBUF_NAMESPACE_ID::ListValue& Value::_internal_list_value() const { + return _internal_has_list_value() + ? *kind_.list_value_ + : *reinterpret_cast< PROTOBUF_NAMESPACE_ID::ListValue*>(&PROTOBUF_NAMESPACE_ID::_ListValue_default_instance_); +} +inline const PROTOBUF_NAMESPACE_ID::ListValue& Value::list_value() const { + // @@protoc_insertion_point(field_get:google.protobuf.Value.list_value) + return _internal_list_value(); +} +inline PROTOBUF_NAMESPACE_ID::ListValue* Value::unsafe_arena_release_list_value() { + // @@protoc_insertion_point(field_unsafe_arena_release:google.protobuf.Value.list_value) + if (_internal_has_list_value()) { + clear_has_kind(); + PROTOBUF_NAMESPACE_ID::ListValue* temp = kind_.list_value_; + kind_.list_value_ = nullptr; + return temp; + } else { + return nullptr; + } +} +inline void Value::unsafe_arena_set_allocated_list_value(PROTOBUF_NAMESPACE_ID::ListValue* list_value) { + clear_kind(); + if (list_value) { + set_has_list_value(); + kind_.list_value_ = list_value; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:google.protobuf.Value.list_value) +} +inline PROTOBUF_NAMESPACE_ID::ListValue* Value::_internal_mutable_list_value() { + if (!_internal_has_list_value()) { + clear_kind(); + set_has_list_value(); + kind_.list_value_ = CreateMaybeMessage< PROTOBUF_NAMESPACE_ID::ListValue >(GetArena()); + } + return kind_.list_value_; +} +inline PROTOBUF_NAMESPACE_ID::ListValue* Value::mutable_list_value() { + // @@protoc_insertion_point(field_mutable:google.protobuf.Value.list_value) + return _internal_mutable_list_value(); +} + +inline bool Value::has_kind() const { + return kind_case() != KIND_NOT_SET; +} +inline void Value::clear_has_kind() { + _oneof_case_[0] = KIND_NOT_SET; +} +inline Value::KindCase Value::kind_case() const { + return Value::KindCase(_oneof_case_[0]); +} +// ------------------------------------------------------------------- + +// ListValue + +// repeated .google.protobuf.Value values = 1; +inline int ListValue::_internal_values_size() const { + return values_.size(); +} +inline int ListValue::values_size() const { + return _internal_values_size(); +} +inline void ListValue::clear_values() { + values_.Clear(); +} +inline PROTOBUF_NAMESPACE_ID::Value* ListValue::mutable_values(int index) { + // @@protoc_insertion_point(field_mutable:google.protobuf.ListValue.values) + return values_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::Value >* +ListValue::mutable_values() { + // @@protoc_insertion_point(field_mutable_list:google.protobuf.ListValue.values) + return &values_; +} +inline const PROTOBUF_NAMESPACE_ID::Value& ListValue::_internal_values(int index) const { + return values_.Get(index); +} +inline const PROTOBUF_NAMESPACE_ID::Value& ListValue::values(int index) const { + // @@protoc_insertion_point(field_get:google.protobuf.ListValue.values) + return _internal_values(index); +} +inline PROTOBUF_NAMESPACE_ID::Value* ListValue::_internal_add_values() { + return values_.Add(); +} +inline PROTOBUF_NAMESPACE_ID::Value* ListValue::add_values() { + // @@protoc_insertion_point(field_add:google.protobuf.ListValue.values) + return _internal_add_values(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::Value >& +ListValue::values() const { + // @@protoc_insertion_point(field_list:google.protobuf.ListValue.values) + return values_; +} + +#ifdef __GNUC__ + #pragma GCC diagnostic pop +#endif // __GNUC__ +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + + +// @@protoc_insertion_point(namespace_scope) + +PROTOBUF_NAMESPACE_CLOSE + +PROTOBUF_NAMESPACE_OPEN + +template <> struct is_proto_enum< PROTOBUF_NAMESPACE_ID::NullValue> : ::std::true_type {}; +template <> +inline const EnumDescriptor* GetEnumDescriptor< PROTOBUF_NAMESPACE_ID::NullValue>() { + return PROTOBUF_NAMESPACE_ID::NullValue_descriptor(); +} + +PROTOBUF_NAMESPACE_CLOSE + +// @@protoc_insertion_point(global_scope) + +#include +#endif // GOOGLE_PROTOBUF_INCLUDED_GOOGLE_PROTOBUF_INCLUDED_google_2fprotobuf_2fstruct_2eproto + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/bytestream.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/bytestream.h new file mode 100644 index 0000000000000000000000000000000000000000..65d62941aa61faeae9eb0ca6f69b28b0b1e3b151 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/bytestream.h @@ -0,0 +1,356 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// This file declares the ByteSink and ByteSource abstract interfaces. These +// interfaces represent objects that consume (ByteSink) or produce (ByteSource) +// a sequence of bytes. Using these abstract interfaces in your APIs can help +// make your code work with a variety of input and output types. +// +// This file also declares the following commonly used implementations of these +// interfaces. +// +// ByteSink: +// UncheckedArrayByteSink Writes to an array, without bounds checking +// CheckedArrayByteSink Writes to an array, with bounds checking +// GrowingArrayByteSink Allocates and writes to a growable buffer +// StringByteSink Writes to an STL string +// NullByteSink Consumes a never-ending stream of bytes +// +// ByteSource: +// ArrayByteSource Reads from an array or string/StringPiece +// LimitedByteSource Limits the number of bytes read from an + +#ifndef GOOGLE_PROTOBUF_STUBS_BYTESTREAM_H_ +#define GOOGLE_PROTOBUF_STUBS_BYTESTREAM_H_ + +#include +#include + +#include +#include + +#include + +class CordByteSink; + +namespace google { +namespace protobuf { +namespace strings { + +// An abstract interface for an object that consumes a sequence of bytes. This +// interface offers a way to append data as well as a Flush() function. +// +// Example: +// +// string my_data; +// ... +// ByteSink* sink = ... +// sink->Append(my_data.data(), my_data.size()); +// sink->Flush(); +// +class PROTOBUF_EXPORT ByteSink { + public: + ByteSink() {} + virtual ~ByteSink() {} + + // Appends the "n" bytes starting at "bytes". + virtual void Append(const char* bytes, size_t n) = 0; + + // Flushes internal buffers. The default implementation does nothing. ByteSink + // subclasses may use internal buffers that require calling Flush() at the end + // of the stream. + virtual void Flush(); + + private: + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(ByteSink); +}; + +// An abstract interface for an object that produces a fixed-size sequence of +// bytes. +// +// Example: +// +// ByteSource* source = ... +// while (source->Available() > 0) { +// StringPiece data = source->Peek(); +// ... do something with "data" ... +// source->Skip(data.length()); +// } +// +class PROTOBUF_EXPORT ByteSource { + public: + ByteSource() {} + virtual ~ByteSource() {} + + // Returns the number of bytes left to read from the source. Available() + // should decrease by N each time Skip(N) is called. Available() may not + // increase. Available() returning 0 indicates that the ByteSource is + // exhausted. + // + // Note: Size() may have been a more appropriate name as it's more + // indicative of the fixed-size nature of a ByteSource. + virtual size_t Available() const = 0; + + // Returns a StringPiece of the next contiguous region of the source. Does not + // reposition the source. The returned region is empty iff Available() == 0. + // + // The returned region is valid until the next call to Skip() or until this + // object is destroyed, whichever occurs first. + // + // The length of the returned StringPiece will be <= Available(). + virtual StringPiece Peek() = 0; + + // Skips the next n bytes. Invalidates any StringPiece returned by a previous + // call to Peek(). + // + // REQUIRES: Available() >= n + virtual void Skip(size_t n) = 0; + + // Writes the next n bytes in this ByteSource to the given ByteSink, and + // advances this ByteSource past the copied bytes. The default implementation + // of this method just copies the bytes normally, but subclasses might + // override CopyTo to optimize certain cases. + // + // REQUIRES: Available() >= n + virtual void CopyTo(ByteSink* sink, size_t n); + + private: + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(ByteSource); +}; + +// +// Some commonly used implementations of ByteSink +// + +// Implementation of ByteSink that writes to an unsized byte array. No +// bounds-checking is performed--it is the caller's responsibility to ensure +// that the destination array is large enough. +// +// Example: +// +// char buf[10]; +// UncheckedArrayByteSink sink(buf); +// sink.Append("hi", 2); // OK +// sink.Append(data, 100); // WOOPS! Overflows buf[10]. +// +class PROTOBUF_EXPORT UncheckedArrayByteSink : public ByteSink { + public: + explicit UncheckedArrayByteSink(char* dest) : dest_(dest) {} + virtual void Append(const char* data, size_t n) override; + + // Returns the current output pointer so that a caller can see how many bytes + // were produced. + // + // Note: this method is not part of the ByteSink interface. + char* CurrentDestination() const { return dest_; } + + private: + char* dest_; + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(UncheckedArrayByteSink); +}; + +// Implementation of ByteSink that writes to a sized byte array. This sink will +// not write more than "capacity" bytes to outbuf. Once "capacity" bytes are +// appended, subsequent bytes will be ignored and Overflowed() will return true. +// Overflowed() does not cause a runtime error (i.e., it does not CHECK fail). +// +// Example: +// +// char buf[10]; +// CheckedArrayByteSink sink(buf, 10); +// sink.Append("hi", 2); // OK +// sink.Append(data, 100); // Will only write 8 more bytes +// +class PROTOBUF_EXPORT CheckedArrayByteSink : public ByteSink { + public: + CheckedArrayByteSink(char* outbuf, size_t capacity); + virtual void Append(const char* bytes, size_t n) override; + + // Returns the number of bytes actually written to the sink. + size_t NumberOfBytesWritten() const { return size_; } + + // Returns true if any bytes were discarded, i.e., if there was an + // attempt to write more than 'capacity' bytes. + bool Overflowed() const { return overflowed_; } + + private: + char* outbuf_; + const size_t capacity_; + size_t size_; + bool overflowed_; + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(CheckedArrayByteSink); +}; + +// Implementation of ByteSink that allocates an internal buffer (a char array) +// and expands it as needed to accommodate appended data (similar to a string), +// and allows the caller to take ownership of the internal buffer via the +// GetBuffer() method. The buffer returned from GetBuffer() must be deleted by +// the caller with delete[]. GetBuffer() also sets the internal buffer to be +// empty, and subsequent appends to the sink will create a new buffer. The +// destructor will free the internal buffer if GetBuffer() was not called. +// +// Example: +// +// GrowingArrayByteSink sink(10); +// sink.Append("hi", 2); +// sink.Append(data, n); +// const char* buf = sink.GetBuffer(); // Ownership transferred +// delete[] buf; +// +class PROTOBUF_EXPORT GrowingArrayByteSink : public strings::ByteSink { + public: + explicit GrowingArrayByteSink(size_t estimated_size); + virtual ~GrowingArrayByteSink(); + virtual void Append(const char* bytes, size_t n) override; + + // Returns the allocated buffer, and sets nbytes to its size. The caller takes + // ownership of the buffer and must delete it with delete[]. + char* GetBuffer(size_t* nbytes); + + private: + void Expand(size_t amount); + void ShrinkToFit(); + + size_t capacity_; + char* buf_; + size_t size_; + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(GrowingArrayByteSink); +}; + +// Implementation of ByteSink that appends to the given string. +// Existing contents of "dest" are not modified; new data is appended. +// +// Example: +// +// string dest = "Hello "; +// StringByteSink sink(&dest); +// sink.Append("World", 5); +// assert(dest == "Hello World"); +// +class PROTOBUF_EXPORT StringByteSink : public ByteSink { + public: + explicit StringByteSink(string* dest) : dest_(dest) {} + virtual void Append(const char* data, size_t n) override; + + private: + string* dest_; + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(StringByteSink); +}; + +// Implementation of ByteSink that discards all data. +// +// Example: +// +// NullByteSink sink; +// sink.Append(data, data.size()); // All data ignored. +// +class PROTOBUF_EXPORT NullByteSink : public ByteSink { + public: + NullByteSink() {} + void Append(const char* /*data*/, size_t /*n*/) override {} + + private: + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(NullByteSink); +}; + +// +// Some commonly used implementations of ByteSource +// + +// Implementation of ByteSource that reads from a StringPiece. +// +// Example: +// +// string data = "Hello"; +// ArrayByteSource source(data); +// assert(source.Available() == 5); +// assert(source.Peek() == "Hello"); +// +class PROTOBUF_EXPORT ArrayByteSource : public ByteSource { + public: + explicit ArrayByteSource(StringPiece s) : input_(s) {} + + virtual size_t Available() const override; + virtual StringPiece Peek() override; + virtual void Skip(size_t n) override; + + private: + StringPiece input_; + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(ArrayByteSource); +}; + +// Implementation of ByteSource that wraps another ByteSource, limiting the +// number of bytes returned. +// +// The caller maintains ownership of the underlying source, and may not use the +// underlying source while using the LimitByteSource object. The underlying +// source's pointer is advanced by n bytes every time this LimitByteSource +// object is advanced by n. +// +// Example: +// +// string data = "Hello World"; +// ArrayByteSource abs(data); +// assert(abs.Available() == data.size()); +// +// LimitByteSource limit(abs, 5); +// assert(limit.Available() == 5); +// assert(limit.Peek() == "Hello"); +// +class PROTOBUF_EXPORT LimitByteSource : public ByteSource { + public: + // Returns at most "limit" bytes from "source". + LimitByteSource(ByteSource* source, size_t limit); + + virtual size_t Available() const override; + virtual StringPiece Peek() override; + virtual void Skip(size_t n) override; + + // We override CopyTo so that we can forward to the underlying source, in + // case it has an efficient implementation of CopyTo. + virtual void CopyTo(ByteSink* sink, size_t n) override; + + private: + ByteSource* source_; + size_t limit_; +}; + +} // namespace strings +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_STUBS_BYTESTREAM_H_ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/callback.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/callback.h new file mode 100644 index 0000000000000000000000000000000000000000..731d46fc821ba3055073a5f66046892c7485903d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/callback.h @@ -0,0 +1,588 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#ifndef GOOGLE_PROTOBUF_STUBS_CALLBACK_H_ +#define GOOGLE_PROTOBUF_STUBS_CALLBACK_H_ + +#include + +#include + +#include + +// =================================================================== +// emulates google3/base/callback.h + +namespace google { +namespace protobuf { + +// Abstract interface for a callback. When calling an RPC, you must provide +// a Closure to call when the procedure completes. See the Service interface +// in service.h. +// +// To automatically construct a Closure which calls a particular function or +// method with a particular set of parameters, use the NewCallback() function. +// Example: +// void FooDone(const FooResponse* response) { +// ... +// } +// +// void CallFoo() { +// ... +// // When done, call FooDone() and pass it a pointer to the response. +// Closure* callback = NewCallback(&FooDone, response); +// // Make the call. +// service->Foo(controller, request, response, callback); +// } +// +// Example that calls a method: +// class Handler { +// public: +// ... +// +// void FooDone(const FooResponse* response) { +// ... +// } +// +// void CallFoo() { +// ... +// // When done, call FooDone() and pass it a pointer to the response. +// Closure* callback = NewCallback(this, &Handler::FooDone, response); +// // Make the call. +// service->Foo(controller, request, response, callback); +// } +// }; +// +// Currently NewCallback() supports binding zero, one, or two arguments. +// +// Callbacks created with NewCallback() automatically delete themselves when +// executed. They should be used when a callback is to be called exactly +// once (usually the case with RPC callbacks). If a callback may be called +// a different number of times (including zero), create it with +// NewPermanentCallback() instead. You are then responsible for deleting the +// callback (using the "delete" keyword as normal). +// +// Note that NewCallback() is a bit touchy regarding argument types. Generally, +// the values you provide for the parameter bindings must exactly match the +// types accepted by the callback function. For example: +// void Foo(string s); +// NewCallback(&Foo, "foo"); // WON'T WORK: const char* != string +// NewCallback(&Foo, string("foo")); // WORKS +// Also note that the arguments cannot be references: +// void Foo(const string& s); +// string my_str; +// NewCallback(&Foo, my_str); // WON'T WORK: Can't use references. +// However, correctly-typed pointers will work just fine. +class PROTOBUF_EXPORT Closure { + public: + Closure() {} + virtual ~Closure(); + + virtual void Run() = 0; + + private: + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(Closure); +}; + +template +class ResultCallback { + public: + ResultCallback() {} + virtual ~ResultCallback() {} + + virtual R Run() = 0; + + private: + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(ResultCallback); +}; + +template +class PROTOBUF_EXPORT ResultCallback1 { + public: + ResultCallback1() {} + virtual ~ResultCallback1() {} + + virtual R Run(A1) = 0; + + private: + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(ResultCallback1); +}; + +template +class PROTOBUF_EXPORT ResultCallback2 { + public: + ResultCallback2() {} + virtual ~ResultCallback2() {} + + virtual R Run(A1,A2) = 0; + + private: + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(ResultCallback2); +}; + +namespace internal { + +class PROTOBUF_EXPORT FunctionClosure0 : public Closure { + public: + typedef void (*FunctionType)(); + + FunctionClosure0(FunctionType function, bool self_deleting) + : function_(function), self_deleting_(self_deleting) {} + ~FunctionClosure0(); + + void Run() override { + bool needs_delete = self_deleting_; // read in case callback deletes + function_(); + if (needs_delete) delete this; + } + + private: + FunctionType function_; + bool self_deleting_; +}; + +template +class MethodClosure0 : public Closure { + public: + typedef void (Class::*MethodType)(); + + MethodClosure0(Class* object, MethodType method, bool self_deleting) + : object_(object), method_(method), self_deleting_(self_deleting) {} + ~MethodClosure0() {} + + void Run() override { + bool needs_delete = self_deleting_; // read in case callback deletes + (object_->*method_)(); + if (needs_delete) delete this; + } + + private: + Class* object_; + MethodType method_; + bool self_deleting_; +}; + +template +class FunctionClosure1 : public Closure { + public: + typedef void (*FunctionType)(Arg1 arg1); + + FunctionClosure1(FunctionType function, bool self_deleting, + Arg1 arg1) + : function_(function), self_deleting_(self_deleting), + arg1_(arg1) {} + ~FunctionClosure1() {} + + void Run() override { + bool needs_delete = self_deleting_; // read in case callback deletes + function_(arg1_); + if (needs_delete) delete this; + } + + private: + FunctionType function_; + bool self_deleting_; + Arg1 arg1_; +}; + +template +class MethodClosure1 : public Closure { + public: + typedef void (Class::*MethodType)(Arg1 arg1); + + MethodClosure1(Class* object, MethodType method, bool self_deleting, + Arg1 arg1) + : object_(object), method_(method), self_deleting_(self_deleting), + arg1_(arg1) {} + ~MethodClosure1() {} + + void Run() override { + bool needs_delete = self_deleting_; // read in case callback deletes + (object_->*method_)(arg1_); + if (needs_delete) delete this; + } + + private: + Class* object_; + MethodType method_; + bool self_deleting_; + Arg1 arg1_; +}; + +template +class FunctionClosure2 : public Closure { + public: + typedef void (*FunctionType)(Arg1 arg1, Arg2 arg2); + + FunctionClosure2(FunctionType function, bool self_deleting, + Arg1 arg1, Arg2 arg2) + : function_(function), self_deleting_(self_deleting), + arg1_(arg1), arg2_(arg2) {} + ~FunctionClosure2() {} + + void Run() override { + bool needs_delete = self_deleting_; // read in case callback deletes + function_(arg1_, arg2_); + if (needs_delete) delete this; + } + + private: + FunctionType function_; + bool self_deleting_; + Arg1 arg1_; + Arg2 arg2_; +}; + +template +class MethodClosure2 : public Closure { + public: + typedef void (Class::*MethodType)(Arg1 arg1, Arg2 arg2); + + MethodClosure2(Class* object, MethodType method, bool self_deleting, + Arg1 arg1, Arg2 arg2) + : object_(object), method_(method), self_deleting_(self_deleting), + arg1_(arg1), arg2_(arg2) {} + ~MethodClosure2() {} + + void Run() override { + bool needs_delete = self_deleting_; // read in case callback deletes + (object_->*method_)(arg1_, arg2_); + if (needs_delete) delete this; + } + + private: + Class* object_; + MethodType method_; + bool self_deleting_; + Arg1 arg1_; + Arg2 arg2_; +}; + +template +class FunctionResultCallback_0_0 : public ResultCallback { + public: + typedef R (*FunctionType)(); + + FunctionResultCallback_0_0(FunctionType function, bool self_deleting) + : function_(function), self_deleting_(self_deleting) {} + ~FunctionResultCallback_0_0() {} + + R Run() override { + bool needs_delete = self_deleting_; // read in case callback deletes + R result = function_(); + if (needs_delete) delete this; + return result; + } + + private: + FunctionType function_; + bool self_deleting_; +}; + +template +class FunctionResultCallback_1_0 : public ResultCallback { + public: + typedef R (*FunctionType)(P1); + + FunctionResultCallback_1_0(FunctionType function, bool self_deleting, + P1 p1) + : function_(function), self_deleting_(self_deleting), p1_(p1) {} + ~FunctionResultCallback_1_0() {} + + R Run() override { + bool needs_delete = self_deleting_; // read in case callback deletes + R result = function_(p1_); + if (needs_delete) delete this; + return result; + } + + private: + FunctionType function_; + bool self_deleting_; + P1 p1_; +}; + +template +class FunctionResultCallback_0_1 : public ResultCallback1 { + public: + typedef R (*FunctionType)(Arg1 arg1); + + FunctionResultCallback_0_1(FunctionType function, bool self_deleting) + : function_(function), self_deleting_(self_deleting) {} + ~FunctionResultCallback_0_1() {} + + R Run(Arg1 a1) override { + bool needs_delete = self_deleting_; // read in case callback deletes + R result = function_(a1); + if (needs_delete) delete this; + return result; + } + + private: + FunctionType function_; + bool self_deleting_; +}; + +template +class FunctionResultCallback_1_1 : public ResultCallback1 { + public: + typedef R (*FunctionType)(P1, A1); + + FunctionResultCallback_1_1(FunctionType function, bool self_deleting, + P1 p1) + : function_(function), self_deleting_(self_deleting), p1_(p1) {} + ~FunctionResultCallback_1_1() {} + + R Run(A1 a1) override { + bool needs_delete = self_deleting_; // read in case callback deletes + R result = function_(p1_, a1); + if (needs_delete) delete this; + return result; + } + + private: + FunctionType function_; + bool self_deleting_; + P1 p1_; +}; + +template +struct InternalConstRef { + typedef typename std::remove_reference::type base_type; + typedef const base_type& type; +}; + +template +class MethodResultCallback_0_0 : public ResultCallback { + public: + typedef R (T::*MethodType)(); + MethodResultCallback_0_0(T* object, MethodType method, bool self_deleting) + : object_(object), + method_(method), + self_deleting_(self_deleting) {} + ~MethodResultCallback_0_0() {} + + R Run() { + bool needs_delete = self_deleting_; + R result = (object_->*method_)(); + if (needs_delete) delete this; + return result; + } + + private: + T* object_; + MethodType method_; + bool self_deleting_; +}; + +template +class MethodResultCallback_6_2 : public ResultCallback2 { + public: + typedef R (T::*MethodType)(P1, P2, P3, P4, P5, P6, A1, A2); + MethodResultCallback_6_2(T* object, MethodType method, bool self_deleting, + P1 p1, P2 p2, P3 p3, P4 p4, P5 p5, P6 p6) + : object_(object), + method_(method), + self_deleting_(self_deleting), + p1_(p1), + p2_(p2), + p3_(p3), + p4_(p4), + p5_(p5), + p6_(p6) {} + ~MethodResultCallback_6_2() {} + + R Run(A1 a1, A2 a2) override { + bool needs_delete = self_deleting_; + R result = (object_->*method_)(p1_, p2_, p3_, p4_, p5_, p6_, a1, a2); + if (needs_delete) delete this; + return result; + } + + private: + T* object_; + MethodType method_; + bool self_deleting_; + typename std::remove_reference::type p1_; + typename std::remove_reference::type p2_; + typename std::remove_reference::type p3_; + typename std::remove_reference::type p4_; + typename std::remove_reference::type p5_; + typename std::remove_reference::type p6_; +}; + +} // namespace internal + +// See Closure. +inline Closure* NewCallback(void (*function)()) { + return new internal::FunctionClosure0(function, true); +} + +// See Closure. +inline Closure* NewPermanentCallback(void (*function)()) { + return new internal::FunctionClosure0(function, false); +} + +// See Closure. +template +inline Closure* NewCallback(Class* object, void (Class::*method)()) { + return new internal::MethodClosure0(object, method, true); +} + +// See Closure. +template +inline Closure* NewPermanentCallback(Class* object, void (Class::*method)()) { + return new internal::MethodClosure0(object, method, false); +} + +// See Closure. +template +inline Closure* NewCallback(void (*function)(Arg1), + Arg1 arg1) { + return new internal::FunctionClosure1(function, true, arg1); +} + +// See Closure. +template +inline Closure* NewPermanentCallback(void (*function)(Arg1), + Arg1 arg1) { + return new internal::FunctionClosure1(function, false, arg1); +} + +// See Closure. +template +inline Closure* NewCallback(Class* object, void (Class::*method)(Arg1), + Arg1 arg1) { + return new internal::MethodClosure1(object, method, true, arg1); +} + +// See Closure. +template +inline Closure* NewPermanentCallback(Class* object, void (Class::*method)(Arg1), + Arg1 arg1) { + return new internal::MethodClosure1(object, method, false, arg1); +} + +// See Closure. +template +inline Closure* NewCallback(void (*function)(Arg1, Arg2), + Arg1 arg1, Arg2 arg2) { + return new internal::FunctionClosure2( + function, true, arg1, arg2); +} + +// See Closure. +template +inline Closure* NewPermanentCallback(void (*function)(Arg1, Arg2), + Arg1 arg1, Arg2 arg2) { + return new internal::FunctionClosure2( + function, false, arg1, arg2); +} + +// See Closure. +template +inline Closure* NewCallback(Class* object, void (Class::*method)(Arg1, Arg2), + Arg1 arg1, Arg2 arg2) { + return new internal::MethodClosure2( + object, method, true, arg1, arg2); +} + +// See Closure. +template +inline Closure* NewPermanentCallback( + Class* object, void (Class::*method)(Arg1, Arg2), + Arg1 arg1, Arg2 arg2) { + return new internal::MethodClosure2( + object, method, false, arg1, arg2); +} + +// See ResultCallback +template +inline ResultCallback* NewCallback(R (*function)()) { + return new internal::FunctionResultCallback_0_0(function, true); +} + +// See ResultCallback +template +inline ResultCallback* NewPermanentCallback(R (*function)()) { + return new internal::FunctionResultCallback_0_0(function, false); +} + +// See ResultCallback +template +inline ResultCallback* NewCallback(R (*function)(P1), P1 p1) { + return new internal::FunctionResultCallback_1_0( + function, true, p1); +} + +// See ResultCallback +template +inline ResultCallback* NewPermanentCallback( + R (*function)(P1), P1 p1) { + return new internal::FunctionResultCallback_1_0( + function, false, p1); +} + +// See ResultCallback1 +template +inline ResultCallback1* NewCallback(R (*function)(A1)) { + return new internal::FunctionResultCallback_0_1(function, true); +} + +// See ResultCallback1 +template +inline ResultCallback1* NewPermanentCallback(R (*function)(A1)) { + return new internal::FunctionResultCallback_0_1(function, false); +} + +// See ResultCallback1 +template +inline ResultCallback1* NewCallback(R (*function)(P1, A1), P1 p1) { + return new internal::FunctionResultCallback_1_1( + function, true, p1); +} + +// See ResultCallback1 +template +inline ResultCallback1* NewPermanentCallback( + R (*function)(P1, A1), P1 p1) { + return new internal::FunctionResultCallback_1_1( + function, false, p1); +} + +// See MethodResultCallback_0_0 +template +inline ResultCallback* NewPermanentCallback( + T1* object, R (T2::*function)()) { + return new internal::MethodResultCallback_0_0(object, function, false); +} + +// See MethodResultCallback_6_2 +template +inline ResultCallback2* NewPermanentCallback( + T* object, R (T::*function)(P1, P2, P3, P4, P5, P6, A1, A2), + typename internal::InternalConstRef::type p1, + typename internal::InternalConstRef::type p2, + typename internal::InternalConstRef::type p3, + typename internal::InternalConstRef::type p4, + typename internal::InternalConstRef::type p5, + typename internal::InternalConstRef::type p6) { + return new internal::MethodResultCallback_6_2(object, function, false, + p1, p2, p3, p4, p5, p6); +} + +// A function which does nothing. Useful for creating no-op callbacks, e.g.: +// Closure* nothing = NewCallback(&DoNothing); +void PROTOBUF_EXPORT DoNothing(); + +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_STUBS_CALLBACK_H_ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/casts.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/casts.h new file mode 100644 index 0000000000000000000000000000000000000000..b77ca87d78970e2431d5cc65626b81d9275b1d41 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/casts.h @@ -0,0 +1,144 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2014 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef GOOGLE_PROTOBUF_CASTS_H__ +#define GOOGLE_PROTOBUF_CASTS_H__ + +#include + +#include +#include + +namespace google { +namespace protobuf { +namespace internal { + +// Use implicit_cast as a safe version of static_cast or const_cast +// for upcasting in the type hierarchy (i.e. casting a pointer to Foo +// to a pointer to SuperclassOfFoo or casting a pointer to Foo to +// a const pointer to Foo). +// When you use implicit_cast, the compiler checks that the cast is safe. +// Such explicit implicit_casts are necessary in surprisingly many +// situations where C++ demands an exact type match instead of an +// argument type convertable to a target type. +// +// The From type can be inferred, so the preferred syntax for using +// implicit_cast is the same as for static_cast etc.: +// +// implicit_cast(expr) +// +// implicit_cast would have been part of the C++ standard library, +// but the proposal was submitted too late. It will probably make +// its way into the language in the future. +template +inline To implicit_cast(From const &f) { + return f; +} + +// When you upcast (that is, cast a pointer from type Foo to type +// SuperclassOfFoo), it's fine to use implicit_cast<>, since upcasts +// always succeed. When you downcast (that is, cast a pointer from +// type Foo to type SubclassOfFoo), static_cast<> isn't safe, because +// how do you know the pointer is really of type SubclassOfFoo? It +// could be a bare Foo, or of type DifferentSubclassOfFoo. Thus, +// when you downcast, you should use this macro. In debug mode, we +// use dynamic_cast<> to double-check the downcast is legal (we die +// if it's not). In normal mode, we do the efficient static_cast<> +// instead. Thus, it's important to test in debug mode to make sure +// the cast is legal! +// This is the only place in the code we should use dynamic_cast<>. +// In particular, you SHOULDN'T be using dynamic_cast<> in order to +// do RTTI (eg code like this: +// if (dynamic_cast(foo)) HandleASubclass1Object(foo); +// if (dynamic_cast(foo)) HandleASubclass2Object(foo); +// You should design the code some other way not to need this. + +template // use like this: down_cast(foo); +inline To down_cast(From* f) { // so we only accept pointers + // Ensures that To is a sub-type of From *. This test is here only + // for compile-time type checking, and has no overhead in an + // optimized build at run-time, as it will be optimized away + // completely. + if (false) { + implicit_cast(0); + } + +#if !defined(NDEBUG) && PROTOBUF_RTTI + assert(f == nullptr || dynamic_cast(f) != nullptr); // RTTI: debug mode only! +#endif + return static_cast(f); +} + +template // use like this: down_cast(foo); +inline To down_cast(From& f) { + typedef typename std::remove_reference::type* ToAsPointer; + // Ensures that To is a sub-type of From *. This test is here only + // for compile-time type checking, and has no overhead in an + // optimized build at run-time, as it will be optimized away + // completely. + if (false) { + implicit_cast(0); + } + +#if !defined(NDEBUG) && PROTOBUF_RTTI + // RTTI: debug mode only! + assert(dynamic_cast(&f) != nullptr); +#endif + return *static_cast(&f); +} + +template +inline To bit_cast(const From& from) { + GOOGLE_COMPILE_ASSERT(sizeof(From) == sizeof(To), + bit_cast_with_different_sizes); + To dest; + memcpy(&dest, &from, sizeof(dest)); + return dest; +} + +} // namespace internal + +// We made these internal so that they would show up as such in the docs, +// but we don't want to stick "internal::" in front of them everywhere. +using internal::implicit_cast; +using internal::down_cast; +using internal::bit_cast; + +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_CASTS_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/common.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/common.h new file mode 100644 index 0000000000000000000000000000000000000000..ddfd338bf68f0b5602ff46e655a35ecd59071bd7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/common.h @@ -0,0 +1,207 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: kenton@google.com (Kenton Varda) and others +// +// Contains basic types and utilities used by the rest of the library. + +#ifndef GOOGLE_PROTOBUF_COMMON_H__ +#define GOOGLE_PROTOBUF_COMMON_H__ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#ifndef PROTOBUF_USE_EXCEPTIONS +#if defined(_MSC_VER) && defined(_CPPUNWIND) + #define PROTOBUF_USE_EXCEPTIONS 1 +#elif defined(__EXCEPTIONS) + #define PROTOBUF_USE_EXCEPTIONS 1 +#else + #define PROTOBUF_USE_EXCEPTIONS 0 +#endif +#endif + +#if PROTOBUF_USE_EXCEPTIONS +#include +#endif +#if defined(__APPLE__) +#include // for TARGET_OS_IPHONE +#endif + +#if defined(__ANDROID__) || defined(GOOGLE_PROTOBUF_OS_ANDROID) || (defined(TARGET_OS_IPHONE) && TARGET_OS_IPHONE) || defined(GOOGLE_PROTOBUF_OS_IPHONE) +#include +#endif + +#include + +namespace std {} + +namespace google { +namespace protobuf { +namespace internal { + +// Some of these constants are macros rather than const ints so that they can +// be used in #if directives. + +// The current version, represented as a single integer to make comparison +// easier: major * 10^6 + minor * 10^3 + micro +#define GOOGLE_PROTOBUF_VERSION 3013000 + +// A suffix string for alpha, beta or rc releases. Empty for stable releases. +#define GOOGLE_PROTOBUF_VERSION_SUFFIX "" + +// The minimum header version which works with the current version of +// the library. This constant should only be used by protoc's C++ code +// generator. +static const int kMinHeaderVersionForLibrary = 3013000; + +// The minimum protoc version which works with the current version of the +// headers. +#define GOOGLE_PROTOBUF_MIN_PROTOC_VERSION 3013000 + +// The minimum header version which works with the current version of +// protoc. This constant should only be used in VerifyVersion(). +static const int kMinHeaderVersionForProtoc = 3013000; + +// Verifies that the headers and libraries are compatible. Use the macro +// below to call this. +void PROTOBUF_EXPORT VerifyVersion(int headerVersion, int minLibraryVersion, + const char* filename); + +// Converts a numeric version number to a string. +std::string PROTOBUF_EXPORT VersionString(int version); + +} // namespace internal + +// Place this macro in your main() function (or somewhere before you attempt +// to use the protobuf library) to verify that the version you link against +// matches the headers you compiled against. If a version mismatch is +// detected, the process will abort. +#define GOOGLE_PROTOBUF_VERIFY_VERSION \ + ::google::protobuf::internal::VerifyVersion( \ + GOOGLE_PROTOBUF_VERSION, GOOGLE_PROTOBUF_MIN_LIBRARY_VERSION, \ + __FILE__) + + +// =================================================================== +// from google3/util/utf8/public/unilib.h + +class StringPiece; +namespace internal { + +// Checks if the buffer contains structurally-valid UTF-8. Implemented in +// structurally_valid.cc. +PROTOBUF_EXPORT bool IsStructurallyValidUTF8(const char* buf, int len); + +inline bool IsStructurallyValidUTF8(StringPiece str) { + return IsStructurallyValidUTF8(str.data(), static_cast(str.length())); +} + +// Returns initial number of bytes of structurally valid UTF-8. +PROTOBUF_EXPORT int UTF8SpnStructurallyValid(StringPiece str); + +// Coerce UTF-8 byte string in src_str to be +// a structurally-valid equal-length string by selectively +// overwriting illegal bytes with replace_char (typically ' ' or '?'). +// replace_char must be legal printable 7-bit Ascii 0x20..0x7e. +// src_str is read-only. +// +// Returns pointer to output buffer, src_str.data() if no changes were made, +// or idst if some bytes were changed. idst is allocated by the caller +// and must be at least as big as src_str +// +// Optimized for: all structurally valid and no byte copying is done. +// +PROTOBUF_EXPORT char* UTF8CoerceToStructurallyValid(StringPiece str, char* dst, + char replace_char); + +} // namespace internal + +// This lives in message_lite.h now, but we leave this here for any users that +// #include common.h and not message_lite.h. +PROTOBUF_EXPORT void ShutdownProtobufLibrary(); + +namespace internal { + +// Strongly references the given variable such that the linker will be forced +// to pull in this variable's translation unit. +template +void StrongReference(const T& var) { + auto volatile unused = &var; + (void)&unused; // Use address to avoid an extra load of "unused". +} + +} // namespace internal + +#if PROTOBUF_USE_EXCEPTIONS +class FatalException : public std::exception { + public: + FatalException(const char* filename, int line, const std::string& message) + : filename_(filename), line_(line), message_(message) {} + virtual ~FatalException() throw(); + + virtual const char* what() const throw(); + + const char* filename() const { return filename_; } + int line() const { return line_; } + const std::string& message() const { return message_; } + + private: + const char* filename_; + const int line_; + const std::string message_; +}; +#endif + +// This is at the end of the file instead of the beginning to work around a bug +// in some versions of MSVC. +using std::string; + +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_COMMON_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/fastmem.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/fastmem.h new file mode 100644 index 0000000000000000000000000000000000000000..ba25746d319f09787a79db713b6e6255c02d1299 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/fastmem.h @@ -0,0 +1,162 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2014 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Fast memory copying and comparison routines. +// strings::fastmemcmp_inlined() replaces memcmp() +// strings::memcpy_inlined() replaces memcpy() +// strings::memeq(a, b, n) replaces memcmp(a, b, n) == 0 +// +// strings::*_inlined() routines are inline versions of the +// routines exported by this module. Sometimes using the inlined +// versions is faster. Measure before using the inlined versions. +// +// Performance measurement: +// strings::fastmemcmp_inlined +// Analysis: memcmp, fastmemcmp_inlined, fastmemcmp +// 2012-01-30 + +#ifndef GOOGLE_PROTOBUF_STUBS_FASTMEM_H_ +#define GOOGLE_PROTOBUF_STUBS_FASTMEM_H_ + +#include +#include +#include + +#include + +#include + +namespace google { +namespace protobuf { +namespace internal { + +// Return true if the n bytes at a equal the n bytes at b. +// The regions are allowed to overlap. +// +// The performance is similar to the performance memcmp(), but faster for +// moderately-sized inputs, or inputs that share a common prefix and differ +// somewhere in their last 8 bytes. Further optimizations can be added later +// if it makes sense to do so.:w +inline bool memeq(const char* a, const char* b, size_t n) { + size_t n_rounded_down = n & ~static_cast(7); + if (PROTOBUF_PREDICT_FALSE(n_rounded_down == 0)) { // n <= 7 + return memcmp(a, b, n) == 0; + } + // n >= 8 + uint64 u = GOOGLE_UNALIGNED_LOAD64(a) ^ GOOGLE_UNALIGNED_LOAD64(b); + uint64 v = GOOGLE_UNALIGNED_LOAD64(a + n - 8) ^ GOOGLE_UNALIGNED_LOAD64(b + n - 8); + if ((u | v) != 0) { // The first or last 8 bytes differ. + return false; + } + a += 8; + b += 8; + n = n_rounded_down - 8; + if (n > 128) { + // As of 2012, memcmp on x86-64 uses a big unrolled loop with SSE2 + // instructions, and while we could try to do something faster, it + // doesn't seem worth pursuing. + return memcmp(a, b, n) == 0; + } + for (; n >= 16; n -= 16) { + uint64 x = GOOGLE_UNALIGNED_LOAD64(a) ^ GOOGLE_UNALIGNED_LOAD64(b); + uint64 y = GOOGLE_UNALIGNED_LOAD64(a + 8) ^ GOOGLE_UNALIGNED_LOAD64(b + 8); + if ((x | y) != 0) { + return false; + } + a += 16; + b += 16; + } + // n must be 0 or 8 now because it was a multiple of 8 at the top of the loop. + return n == 0 || GOOGLE_UNALIGNED_LOAD64(a) == GOOGLE_UNALIGNED_LOAD64(b); +} + +inline int fastmemcmp_inlined(const char *a, const char *b, size_t n) { + if (n >= 64) { + return memcmp(a, b, n); + } + const char* a_limit = a + n; + while (a + sizeof(uint64) <= a_limit && + GOOGLE_UNALIGNED_LOAD64(a) == GOOGLE_UNALIGNED_LOAD64(b)) { + a += sizeof(uint64); + b += sizeof(uint64); + } + if (a + sizeof(uint32) <= a_limit && + GOOGLE_UNALIGNED_LOAD32(a) == GOOGLE_UNALIGNED_LOAD32(b)) { + a += sizeof(uint32); + b += sizeof(uint32); + } + while (a < a_limit) { + int d = + static_cast(static_cast(*a++) - static_cast(*b++)); + if (d) return d; + } + return 0; +} + +// The standard memcpy operation is slow for variable small sizes. +// This implementation inlines the optimal realization for sizes 1 to 16. +// To avoid code bloat don't use it in case of not performance-critical spots, +// nor when you don't expect very frequent values of size <= 16. +inline void memcpy_inlined(char *dst, const char *src, size_t size) { + // Compiler inlines code with minimal amount of data movement when third + // parameter of memcpy is a constant. + switch (size) { + case 1: memcpy(dst, src, 1); break; + case 2: memcpy(dst, src, 2); break; + case 3: memcpy(dst, src, 3); break; + case 4: memcpy(dst, src, 4); break; + case 5: memcpy(dst, src, 5); break; + case 6: memcpy(dst, src, 6); break; + case 7: memcpy(dst, src, 7); break; + case 8: memcpy(dst, src, 8); break; + case 9: memcpy(dst, src, 9); break; + case 10: memcpy(dst, src, 10); break; + case 11: memcpy(dst, src, 11); break; + case 12: memcpy(dst, src, 12); break; + case 13: memcpy(dst, src, 13); break; + case 14: memcpy(dst, src, 14); break; + case 15: memcpy(dst, src, 15); break; + case 16: memcpy(dst, src, 16); break; + default: memcpy(dst, src, size); break; + } +} + +} // namespace internal +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_STUBS_FASTMEM_H_ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/hash.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/hash.h new file mode 100644 index 0000000000000000000000000000000000000000..4d61f3d44fb19dc1a8415c99dc8c96ddba846862 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/hash.h @@ -0,0 +1,127 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: kenton@google.com (Kenton Varda) + +#ifndef GOOGLE_PROTOBUF_STUBS_HASH_H__ +#define GOOGLE_PROTOBUF_STUBS_HASH_H__ + +#include +#include +#include +#include + +# define GOOGLE_PROTOBUF_HASH_NAMESPACE_DECLARATION_START \ + namespace google { \ + namespace protobuf { +# define GOOGLE_PROTOBUF_HASH_NAMESPACE_DECLARATION_END }} + +namespace google { +namespace protobuf { + +template +struct hash : public std::hash {}; + +template +struct hash { + inline size_t operator()(const Key* key) const { + return reinterpret_cast(key); + } +}; + +// Unlike the old SGI version, the TR1 "hash" does not special-case char*. So, +// we go ahead and provide our own implementation. +template <> +struct hash { + inline size_t operator()(const char* str) const { + size_t result = 0; + for (; *str != '\0'; str++) { + result = 5 * result + static_cast(*str); + } + return result; + } +}; + +template<> +struct hash { + size_t operator()(bool x) const { + return static_cast(x); + } +}; + +template <> +struct hash { + inline size_t operator()(const std::string& key) const { + return hash()(key.c_str()); + } + + static const size_t bucket_size = 4; + static const size_t min_buckets = 8; + inline bool operator()(const std::string& a, const std::string& b) const { + return a < b; + } +}; + +template +struct hash > { + inline size_t operator()(const std::pair& key) const { + size_t first_hash = hash()(key.first); + size_t second_hash = hash()(key.second); + + // FIXME(kenton): What is the best way to compute this hash? I have + // no idea! This seems a bit better than an XOR. + return first_hash * ((1 << 16) - 1) + second_hash; + } + + static const size_t bucket_size = 4; + static const size_t min_buckets = 8; + inline bool operator()(const std::pair& a, + const std::pair& b) const { + return a < b; + } +}; + +// Used by GCC/SGI STL only. (Why isn't this provided by the standard +// library? :( ) +struct streq { + inline bool operator()(const char* a, const char* b) const { + return strcmp(a, b) == 0; + } +}; + +} // namespace protobuf +} // namespace google + +#endif // GOOGLE_PROTOBUF_STUBS_HASH_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/logging.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/logging.h new file mode 100644 index 0000000000000000000000000000000000000000..318d1a435d94b1f64ac7921bbc36ed01b676bc9d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/logging.h @@ -0,0 +1,246 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef GOOGLE_PROTOBUF_STUBS_LOGGING_H_ +#define GOOGLE_PROTOBUF_STUBS_LOGGING_H_ + +#include +#include + +#include + +// =================================================================== +// emulates google3/base/logging.h + +namespace google { +namespace protobuf { + +enum LogLevel { + LOGLEVEL_INFO, // Informational. This is never actually used by + // libprotobuf. + LOGLEVEL_WARNING, // Warns about issues that, although not technically a + // problem now, could cause problems in the future. For + // example, a // warning will be printed when parsing a + // message that is near the message size limit. + LOGLEVEL_ERROR, // An error occurred which should never happen during + // normal use. + LOGLEVEL_FATAL, // An error occurred from which the library cannot + // recover. This usually indicates a programming error + // in the code which calls the library, especially when + // compiled in debug mode. + +#ifdef NDEBUG + LOGLEVEL_DFATAL = LOGLEVEL_ERROR +#else + LOGLEVEL_DFATAL = LOGLEVEL_FATAL +#endif +}; + +class StringPiece; +namespace util { +class Status; +} +class uint128; +namespace internal { + +class LogFinisher; + +class PROTOBUF_EXPORT LogMessage { + public: + LogMessage(LogLevel level, const char* filename, int line); + ~LogMessage(); + + LogMessage& operator<<(const std::string& value); + LogMessage& operator<<(const char* value); + LogMessage& operator<<(char value); + LogMessage& operator<<(int value); + LogMessage& operator<<(uint value); + LogMessage& operator<<(long value); + LogMessage& operator<<(unsigned long value); + LogMessage& operator<<(long long value); + LogMessage& operator<<(unsigned long long value); + LogMessage& operator<<(double value); + LogMessage& operator<<(void* value); + LogMessage& operator<<(const StringPiece& value); + LogMessage& operator<<(const util::Status& status); + LogMessage& operator<<(const uint128& value); + + private: + friend class LogFinisher; + void Finish(); + + LogLevel level_; + const char* filename_; + int line_; + std::string message_; +}; + +// Used to make the entire "LOG(BLAH) << etc." expression have a void return +// type and print a newline after each message. +class PROTOBUF_EXPORT LogFinisher { + public: + void operator=(LogMessage& other); +}; + +template +bool IsOk(T status) { return status.ok(); } +template<> +inline bool IsOk(bool status) { return status; } + +} // namespace internal + +// Undef everything in case we're being mixed with some other Google library +// which already defined them itself. Presumably all Google libraries will +// support the same syntax for these so it should not be a big deal if they +// end up using our definitions instead. +#undef GOOGLE_LOG +#undef GOOGLE_LOG_IF + +#undef GOOGLE_CHECK +#undef GOOGLE_CHECK_OK +#undef GOOGLE_CHECK_EQ +#undef GOOGLE_CHECK_NE +#undef GOOGLE_CHECK_LT +#undef GOOGLE_CHECK_LE +#undef GOOGLE_CHECK_GT +#undef GOOGLE_CHECK_GE +#undef GOOGLE_CHECK_NOTNULL + +#undef GOOGLE_DLOG +#undef GOOGLE_DCHECK +#undef GOOGLE_DCHECK_OK +#undef GOOGLE_DCHECK_EQ +#undef GOOGLE_DCHECK_NE +#undef GOOGLE_DCHECK_LT +#undef GOOGLE_DCHECK_LE +#undef GOOGLE_DCHECK_GT +#undef GOOGLE_DCHECK_GE + +#define GOOGLE_LOG(LEVEL) \ + ::google::protobuf::internal::LogFinisher() = \ + ::google::protobuf::internal::LogMessage( \ + ::google::protobuf::LOGLEVEL_##LEVEL, __FILE__, __LINE__) +#define GOOGLE_LOG_IF(LEVEL, CONDITION) \ + !(CONDITION) ? (void)0 : GOOGLE_LOG(LEVEL) + +#define GOOGLE_CHECK(EXPRESSION) \ + GOOGLE_LOG_IF(FATAL, !(EXPRESSION)) << "CHECK failed: " #EXPRESSION ": " +#define GOOGLE_CHECK_OK(A) GOOGLE_CHECK(::google::protobuf::internal::IsOk(A)) +#define GOOGLE_CHECK_EQ(A, B) GOOGLE_CHECK((A) == (B)) +#define GOOGLE_CHECK_NE(A, B) GOOGLE_CHECK((A) != (B)) +#define GOOGLE_CHECK_LT(A, B) GOOGLE_CHECK((A) < (B)) +#define GOOGLE_CHECK_LE(A, B) GOOGLE_CHECK((A) <= (B)) +#define GOOGLE_CHECK_GT(A, B) GOOGLE_CHECK((A) > (B)) +#define GOOGLE_CHECK_GE(A, B) GOOGLE_CHECK((A) >= (B)) + +namespace internal { +template +T* CheckNotNull(const char* /* file */, int /* line */, + const char* name, T* val) { + if (val == nullptr) { + GOOGLE_LOG(FATAL) << name; + } + return val; +} +} // namespace internal +#define GOOGLE_CHECK_NOTNULL(A) \ + ::google::protobuf::internal::CheckNotNull( \ + __FILE__, __LINE__, "'" #A "' must not be nullptr", (A)) + +#ifdef NDEBUG + +#define GOOGLE_DLOG(LEVEL) GOOGLE_LOG_IF(LEVEL, false) + +#define GOOGLE_DCHECK(EXPRESSION) while(false) GOOGLE_CHECK(EXPRESSION) +#define GOOGLE_DCHECK_OK(E) GOOGLE_DCHECK(::google::protobuf::internal::IsOk(E)) +#define GOOGLE_DCHECK_EQ(A, B) GOOGLE_DCHECK((A) == (B)) +#define GOOGLE_DCHECK_NE(A, B) GOOGLE_DCHECK((A) != (B)) +#define GOOGLE_DCHECK_LT(A, B) GOOGLE_DCHECK((A) < (B)) +#define GOOGLE_DCHECK_LE(A, B) GOOGLE_DCHECK((A) <= (B)) +#define GOOGLE_DCHECK_GT(A, B) GOOGLE_DCHECK((A) > (B)) +#define GOOGLE_DCHECK_GE(A, B) GOOGLE_DCHECK((A) >= (B)) + +#else // NDEBUG + +#define GOOGLE_DLOG GOOGLE_LOG + +#define GOOGLE_DCHECK GOOGLE_CHECK +#define GOOGLE_DCHECK_OK GOOGLE_CHECK_OK +#define GOOGLE_DCHECK_EQ GOOGLE_CHECK_EQ +#define GOOGLE_DCHECK_NE GOOGLE_CHECK_NE +#define GOOGLE_DCHECK_LT GOOGLE_CHECK_LT +#define GOOGLE_DCHECK_LE GOOGLE_CHECK_LE +#define GOOGLE_DCHECK_GT GOOGLE_CHECK_GT +#define GOOGLE_DCHECK_GE GOOGLE_CHECK_GE + +#endif // !NDEBUG + +typedef void LogHandler(LogLevel level, const char* filename, int line, + const std::string& message); + +// The protobuf library sometimes writes warning and error messages to +// stderr. These messages are primarily useful for developers, but may +// also help end users figure out a problem. If you would prefer that +// these messages be sent somewhere other than stderr, call SetLogHandler() +// to set your own handler. This returns the old handler. Set the handler +// to nullptr to ignore log messages (but see also LogSilencer, below). +// +// Obviously, SetLogHandler is not thread-safe. You should only call it +// at initialization time, and probably not from library code. If you +// simply want to suppress log messages temporarily (e.g. because you +// have some code that tends to trigger them frequently and you know +// the warnings are not important to you), use the LogSilencer class +// below. +PROTOBUF_EXPORT LogHandler* SetLogHandler(LogHandler* new_func); + +// Create a LogSilencer if you want to temporarily suppress all log +// messages. As long as any LogSilencer objects exist, non-fatal +// log messages will be discarded (the current LogHandler will *not* +// be called). Constructing a LogSilencer is thread-safe. You may +// accidentally suppress log messages occurring in another thread, but +// since messages are generally for debugging purposes only, this isn't +// a big deal. If you want to intercept log messages, use SetLogHandler(). +class PROTOBUF_EXPORT LogSilencer { + public: + LogSilencer(); + ~LogSilencer(); +}; + +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_STUBS_LOGGING_H_ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/macros.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/macros.h new file mode 100644 index 0000000000000000000000000000000000000000..581790c6d72796fcf48170bd1aaa0ff2fee8e5b8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/macros.h @@ -0,0 +1,125 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef GOOGLE_PROTOBUF_MACROS_H__ +#define GOOGLE_PROTOBUF_MACROS_H__ + +#include + +namespace google { +namespace protobuf { + +#undef GOOGLE_DISALLOW_EVIL_CONSTRUCTORS +#define GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(TypeName) \ + TypeName(const TypeName&); \ + void operator=(const TypeName&) + +#undef GOOGLE_DISALLOW_IMPLICIT_CONSTRUCTORS +#define GOOGLE_DISALLOW_IMPLICIT_CONSTRUCTORS(TypeName) \ + TypeName(); \ + TypeName(const TypeName&); \ + void operator=(const TypeName&) + +// =================================================================== +// from google3/base/basictypes.h + +// The GOOGLE_ARRAYSIZE(arr) macro returns the # of elements in an array arr. +// The expression is a compile-time constant, and therefore can be +// used in defining new arrays, for example. +// +// GOOGLE_ARRAYSIZE catches a few type errors. If you see a compiler error +// +// "warning: division by zero in ..." +// +// when using GOOGLE_ARRAYSIZE, you are (wrongfully) giving it a pointer. +// You should only use GOOGLE_ARRAYSIZE on statically allocated arrays. +// +// The following comments are on the implementation details, and can +// be ignored by the users. +// +// ARRAYSIZE(arr) works by inspecting sizeof(arr) (the # of bytes in +// the array) and sizeof(*(arr)) (the # of bytes in one array +// element). If the former is divisible by the latter, perhaps arr is +// indeed an array, in which case the division result is the # of +// elements in the array. Otherwise, arr cannot possibly be an array, +// and we generate a compiler error to prevent the code from +// compiling. +// +// Since the size of bool is implementation-defined, we need to cast +// !(sizeof(a) & sizeof(*(a))) to size_t in order to ensure the final +// result has type size_t. +// +// This macro is not perfect as it wrongfully accepts certain +// pointers, namely where the pointer size is divisible by the pointee +// size. Since all our code has to go through a 32-bit compiler, +// where a pointer is 4 bytes, this means all pointers to a type whose +// size is 3 or greater than 4 will be (righteously) rejected. +// +// Kudos to Jorg Brown for this simple and elegant implementation. + +#undef GOOGLE_ARRAYSIZE +#define GOOGLE_ARRAYSIZE(a) \ + ((sizeof(a) / sizeof(*(a))) / \ + static_cast(!(sizeof(a) % sizeof(*(a))))) + +// The COMPILE_ASSERT macro can be used to verify that a compile time +// expression is true. For example, you could use it to verify the +// size of a static array: +// +// COMPILE_ASSERT(ARRAYSIZE(content_type_names) == CONTENT_NUM_TYPES, +// content_type_names_incorrect_size); +// +// or to make sure a struct is smaller than a certain size: +// +// COMPILE_ASSERT(sizeof(foo) < 128, foo_too_large); +// +// The second argument to the macro is the name of the variable. If +// the expression is false, most compilers will issue a warning/error +// containing the name of the variable. + +namespace internal { + +template +struct CompileAssert { +}; + +} // namespace internal + +#define GOOGLE_COMPILE_ASSERT(expr, msg) static_assert(expr, #msg) + +} // namespace protobuf +} // namespace google + +#endif // GOOGLE_PROTOBUF_MACROS_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/map_util.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/map_util.h new file mode 100644 index 0000000000000000000000000000000000000000..17f6b90aa0ad3372f40b2fa5dfe23be2157f647e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/map_util.h @@ -0,0 +1,774 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2014 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// from google3/util/gtl/map_util.h +// Author: Anton Carver + +#ifndef GOOGLE_PROTOBUF_STUBS_MAP_UTIL_H__ +#define GOOGLE_PROTOBUF_STUBS_MAP_UTIL_H__ + +#include +#include +#include +#include +#include + +#include + +namespace google { +namespace protobuf { +namespace internal { +// Local implementation of RemoveConst to avoid including base/type_traits.h. +template struct RemoveConst { typedef T type; }; +template struct RemoveConst : RemoveConst {}; +} // namespace internal + +// +// Find*() +// + +// Returns a const reference to the value associated with the given key if it +// exists. Crashes otherwise. +// +// This is intended as a replacement for operator[] as an rvalue (for reading) +// when the key is guaranteed to exist. +// +// operator[] for lookup is discouraged for several reasons: +// * It has a side-effect of inserting missing keys +// * It is not thread-safe (even when it is not inserting, it can still +// choose to resize the underlying storage) +// * It invalidates iterators (when it chooses to resize) +// * It default constructs a value object even if it doesn't need to +// +// This version assumes the key is printable, and includes it in the fatal log +// message. +template +const typename Collection::value_type::second_type& +FindOrDie(const Collection& collection, + const typename Collection::value_type::first_type& key) { + typename Collection::const_iterator it = collection.find(key); + GOOGLE_CHECK(it != collection.end()) << "Map key not found: " << key; + return it->second; +} + +// Same as above, but returns a non-const reference. +template +typename Collection::value_type::second_type& +FindOrDie(Collection& collection, // NOLINT + const typename Collection::value_type::first_type& key) { + typename Collection::iterator it = collection.find(key); + GOOGLE_CHECK(it != collection.end()) << "Map key not found: " << key; + return it->second; +} + +// Same as FindOrDie above, but doesn't log the key on failure. +template +const typename Collection::value_type::second_type& +FindOrDieNoPrint(const Collection& collection, + const typename Collection::value_type::first_type& key) { + typename Collection::const_iterator it = collection.find(key); + GOOGLE_CHECK(it != collection.end()) << "Map key not found"; + return it->second; +} + +// Same as above, but returns a non-const reference. +template +typename Collection::value_type::second_type& +FindOrDieNoPrint(Collection& collection, // NOLINT + const typename Collection::value_type::first_type& key) { + typename Collection::iterator it = collection.find(key); + GOOGLE_CHECK(it != collection.end()) << "Map key not found"; + return it->second; +} + +// Returns a const reference to the value associated with the given key if it +// exists, otherwise returns a const reference to the provided default value. +// +// WARNING: If a temporary object is passed as the default "value," +// this function will return a reference to that temporary object, +// which will be destroyed at the end of the statement. A common +// example: if you have a map with string values, and you pass a char* +// as the default "value," either use the returned value immediately +// or store it in a string (not string&). +// Details: http://go/findwithdefault +template +const typename Collection::value_type::second_type& +FindWithDefault(const Collection& collection, + const typename Collection::value_type::first_type& key, + const typename Collection::value_type::second_type& value) { + typename Collection::const_iterator it = collection.find(key); + if (it == collection.end()) { + return value; + } + return it->second; +} + +// Returns a pointer to the const value associated with the given key if it +// exists, or nullptr otherwise. +template +const typename Collection::value_type::second_type* +FindOrNull(const Collection& collection, + const typename Collection::value_type::first_type& key) { + typename Collection::const_iterator it = collection.find(key); + if (it == collection.end()) { + return 0; + } + return &it->second; +} + +// Same as above but returns a pointer to the non-const value. +template +typename Collection::value_type::second_type* +FindOrNull(Collection& collection, // NOLINT + const typename Collection::value_type::first_type& key) { + typename Collection::iterator it = collection.find(key); + if (it == collection.end()) { + return 0; + } + return &it->second; +} + +// Returns the pointer value associated with the given key. If none is found, +// nullptr is returned. The function is designed to be used with a map of keys to +// pointers. +// +// This function does not distinguish between a missing key and a key mapped +// to nullptr. +template +typename Collection::value_type::second_type +FindPtrOrNull(const Collection& collection, + const typename Collection::value_type::first_type& key) { + typename Collection::const_iterator it = collection.find(key); + if (it == collection.end()) { + return typename Collection::value_type::second_type(); + } + return it->second; +} + +// Same as above, except takes non-const reference to collection. +// +// This function is needed for containers that propagate constness to the +// pointee, such as boost::ptr_map. +template +typename Collection::value_type::second_type +FindPtrOrNull(Collection& collection, // NOLINT + const typename Collection::value_type::first_type& key) { + typename Collection::iterator it = collection.find(key); + if (it == collection.end()) { + return typename Collection::value_type::second_type(); + } + return it->second; +} + +// Finds the pointer value associated with the given key in a map whose values +// are linked_ptrs. Returns nullptr if key is not found. +template +typename Collection::value_type::second_type::element_type* +FindLinkedPtrOrNull(const Collection& collection, + const typename Collection::value_type::first_type& key) { + typename Collection::const_iterator it = collection.find(key); + if (it == collection.end()) { + return 0; + } + // Since linked_ptr::get() is a const member returning a non const, + // we do not need a version of this function taking a non const collection. + return it->second.get(); +} + +// Same as above, but dies if the key is not found. +template +typename Collection::value_type::second_type::element_type& +FindLinkedPtrOrDie(const Collection& collection, + const typename Collection::value_type::first_type& key) { + typename Collection::const_iterator it = collection.find(key); + GOOGLE_CHECK(it != collection.end()) << "key not found: " << key; + // Since linked_ptr::operator*() is a const member returning a non const, + // we do not need a version of this function taking a non const collection. + return *it->second; +} + +// Finds the value associated with the given key and copies it to *value (if not +// nullptr). Returns false if the key was not found, true otherwise. +template +bool FindCopy(const Collection& collection, + const Key& key, + Value* const value) { + typename Collection::const_iterator it = collection.find(key); + if (it == collection.end()) { + return false; + } + if (value) { + *value = it->second; + } + return true; +} + +// +// Contains*() +// + +// Returns true if and only if the given collection contains the given key. +template +bool ContainsKey(const Collection& collection, const Key& key) { + return collection.find(key) != collection.end(); +} + +// Returns true if and only if the given collection contains the given key-value +// pair. +template +bool ContainsKeyValuePair(const Collection& collection, + const Key& key, + const Value& value) { + typedef typename Collection::const_iterator const_iterator; + std::pair range = collection.equal_range(key); + for (const_iterator it = range.first; it != range.second; ++it) { + if (it->second == value) { + return true; + } + } + return false; +} + +// +// Insert*() +// + +// Inserts the given key-value pair into the collection. Returns true if and +// only if the key from the given pair didn't previously exist. Otherwise, the +// value in the map is replaced with the value from the given pair. +template +bool InsertOrUpdate(Collection* const collection, + const typename Collection::value_type& vt) { + std::pair ret = collection->insert(vt); + if (!ret.second) { + // update + ret.first->second = vt.second; + return false; + } + return true; +} + +// Same as above, except that the key and value are passed separately. +template +bool InsertOrUpdate(Collection* const collection, + const typename Collection::value_type::first_type& key, + const typename Collection::value_type::second_type& value) { + return InsertOrUpdate( + collection, typename Collection::value_type(key, value)); +} + +// Inserts/updates all the key-value pairs from the range defined by the +// iterators "first" and "last" into the given collection. +template +void InsertOrUpdateMany(Collection* const collection, + InputIterator first, InputIterator last) { + for (; first != last; ++first) { + InsertOrUpdate(collection, *first); + } +} + +// Change the value associated with a particular key in a map or hash_map +// of the form map which owns the objects pointed to by the +// value pointers. If there was an existing value for the key, it is deleted. +// True indicates an insert took place, false indicates an update + delete. +template +bool InsertAndDeleteExisting( + Collection* const collection, + const typename Collection::value_type::first_type& key, + const typename Collection::value_type::second_type& value) { + std::pair ret = + collection->insert(typename Collection::value_type(key, value)); + if (!ret.second) { + delete ret.first->second; + ret.first->second = value; + return false; + } + return true; +} + +// Inserts the given key and value into the given collection if and only if the +// given key did NOT already exist in the collection. If the key previously +// existed in the collection, the value is not changed. Returns true if the +// key-value pair was inserted; returns false if the key was already present. +template +bool InsertIfNotPresent(Collection* const collection, + const typename Collection::value_type& vt) { + return collection->insert(vt).second; +} + +// Same as above except the key and value are passed separately. +template +bool InsertIfNotPresent( + Collection* const collection, + const typename Collection::value_type::first_type& key, + const typename Collection::value_type::second_type& value) { + return InsertIfNotPresent( + collection, typename Collection::value_type(key, value)); +} + +// Same as above except dies if the key already exists in the collection. +template +void InsertOrDie(Collection* const collection, + const typename Collection::value_type& value) { + GOOGLE_CHECK(InsertIfNotPresent(collection, value)) + << "duplicate value: " << value; +} + +// Same as above except doesn't log the value on error. +template +void InsertOrDieNoPrint(Collection* const collection, + const typename Collection::value_type& value) { + GOOGLE_CHECK(InsertIfNotPresent(collection, value)) << "duplicate value."; +} + +// Inserts the key-value pair into the collection. Dies if key was already +// present. +template +void InsertOrDie(Collection* const collection, + const typename Collection::value_type::first_type& key, + const typename Collection::value_type::second_type& data) { + GOOGLE_CHECK(InsertIfNotPresent(collection, key, data)) + << "duplicate key: " << key; +} + +// Same as above except doesn't log the key on error. +template +void InsertOrDieNoPrint( + Collection* const collection, + const typename Collection::value_type::first_type& key, + const typename Collection::value_type::second_type& data) { + GOOGLE_CHECK(InsertIfNotPresent(collection, key, data)) << "duplicate key."; +} + +// Inserts a new key and default-initialized value. Dies if the key was already +// present. Returns a reference to the value. Example usage: +// +// map m; +// SomeProto& proto = InsertKeyOrDie(&m, 3); +// proto.set_field("foo"); +template +typename Collection::value_type::second_type& InsertKeyOrDie( + Collection* const collection, + const typename Collection::value_type::first_type& key) { + typedef typename Collection::value_type value_type; + std::pair res = + collection->insert(value_type(key, typename value_type::second_type())); + GOOGLE_CHECK(res.second) << "duplicate key: " << key; + return res.first->second; +} + +// +// Lookup*() +// + +// Looks up a given key and value pair in a collection and inserts the key-value +// pair if it's not already present. Returns a reference to the value associated +// with the key. +template +typename Collection::value_type::second_type& +LookupOrInsert(Collection* const collection, + const typename Collection::value_type& vt) { + return collection->insert(vt).first->second; +} + +// Same as above except the key-value are passed separately. +template +typename Collection::value_type::second_type& +LookupOrInsert(Collection* const collection, + const typename Collection::value_type::first_type& key, + const typename Collection::value_type::second_type& value) { + return LookupOrInsert( + collection, typename Collection::value_type(key, value)); +} + +// Counts the number of equivalent elements in the given "sequence", and stores +// the results in "count_map" with element as the key and count as the value. +// +// Example: +// vector v = {"a", "b", "c", "a", "b"}; +// map m; +// AddTokenCounts(v, 1, &m); +// assert(m["a"] == 2); +// assert(m["b"] == 2); +// assert(m["c"] == 1); +template +void AddTokenCounts( + const Sequence& sequence, + const typename Collection::value_type::second_type& increment, + Collection* const count_map) { + for (typename Sequence::const_iterator it = sequence.begin(); + it != sequence.end(); ++it) { + typename Collection::value_type::second_type& value = + LookupOrInsert(count_map, *it, + typename Collection::value_type::second_type()); + value += increment; + } +} + +// Returns a reference to the value associated with key. If not found, a value +// is default constructed on the heap and added to the map. +// +// This function is useful for containers of the form map, where +// inserting a new key, value pair involves constructing a new heap-allocated +// Value, and storing a pointer to that in the collection. +template +typename Collection::value_type::second_type& +LookupOrInsertNew(Collection* const collection, + const typename Collection::value_type::first_type& key) { + typedef typename std::iterator_traits< + typename Collection::value_type::second_type>::value_type Element; + std::pair ret = + collection->insert(typename Collection::value_type( + key, + static_cast(nullptr))); + if (ret.second) { + ret.first->second = new Element(); + } + return ret.first->second; +} + +// Same as above but constructs the value using the single-argument constructor +// and the given "arg". +template +typename Collection::value_type::second_type& +LookupOrInsertNew(Collection* const collection, + const typename Collection::value_type::first_type& key, + const Arg& arg) { + typedef typename std::iterator_traits< + typename Collection::value_type::second_type>::value_type Element; + std::pair ret = + collection->insert(typename Collection::value_type( + key, + static_cast(nullptr))); + if (ret.second) { + ret.first->second = new Element(arg); + } + return ret.first->second; +} + +// Lookup of linked/shared pointers is used in two scenarios: +// +// Use LookupOrInsertNewLinkedPtr if the container owns the elements. +// In this case it is fine working with the raw pointer as long as it is +// guaranteed that no other thread can delete/update an accessed element. +// A mutex will need to lock the container operation as well as the use +// of the returned elements. Finding an element may be performed using +// FindLinkedPtr*(). +// +// Use LookupOrInsertNewSharedPtr if the container does not own the elements +// for their whole lifetime. This is typically the case when a reader allows +// parallel updates to the container. In this case a Mutex only needs to lock +// container operations, but all element operations must be performed on the +// shared pointer. Finding an element must be performed using FindPtr*() and +// cannot be done with FindLinkedPtr*() even though it compiles. + +// Lookup a key in a map or hash_map whose values are linked_ptrs. If it is +// missing, set collection[key].reset(new Value::element_type) and return that. +// Value::element_type must be default constructable. +template +typename Collection::value_type::second_type::element_type* +LookupOrInsertNewLinkedPtr( + Collection* const collection, + const typename Collection::value_type::first_type& key) { + typedef typename Collection::value_type::second_type Value; + std::pair ret = + collection->insert(typename Collection::value_type(key, Value())); + if (ret.second) { + ret.first->second.reset(new typename Value::element_type); + } + return ret.first->second.get(); +} + +// A variant of LookupOrInsertNewLinkedPtr where the value is constructed using +// a single-parameter constructor. Note: the constructor argument is computed +// even if it will not be used, so only values cheap to compute should be passed +// here. On the other hand it does not matter how expensive the construction of +// the actual stored value is, as that only occurs if necessary. +template +typename Collection::value_type::second_type::element_type* +LookupOrInsertNewLinkedPtr( + Collection* const collection, + const typename Collection::value_type::first_type& key, + const Arg& arg) { + typedef typename Collection::value_type::second_type Value; + std::pair ret = + collection->insert(typename Collection::value_type(key, Value())); + if (ret.second) { + ret.first->second.reset(new typename Value::element_type(arg)); + } + return ret.first->second.get(); +} + +// Lookup a key in a map or hash_map whose values are shared_ptrs. If it is +// missing, set collection[key].reset(new Value::element_type). Unlike +// LookupOrInsertNewLinkedPtr, this function returns the shared_ptr instead of +// the raw pointer. Value::element_type must be default constructable. +template +typename Collection::value_type::second_type& +LookupOrInsertNewSharedPtr( + Collection* const collection, + const typename Collection::value_type::first_type& key) { + typedef typename Collection::value_type::second_type SharedPtr; + typedef typename Collection::value_type::second_type::element_type Element; + std::pair ret = + collection->insert(typename Collection::value_type(key, SharedPtr())); + if (ret.second) { + ret.first->second.reset(new Element()); + } + return ret.first->second; +} + +// A variant of LookupOrInsertNewSharedPtr where the value is constructed using +// a single-parameter constructor. Note: the constructor argument is computed +// even if it will not be used, so only values cheap to compute should be passed +// here. On the other hand it does not matter how expensive the construction of +// the actual stored value is, as that only occurs if necessary. +template +typename Collection::value_type::second_type& +LookupOrInsertNewSharedPtr( + Collection* const collection, + const typename Collection::value_type::first_type& key, + const Arg& arg) { + typedef typename Collection::value_type::second_type SharedPtr; + typedef typename Collection::value_type::second_type::element_type Element; + std::pair ret = + collection->insert(typename Collection::value_type(key, SharedPtr())); + if (ret.second) { + ret.first->second.reset(new Element(arg)); + } + return ret.first->second; +} + +// +// Misc Utility Functions +// + +// Updates the value associated with the given key. If the key was not already +// present, then the key-value pair are inserted and "previous" is unchanged. If +// the key was already present, the value is updated and "*previous" will +// contain a copy of the old value. +// +// InsertOrReturnExisting has complementary behavior that returns the +// address of an already existing value, rather than updating it. +template +bool UpdateReturnCopy(Collection* const collection, + const typename Collection::value_type::first_type& key, + const typename Collection::value_type::second_type& value, + typename Collection::value_type::second_type* previous) { + std::pair ret = + collection->insert(typename Collection::value_type(key, value)); + if (!ret.second) { + // update + if (previous) { + *previous = ret.first->second; + } + ret.first->second = value; + return true; + } + return false; +} + +// Same as above except that the key and value are passed as a pair. +template +bool UpdateReturnCopy(Collection* const collection, + const typename Collection::value_type& vt, + typename Collection::value_type::second_type* previous) { + std::pair ret = collection->insert(vt); + if (!ret.second) { + // update + if (previous) { + *previous = ret.first->second; + } + ret.first->second = vt.second; + return true; + } + return false; +} + +// Tries to insert the given key-value pair into the collection. Returns nullptr if +// the insert succeeds. Otherwise, returns a pointer to the existing value. +// +// This complements UpdateReturnCopy in that it allows to update only after +// verifying the old value and still insert quickly without having to look up +// twice. Unlike UpdateReturnCopy this also does not come with the issue of an +// undefined previous* in case new data was inserted. +template +typename Collection::value_type::second_type* InsertOrReturnExisting( + Collection* const collection, const typename Collection::value_type& vt) { + std::pair ret = collection->insert(vt); + if (ret.second) { + return nullptr; // Inserted, no existing previous value. + } else { + return &ret.first->second; // Return address of already existing value. + } +} + +// Same as above, except for explicit key and data. +template +typename Collection::value_type::second_type* InsertOrReturnExisting( + Collection* const collection, + const typename Collection::value_type::first_type& key, + const typename Collection::value_type::second_type& data) { + return InsertOrReturnExisting(collection, + typename Collection::value_type(key, data)); +} + +// Erases the collection item identified by the given key, and returns the value +// associated with that key. It is assumed that the value (i.e., the +// mapped_type) is a pointer. Returns nullptr if the key was not found in the +// collection. +// +// Examples: +// map my_map; +// +// One line cleanup: +// delete EraseKeyReturnValuePtr(&my_map, "abc"); +// +// Use returned value: +// std::unique_ptr value_ptr( +// EraseKeyReturnValuePtr(&my_map, "abc")); +// if (value_ptr.get()) +// value_ptr->DoSomething(); +// +template +typename Collection::value_type::second_type EraseKeyReturnValuePtr( + Collection* const collection, + const typename Collection::value_type::first_type& key) { + typename Collection::iterator it = collection->find(key); + if (it == collection->end()) { + return nullptr; + } + typename Collection::value_type::second_type v = it->second; + collection->erase(it); + return v; +} + +// Inserts all the keys from map_container into key_container, which must +// support insert(MapContainer::key_type). +// +// Note: any initial contents of the key_container are not cleared. +template +void InsertKeysFromMap(const MapContainer& map_container, + KeyContainer* key_container) { + GOOGLE_CHECK(key_container != nullptr); + for (typename MapContainer::const_iterator it = map_container.begin(); + it != map_container.end(); ++it) { + key_container->insert(it->first); + } +} + +// Appends all the keys from map_container into key_container, which must +// support push_back(MapContainer::key_type). +// +// Note: any initial contents of the key_container are not cleared. +template +void AppendKeysFromMap(const MapContainer& map_container, + KeyContainer* key_container) { + GOOGLE_CHECK(key_container != nullptr); + for (typename MapContainer::const_iterator it = map_container.begin(); + it != map_container.end(); ++it) { + key_container->push_back(it->first); + } +} + +// A more specialized overload of AppendKeysFromMap to optimize reallocations +// for the common case in which we're appending keys to a vector and hence can +// (and sometimes should) call reserve() first. +// +// (It would be possible to play SFINAE games to call reserve() for any +// container that supports it, but this seems to get us 99% of what we need +// without the complexity of a SFINAE-based solution.) +template +void AppendKeysFromMap(const MapContainer& map_container, + std::vector* key_container) { + GOOGLE_CHECK(key_container != nullptr); + // We now have the opportunity to call reserve(). Calling reserve() every + // time is a bad idea for some use cases: libstdc++'s implementation of + // vector<>::reserve() resizes the vector's backing store to exactly the + // given size (unless it's already at least that big). Because of this, + // the use case that involves appending a lot of small maps (total size + // N) one by one to a vector would be O(N^2). But never calling reserve() + // loses the opportunity to improve the use case of adding from a large + // map to an empty vector (this improves performance by up to 33%). A + // number of heuristics are possible; see the discussion in + // cl/34081696. Here we use the simplest one. + if (key_container->empty()) { + key_container->reserve(map_container.size()); + } + for (typename MapContainer::const_iterator it = map_container.begin(); + it != map_container.end(); ++it) { + key_container->push_back(it->first); + } +} + +// Inserts all the values from map_container into value_container, which must +// support push_back(MapContainer::mapped_type). +// +// Note: any initial contents of the value_container are not cleared. +template +void AppendValuesFromMap(const MapContainer& map_container, + ValueContainer* value_container) { + GOOGLE_CHECK(value_container != nullptr); + for (typename MapContainer::const_iterator it = map_container.begin(); + it != map_container.end(); ++it) { + value_container->push_back(it->second); + } +} + +// A more specialized overload of AppendValuesFromMap to optimize reallocations +// for the common case in which we're appending values to a vector and hence +// can (and sometimes should) call reserve() first. +// +// (It would be possible to play SFINAE games to call reserve() for any +// container that supports it, but this seems to get us 99% of what we need +// without the complexity of a SFINAE-based solution.) +template +void AppendValuesFromMap(const MapContainer& map_container, + std::vector* value_container) { + GOOGLE_CHECK(value_container != nullptr); + // See AppendKeysFromMap for why this is done. + if (value_container->empty()) { + value_container->reserve(map_container.size()); + } + for (typename MapContainer::const_iterator it = map_container.begin(); + it != map_container.end(); ++it) { + value_container->push_back(it->second); + } +} + +} // namespace protobuf +} // namespace google + +#endif // GOOGLE_PROTOBUF_STUBS_MAP_UTIL_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/mutex.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/mutex.h new file mode 100644 index 0000000000000000000000000000000000000000..2193d4493920a0d9aa9067605f933a0339c55bad --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/mutex.h @@ -0,0 +1,191 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Copyright (c) 2006, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef GOOGLE_PROTOBUF_STUBS_MUTEX_H_ +#define GOOGLE_PROTOBUF_STUBS_MUTEX_H_ + +#include + +#ifdef GOOGLE_PROTOBUF_SUPPORT_WINDOWS_XP + +#include + +// GetMessage conflicts with GeneratedMessageReflection::GetMessage(). +#ifdef GetMessage +#undef GetMessage +#endif + +#endif + +#include + +// Define thread-safety annotations for use below, if we are building with +// Clang. +#if defined(__clang__) && !defined(SWIG) +#define GOOGLE_PROTOBUF_ACQUIRE(...) \ + __attribute__((acquire_capability(__VA_ARGS__))) +#define GOOGLE_PROTOBUF_RELEASE(...) \ + __attribute__((release_capability(__VA_ARGS__))) +#define GOOGLE_PROTOBUF_CAPABILITY(x) __attribute__((capability(x))) +#else +#define GOOGLE_PROTOBUF_ACQUIRE(...) +#define GOOGLE_PROTOBUF_RELEASE(...) +#define GOOGLE_PROTOBUF_CAPABILITY(x) +#endif + +#include + +// =================================================================== +// emulates google3/base/mutex.h +namespace google { +namespace protobuf { +namespace internal { + +#define GOOGLE_PROTOBUF_LINKER_INITIALIZED + +#ifdef GOOGLE_PROTOBUF_SUPPORT_WINDOWS_XP + +// This class is a lightweight replacement for std::mutex on Windows platforms. +// std::mutex does not work on Windows XP SP2 with the latest VC++ libraries, +// because it utilizes the Concurrency Runtime that is only supported on Windows +// XP SP3 and above. +class PROTOBUF_EXPORT CriticalSectionLock { + public: + CriticalSectionLock() { InitializeCriticalSection(&critical_section_); } + ~CriticalSectionLock() { DeleteCriticalSection(&critical_section_); } + void lock() { EnterCriticalSection(&critical_section_); } + void unlock() { LeaveCriticalSection(&critical_section_); } + + private: + CRITICAL_SECTION critical_section_; + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(CriticalSectionLock); +}; + +#endif + +// Mutex is a natural type to wrap. As both google and other organization have +// specialized mutexes. gRPC also provides an injection mechanism for custom +// mutexes. +class GOOGLE_PROTOBUF_CAPABILITY("mutex") PROTOBUF_EXPORT WrappedMutex { + public: + WrappedMutex() = default; + void Lock() GOOGLE_PROTOBUF_ACQUIRE() { mu_.lock(); } + void Unlock() GOOGLE_PROTOBUF_RELEASE() { mu_.unlock(); } + // Crash if this Mutex is not held exclusively by this thread. + // May fail to crash when it should; will never crash when it should not. + void AssertHeld() const {} + + private: +#ifndef GOOGLE_PROTOBUF_SUPPORT_WINDOWS_XP + std::mutex mu_; +#else // ifndef GOOGLE_PROTOBUF_SUPPORT_WINDOWS_XP + CriticalSectionLock mu_; +#endif // #ifndef GOOGLE_PROTOBUF_SUPPORT_WINDOWS_XP +}; + +using Mutex = WrappedMutex; + +// MutexLock(mu) acquires mu when constructed and releases it when destroyed. +class PROTOBUF_EXPORT MutexLock { + public: + explicit MutexLock(Mutex *mu) : mu_(mu) { this->mu_->Lock(); } + ~MutexLock() { this->mu_->Unlock(); } + private: + Mutex *const mu_; + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(MutexLock); +}; + +// TODO(kenton): Implement these? Hard to implement portably. +typedef MutexLock ReaderMutexLock; +typedef MutexLock WriterMutexLock; + +// MutexLockMaybe is like MutexLock, but is a no-op when mu is nullptr. +class PROTOBUF_EXPORT MutexLockMaybe { + public: + explicit MutexLockMaybe(Mutex *mu) : + mu_(mu) { if (this->mu_ != nullptr) { this->mu_->Lock(); } } + ~MutexLockMaybe() { if (this->mu_ != nullptr) { this->mu_->Unlock(); } } + private: + Mutex *const mu_; + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(MutexLockMaybe); +}; + +#if defined(GOOGLE_PROTOBUF_NO_THREADLOCAL) +template +class ThreadLocalStorage { + public: + ThreadLocalStorage() { + pthread_key_create(&key_, &ThreadLocalStorage::Delete); + } + ~ThreadLocalStorage() { + pthread_key_delete(key_); + } + T* Get() { + T* result = static_cast(pthread_getspecific(key_)); + if (result == nullptr) { + result = new T(); + pthread_setspecific(key_, result); + } + return result; + } + private: + static void Delete(void* value) { + delete static_cast(value); + } + pthread_key_t key_; + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(ThreadLocalStorage); +}; +#endif + +} // namespace internal + +// We made these internal so that they would show up as such in the docs, +// but we don't want to stick "internal::" in front of them everywhere. +using internal::Mutex; +using internal::MutexLock; +using internal::ReaderMutexLock; +using internal::WriterMutexLock; +using internal::MutexLockMaybe; + +} // namespace protobuf +} // namespace google + +#undef GOOGLE_PROTOBUF_ACQUIRE +#undef GOOGLE_PROTOBUF_RELEASE + +#include + +#endif // GOOGLE_PROTOBUF_STUBS_MUTEX_H_ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/once.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/once.h new file mode 100644 index 0000000000000000000000000000000000000000..66ba5987a0d85f7cbf6dea96dee596f0ba0495fc --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/once.h @@ -0,0 +1,60 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef GOOGLE_PROTOBUF_STUBS_ONCE_H__ +#define GOOGLE_PROTOBUF_STUBS_ONCE_H__ + +#include +#include + +#include + +namespace google { +namespace protobuf { +namespace internal { + +using once_flag = std::once_flag; +template +void call_once(Args&&... args ) { + std::call_once(std::forward(args)...); +} + +} // namespace internal +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_STUBS_ONCE_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/platform_macros.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/platform_macros.h new file mode 100644 index 0000000000000000000000000000000000000000..f5d154fff83ed0d2e4ffc2113ee1746a5aa109c2 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/platform_macros.h @@ -0,0 +1,139 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2012 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef GOOGLE_PROTOBUF_PLATFORM_MACROS_H_ +#define GOOGLE_PROTOBUF_PLATFORM_MACROS_H_ + +#define GOOGLE_PROTOBUF_PLATFORM_ERROR \ +#error "Host platform was not detected as supported by protobuf" + +// Processor architecture detection. For more info on what's defined, see: +// http://msdn.microsoft.com/en-us/library/b0084kay.aspx +// http://www.agner.org/optimize/calling_conventions.pdf +// or with gcc, run: "echo | gcc -E -dM -" +#if defined(_M_X64) || defined(__x86_64__) +#define GOOGLE_PROTOBUF_ARCH_X64 1 +#define GOOGLE_PROTOBUF_ARCH_64_BIT 1 +#elif defined(_M_IX86) || defined(__i386__) +#define GOOGLE_PROTOBUF_ARCH_IA32 1 +#define GOOGLE_PROTOBUF_ARCH_32_BIT 1 +#elif defined(__QNX__) +#define GOOGLE_PROTOBUF_ARCH_ARM_QNX 1 +#define GOOGLE_PROTOBUF_ARCH_32_BIT 1 +#elif defined(_M_ARM) || defined(__ARMEL__) +#define GOOGLE_PROTOBUF_ARCH_ARM 1 +#define GOOGLE_PROTOBUF_ARCH_32_BIT 1 +#elif defined(_M_ARM64) +#define GOOGLE_PROTOBUF_ARCH_ARM 1 +#define GOOGLE_PROTOBUF_ARCH_64_BIT 1 +#elif defined(__aarch64__) +#define GOOGLE_PROTOBUF_ARCH_AARCH64 1 +#define GOOGLE_PROTOBUF_ARCH_64_BIT 1 +#elif defined(__mips__) +#if defined(__LP64__) +#define GOOGLE_PROTOBUF_ARCH_MIPS64 1 +#define GOOGLE_PROTOBUF_ARCH_64_BIT 1 +#else +#define GOOGLE_PROTOBUF_ARCH_MIPS 1 +#define GOOGLE_PROTOBUF_ARCH_32_BIT 1 +#endif +#elif defined(__pnacl__) +#define GOOGLE_PROTOBUF_ARCH_32_BIT 1 +#elif defined(sparc) +#define GOOGLE_PROTOBUF_ARCH_SPARC 1 +#if defined(__sparc_v9__) || defined(__sparcv9) || defined(__arch64__) +#define GOOGLE_PROTOBUF_ARCH_64_BIT 1 +#else +#define GOOGLE_PROTOBUF_ARCH_32_BIT 1 +#endif +#elif defined(_POWER) || defined(__powerpc64__) || defined(__PPC64__) +#define GOOGLE_PROTOBUF_ARCH_POWER 1 +#define GOOGLE_PROTOBUF_ARCH_64_BIT 1 +#elif defined(__PPC__) +#define GOOGLE_PROTOBUF_ARCH_PPC 1 +#define GOOGLE_PROTOBUF_ARCH_32_BIT 1 +#elif defined(__GNUC__) +# if (((__GNUC__ == 4) && (__GNUC_MINOR__ >= 7)) || (__GNUC__ > 4)) +// We fallback to the generic Clang/GCC >= 4.7 implementation in atomicops.h +# elif defined(__clang__) +# if !__has_extension(c_atomic) +GOOGLE_PROTOBUF_PLATFORM_ERROR +# endif +// We fallback to the generic Clang/GCC >= 4.7 implementation in atomicops.h +# endif +# if __LP64__ +# define GOOGLE_PROTOBUF_ARCH_64_BIT 1 +# else +# define GOOGLE_PROTOBUF_ARCH_32_BIT 1 +# endif +#else +GOOGLE_PROTOBUF_PLATFORM_ERROR +#endif + +#if defined(__APPLE__) +#define GOOGLE_PROTOBUF_OS_APPLE +#include +#include +#if TARGET_OS_IPHONE +#define GOOGLE_PROTOBUF_OS_IPHONE +#endif +#elif defined(__EMSCRIPTEN__) +#define GOOGLE_PROTOBUF_OS_EMSCRIPTEN +#elif defined(__native_client__) +#define GOOGLE_PROTOBUF_OS_NACL +#elif defined(sun) +#define GOOGLE_PROTOBUF_OS_SOLARIS +#elif defined(_AIX) +#define GOOGLE_PROTOBUF_OS_AIX +#elif defined(__ANDROID__) +#define GOOGLE_PROTOBUF_OS_ANDROID +#endif + +#undef GOOGLE_PROTOBUF_PLATFORM_ERROR + +#if defined(GOOGLE_PROTOBUF_OS_ANDROID) || defined(GOOGLE_PROTOBUF_OS_IPHONE) || defined(__OpenBSD__) +// Android ndk does not support the __thread keyword very well yet. Here +// we use pthread_key_create()/pthread_getspecific()/... methods for +// TLS support on android. +// iOS and OpenBSD also do not support the __thread keyword. +#define GOOGLE_PROTOBUF_NO_THREADLOCAL +#endif + +#if defined(__MAC_OS_X_VERSION_MIN_REQUIRED) && __MAC_OS_X_VERSION_MIN_REQUIRED < 1070 +// __thread keyword requires at least 10.7 +#define GOOGLE_PROTOBUF_NO_THREADLOCAL +#endif + +#endif // GOOGLE_PROTOBUF_PLATFORM_MACROS_H_ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/port.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/port.h new file mode 100644 index 0000000000000000000000000000000000000000..a46e6de0e64aa94ad2683f1742b8a56320fb86ad --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/port.h @@ -0,0 +1,410 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef GOOGLE_PROTOBUF_STUBS_PORT_H_ +#define GOOGLE_PROTOBUF_STUBS_PORT_H_ + +#include +#include +#include +#include +#include +#include + +#include + +#include + +#undef PROTOBUF_LITTLE_ENDIAN +#ifdef _WIN32 + // Assuming windows is always little-endian. + // TODO(xiaofeng): The PROTOBUF_LITTLE_ENDIAN is not only used for + // optimization but also for correctness. We should define an + // different macro to test the big-endian code path in coded_stream. + #if !defined(PROTOBUF_DISABLE_LITTLE_ENDIAN_OPT_FOR_TEST) + #define PROTOBUF_LITTLE_ENDIAN 1 + #endif + #if _MSC_VER >= 1300 && !defined(__INTEL_COMPILER) + // If MSVC has "/RTCc" set, it will complain about truncating casts at + // runtime. This file contains some intentional truncating casts. + #pragma runtime_checks("c", off) + #endif +#else + #include // __BYTE_ORDER + #if defined(__OpenBSD__) + #include + #endif + #if ((defined(__LITTLE_ENDIAN__) && !defined(__BIG_ENDIAN__)) || \ + (defined(__BYTE_ORDER) && __BYTE_ORDER == __LITTLE_ENDIAN) || \ + (defined(BYTE_ORDER) && BYTE_ORDER == LITTLE_ENDIAN)) && \ + !defined(PROTOBUF_DISABLE_LITTLE_ENDIAN_OPT_FOR_TEST) + #define PROTOBUF_LITTLE_ENDIAN 1 + #endif +#endif + +// These #includes are for the byte swap functions declared later on. +#ifdef _MSC_VER +#include // NOLINT(build/include) +#include +#elif defined(__APPLE__) +#include +#elif defined(__GLIBC__) || defined(__BIONIC__) || defined(__CYGWIN__) +#include // IWYU pragma: export +#endif + +// Legacy: some users reference these (internal-only) macros even though we +// don't need them any more. +#if defined(_MSC_VER) && defined(PROTOBUF_USE_DLLS) + #ifdef LIBPROTOBUF_EXPORTS + #define LIBPROTOBUF_EXPORT __declspec(dllexport) + #else + #define LIBPROTOBUF_EXPORT __declspec(dllimport) + #endif + #ifdef LIBPROTOC_EXPORTS + #define LIBPROTOC_EXPORT __declspec(dllexport) + #else + #define LIBPROTOC_EXPORT __declspec(dllimport) + #endif +#else + #define LIBPROTOBUF_EXPORT + #define LIBPROTOC_EXPORT +#endif + +#define PROTOBUF_RUNTIME_DEPRECATED(message) PROTOBUF_DEPRECATED_MSG(message) +#define GOOGLE_PROTOBUF_RUNTIME_DEPRECATED(message) \ + PROTOBUF_DEPRECATED_MSG(message) + +// =================================================================== +// from google3/base/port.h + +#if (defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L || \ + (defined(_MSC_VER) && _MSC_VER >= 1900)) +// Define this to 1 if the code is compiled in C++11 mode; leave it +// undefined otherwise. Do NOT define it to 0 -- that causes +// '#ifdef LANG_CXX11' to behave differently from '#if LANG_CXX11'. +#define LANG_CXX11 1 +#else +#error "Protobuf requires at least C++11." +#endif + +namespace google { +namespace protobuf { + +using ConstStringParam = const std::string &; + +typedef unsigned int uint; + +typedef int8_t int8; +typedef int16_t int16; +typedef int32_t int32; +typedef int64_t int64; + +typedef uint8_t uint8; +typedef uint16_t uint16; +typedef uint32_t uint32; +typedef uint64_t uint64; + +static const int32 kint32max = 0x7FFFFFFF; +static const int32 kint32min = -kint32max - 1; +static const int64 kint64max = PROTOBUF_LONGLONG(0x7FFFFFFFFFFFFFFF); +static const int64 kint64min = -kint64max - 1; +static const uint32 kuint32max = 0xFFFFFFFFu; +static const uint64 kuint64max = PROTOBUF_ULONGLONG(0xFFFFFFFFFFFFFFFF); + +#if defined(ADDRESS_SANITIZER) || defined(THREAD_SANITIZER) ||\ + defined(MEMORY_SANITIZER) + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus +uint16_t __sanitizer_unaligned_load16(const void *p); +uint32_t __sanitizer_unaligned_load32(const void *p); +uint64_t __sanitizer_unaligned_load64(const void *p); +void __sanitizer_unaligned_store16(void *p, uint16_t v); +void __sanitizer_unaligned_store32(void *p, uint32_t v); +void __sanitizer_unaligned_store64(void *p, uint64_t v); +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +inline uint16 GOOGLE_UNALIGNED_LOAD16(const void *p) { + return __sanitizer_unaligned_load16(p); +} + +inline uint32 GOOGLE_UNALIGNED_LOAD32(const void *p) { + return __sanitizer_unaligned_load32(p); +} + +inline uint64 GOOGLE_UNALIGNED_LOAD64(const void *p) { + return __sanitizer_unaligned_load64(p); +} + +inline void GOOGLE_UNALIGNED_STORE16(void *p, uint16 v) { + __sanitizer_unaligned_store16(p, v); +} + +inline void GOOGLE_UNALIGNED_STORE32(void *p, uint32 v) { + __sanitizer_unaligned_store32(p, v); +} + +inline void GOOGLE_UNALIGNED_STORE64(void *p, uint64 v) { + __sanitizer_unaligned_store64(p, v); +} + +#elif defined(GOOGLE_PROTOBUF_USE_UNALIGNED) && GOOGLE_PROTOBUF_USE_UNALIGNED + +#define GOOGLE_UNALIGNED_LOAD16(_p) (*reinterpret_cast(_p)) +#define GOOGLE_UNALIGNED_LOAD32(_p) (*reinterpret_cast(_p)) +#define GOOGLE_UNALIGNED_LOAD64(_p) (*reinterpret_cast(_p)) + +#define GOOGLE_UNALIGNED_STORE16(_p, _val) (*reinterpret_cast(_p) = (_val)) +#define GOOGLE_UNALIGNED_STORE32(_p, _val) (*reinterpret_cast(_p) = (_val)) +#define GOOGLE_UNALIGNED_STORE64(_p, _val) (*reinterpret_cast(_p) = (_val)) + +#else +inline uint16 GOOGLE_UNALIGNED_LOAD16(const void *p) { + uint16 t; + memcpy(&t, p, sizeof t); + return t; +} + +inline uint32 GOOGLE_UNALIGNED_LOAD32(const void *p) { + uint32 t; + memcpy(&t, p, sizeof t); + return t; +} + +inline uint64 GOOGLE_UNALIGNED_LOAD64(const void *p) { + uint64 t; + memcpy(&t, p, sizeof t); + return t; +} + +inline void GOOGLE_UNALIGNED_STORE16(void *p, uint16 v) { + memcpy(p, &v, sizeof v); +} + +inline void GOOGLE_UNALIGNED_STORE32(void *p, uint32 v) { + memcpy(p, &v, sizeof v); +} + +inline void GOOGLE_UNALIGNED_STORE64(void *p, uint64 v) { + memcpy(p, &v, sizeof v); +} +#endif + +#if defined(GOOGLE_PROTOBUF_OS_NACL) \ + || (defined(__ANDROID__) && defined(__clang__) \ + && (__clang_major__ == 3 && __clang_minor__ == 8) \ + && (__clang_patchlevel__ < 275480)) +# define GOOGLE_PROTOBUF_USE_PORTABLE_LOG2 +#endif + +// The following guarantees declaration of the byte swap functions. +#ifdef _MSC_VER +#define bswap_16(x) _byteswap_ushort(x) +#define bswap_32(x) _byteswap_ulong(x) +#define bswap_64(x) _byteswap_uint64(x) + +#elif defined(__APPLE__) +// Mac OS X / Darwin features +#define bswap_16(x) OSSwapInt16(x) +#define bswap_32(x) OSSwapInt32(x) +#define bswap_64(x) OSSwapInt64(x) + +#elif !defined(__GLIBC__) && !defined(__BIONIC__) && !defined(__CYGWIN__) + +#ifndef bswap_16 +static inline uint16 bswap_16(uint16 x) { + return static_cast(((x & 0xFF) << 8) | ((x & 0xFF00) >> 8)); +} +#define bswap_16(x) bswap_16(x) +#endif + +#ifndef bswap_32 +static inline uint32 bswap_32(uint32 x) { + return (((x & 0xFF) << 24) | + ((x & 0xFF00) << 8) | + ((x & 0xFF0000) >> 8) | + ((x & 0xFF000000) >> 24)); +} +#define bswap_32(x) bswap_32(x) +#endif + +#ifndef bswap_64 +static inline uint64 bswap_64(uint64 x) { + return (((x & PROTOBUF_ULONGLONG(0xFF)) << 56) | + ((x & PROTOBUF_ULONGLONG(0xFF00)) << 40) | + ((x & PROTOBUF_ULONGLONG(0xFF0000)) << 24) | + ((x & PROTOBUF_ULONGLONG(0xFF000000)) << 8) | + ((x & PROTOBUF_ULONGLONG(0xFF00000000)) >> 8) | + ((x & PROTOBUF_ULONGLONG(0xFF0000000000)) >> 24) | + ((x & PROTOBUF_ULONGLONG(0xFF000000000000)) >> 40) | + ((x & PROTOBUF_ULONGLONG(0xFF00000000000000)) >> 56)); +} +#define bswap_64(x) bswap_64(x) +#endif + +#endif + +// =================================================================== +// from google3/util/bits/bits.h + +class Bits { + public: + static uint32 Log2FloorNonZero(uint32 n) { +#if defined(__GNUC__) + return 31 ^ static_cast(__builtin_clz(n)); +#elif defined(_MSC_VER) + unsigned long where; + _BitScanReverse(&where, n); + return where; +#else + return Log2FloorNonZero_Portable(n); +#endif + } + + static uint32 Log2FloorNonZero64(uint64 n) { + // Older versions of clang run into an instruction-selection failure when + // it encounters __builtin_clzll: + // https://bugs.chromium.org/p/nativeclient/issues/detail?id=4395 + // This includes arm-nacl-clang and clang in older Android NDK versions. + // To work around this, when we build with those we use the portable + // implementation instead. +#if defined(__GNUC__) && !defined(GOOGLE_PROTOBUF_USE_PORTABLE_LOG2) + return 63 ^ static_cast(__builtin_clzll(n)); +#elif defined(_MSC_VER) && defined(_M_X64) + unsigned long where; + _BitScanReverse64(&where, n); + return where; +#else + return Log2FloorNonZero64_Portable(n); +#endif + } + private: + static int Log2FloorNonZero_Portable(uint32 n) { + if (n == 0) + return -1; + int log = 0; + uint32 value = n; + for (int i = 4; i >= 0; --i) { + int shift = (1 << i); + uint32 x = value >> shift; + if (x != 0) { + value = x; + log += shift; + } + } + assert(value == 1); + return log; + } + + static int Log2FloorNonZero64_Portable(uint64 n) { + const uint32 topbits = static_cast(n >> 32); + if (topbits == 0) { + // Top bits are zero, so scan in bottom bits + return static_cast(Log2FloorNonZero(static_cast(n))); + } else { + return 32 + static_cast(Log2FloorNonZero(topbits)); + } + } +}; + +// =================================================================== +// from google3/util/endian/endian.h +PROTOBUF_EXPORT uint32 ghtonl(uint32 x); + +class BigEndian { + public: +#ifdef PROTOBUF_LITTLE_ENDIAN + + static uint16 FromHost16(uint16 x) { return bswap_16(x); } + static uint16 ToHost16(uint16 x) { return bswap_16(x); } + + static uint32 FromHost32(uint32 x) { return bswap_32(x); } + static uint32 ToHost32(uint32 x) { return bswap_32(x); } + + static uint64 FromHost64(uint64 x) { return bswap_64(x); } + static uint64 ToHost64(uint64 x) { return bswap_64(x); } + + static bool IsLittleEndian() { return true; } + +#else + + static uint16 FromHost16(uint16 x) { return x; } + static uint16 ToHost16(uint16 x) { return x; } + + static uint32 FromHost32(uint32 x) { return x; } + static uint32 ToHost32(uint32 x) { return x; } + + static uint64 FromHost64(uint64 x) { return x; } + static uint64 ToHost64(uint64 x) { return x; } + + static bool IsLittleEndian() { return false; } + +#endif /* ENDIAN */ + + // Functions to do unaligned loads and stores in big-endian order. + static uint16 Load16(const void *p) { + return ToHost16(GOOGLE_UNALIGNED_LOAD16(p)); + } + + static void Store16(void *p, uint16 v) { + GOOGLE_UNALIGNED_STORE16(p, FromHost16(v)); + } + + static uint32 Load32(const void *p) { + return ToHost32(GOOGLE_UNALIGNED_LOAD32(p)); + } + + static void Store32(void *p, uint32 v) { + GOOGLE_UNALIGNED_STORE32(p, FromHost32(v)); + } + + static uint64 Load64(const void *p) { + return ToHost64(GOOGLE_UNALIGNED_LOAD64(p)); + } + + static void Store64(void *p, uint64 v) { + GOOGLE_UNALIGNED_STORE64(p, FromHost64(v)); + } +}; + +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_STUBS_PORT_H_ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/status.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/status.h new file mode 100644 index 0000000000000000000000000000000000000000..cf15c91a33ed90712c77d8ce62185caf82948cc4 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/status.h @@ -0,0 +1,130 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#ifndef GOOGLE_PROTOBUF_STUBS_STATUS_H_ +#define GOOGLE_PROTOBUF_STUBS_STATUS_H_ + +#include +#include + +#include +#include + +#include + +namespace google { +namespace protobuf { +namespace util { +namespace error { +// These values must match error codes defined in google/rpc/code.proto. +enum Code { + OK = 0, + CANCELLED = 1, + UNKNOWN = 2, + INVALID_ARGUMENT = 3, + DEADLINE_EXCEEDED = 4, + NOT_FOUND = 5, + ALREADY_EXISTS = 6, + PERMISSION_DENIED = 7, + UNAUTHENTICATED = 16, + RESOURCE_EXHAUSTED = 8, + FAILED_PRECONDITION = 9, + ABORTED = 10, + OUT_OF_RANGE = 11, + UNIMPLEMENTED = 12, + INTERNAL = 13, + UNAVAILABLE = 14, + DATA_LOSS = 15, +}; +} // namespace error + +class PROTOBUF_EXPORT Status { + public: + // Creates a "successful" status. + Status(); + + // Create a status in the canonical error space with the specified + // code, and error message. If "code == 0", error_message is + // ignored and a Status object identical to Status::OK is + // constructed. + Status(error::Code error_code, StringPiece error_message); + Status(const Status&); + Status& operator=(const Status& x); + ~Status() {} + + // Some pre-defined Status objects + static const Status OK; // Identical to 0-arg constructor + static const Status CANCELLED; + static const Status UNKNOWN; + + // Accessor + bool ok() const { + return error_code_ == error::OK; + } + int error_code() const { + return error_code_; + } + error::Code code() const { + return error_code_; + } + StringPiece error_message() const { + return error_message_; + } + StringPiece message() const { + return error_message_; + } + + bool operator==(const Status& x) const; + bool operator!=(const Status& x) const { + return !operator==(x); + } + + // Return a combination of the error code name and message. + string ToString() const; + + private: + error::Code error_code_; + string error_message_; +}; + +// Prints a human-readable representation of 'x' to 'os'. +PROTOBUF_EXPORT std::ostream& operator<<(std::ostream& os, const Status& x); + +} // namespace util +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_STUBS_STATUS_H_ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/stl_util.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/stl_util.h new file mode 100644 index 0000000000000000000000000000000000000000..89ca9b2fdfa10baff42682ed41ad66e0efed44f5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/stl_util.h @@ -0,0 +1,76 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// from google3/util/gtl/stl_util.h + +#ifndef GOOGLE_PROTOBUF_STUBS_STL_UTIL_H__ +#define GOOGLE_PROTOBUF_STUBS_STL_UTIL_H__ + +#include + +namespace google { +namespace protobuf { + +// Inside Google, this function implements a horrible, disgusting hack in which +// we reach into the string's private implementation and resize it without +// initializing the new bytes. In some cases doing this can significantly +// improve performance. However, since it's totally non-portable it has no +// place in open source code. Feel free to fill this function in with your +// own disgusting hack if you want the perf boost. +inline void STLStringResizeUninitialized(string* s, size_t new_size) { + s->resize(new_size); +} + +// Return a mutable char* pointing to a string's internal buffer, +// which may not be null-terminated. Writing through this pointer will +// modify the string. +// +// string_as_array(&str)[i] is valid for 0 <= i < str.size() until the +// next call to a string method that invalidates iterators. +// +// As of 2006-04, there is no standard-blessed way of getting a +// mutable reference to a string's internal buffer. However, issue 530 +// (http://www.open-std.org/JTC1/SC22/WG21/docs/lwg-active.html#530) +// proposes this as the method. According to Matt Austern, this should +// already work on all current implementations. +inline char* string_as_array(string* str) { + // DO NOT USE const_cast(str->data())! See the unittest for why. + return str->empty() ? nullptr : &*str->begin(); +} + +} // namespace protobuf +} // namespace google + +#endif // GOOGLE_PROTOBUF_STUBS_STL_UTIL_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/stringpiece.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/stringpiece.h new file mode 100644 index 0000000000000000000000000000000000000000..b1c17f2605f7511f072eb7c66b22c127b25830dc --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/stringpiece.h @@ -0,0 +1,494 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// A StringPiece points to part or all of a string, Cord, double-quoted string +// literal, or other string-like object. A StringPiece does *not* own the +// string to which it points. A StringPiece is not null-terminated. +// +// You can use StringPiece as a function or method parameter. A StringPiece +// parameter can receive a double-quoted string literal argument, a "const +// char*" argument, a string argument, or a StringPiece argument with no data +// copying. Systematic use of StringPiece for arguments reduces data +// copies and strlen() calls. +// +// Prefer passing StringPieces by value: +// void MyFunction(StringPiece arg); +// If circumstances require, you may also pass by const reference: +// void MyFunction(const StringPiece& arg); // not preferred +// Both of these have the same lifetime semantics. Passing by value +// generates slightly smaller code. For more discussion, see the thread +// go/stringpiecebyvalue on c-users. +// +// StringPiece is also suitable for local variables if you know that +// the lifetime of the underlying object is longer than the lifetime +// of your StringPiece variable. +// +// Beware of binding a StringPiece to a temporary: +// StringPiece sp = obj.MethodReturningString(); // BAD: lifetime problem +// +// This code is okay: +// string str = obj.MethodReturningString(); // str owns its contents +// StringPiece sp(str); // GOOD, because str outlives sp +// +// StringPiece is sometimes a poor choice for a return value and usually a poor +// choice for a data member. If you do use a StringPiece this way, it is your +// responsibility to ensure that the object pointed to by the StringPiece +// outlives the StringPiece. +// +// A StringPiece may represent just part of a string; thus the name "Piece". +// For example, when splitting a string, vector is a natural data +// type for the output. For another example, a Cord is a non-contiguous, +// potentially very long string-like object. The Cord class has an interface +// that iteratively provides StringPiece objects that point to the +// successive pieces of a Cord object. +// +// A StringPiece is not null-terminated. If you write code that scans a +// StringPiece, you must check its length before reading any characters. +// Common idioms that work on null-terminated strings do not work on +// StringPiece objects. +// +// There are several ways to create a null StringPiece: +// StringPiece() +// StringPiece(nullptr) +// StringPiece(nullptr, 0) +// For all of the above, sp.data() == nullptr, sp.length() == 0, +// and sp.empty() == true. Also, if you create a StringPiece with +// a non-null pointer then sp.data() != nullptr. Once created, +// sp.data() will stay either nullptr or not-nullptr, except if you call +// sp.clear() or sp.set(). +// +// Thus, you can use StringPiece(nullptr) to signal an out-of-band value +// that is different from other StringPiece values. This is similar +// to the way that const char* p1 = nullptr; is different from +// const char* p2 = "";. +// +// There are many ways to create an empty StringPiece: +// StringPiece() +// StringPiece(nullptr) +// StringPiece(nullptr, 0) +// StringPiece("") +// StringPiece("", 0) +// StringPiece("abcdef", 0) +// StringPiece("abcdef"+6, 0) +// For all of the above, sp.length() will be 0 and sp.empty() will be true. +// For some empty StringPiece values, sp.data() will be nullptr. +// For some empty StringPiece values, sp.data() will not be nullptr. +// +// Be careful not to confuse: null StringPiece and empty StringPiece. +// The set of empty StringPieces properly includes the set of null StringPieces. +// That is, every null StringPiece is an empty StringPiece, +// but some non-null StringPieces are empty Stringpieces too. +// +// All empty StringPiece values compare equal to each other. +// Even a null StringPieces compares equal to a non-null empty StringPiece: +// StringPiece() == StringPiece("", 0) +// StringPiece(nullptr) == StringPiece("abc", 0) +// StringPiece(nullptr, 0) == StringPiece("abcdef"+6, 0) +// +// Look carefully at this example: +// StringPiece("") == nullptr +// True or false? TRUE, because StringPiece::operator== converts +// the right-hand side from nullptr to StringPiece(nullptr), +// and then compares two zero-length spans of characters. +// However, we are working to make this example produce a compile error. +// +// Suppose you want to write: +// bool TestWhat?(StringPiece sp) { return sp == nullptr; } // BAD +// Do not do that. Write one of these instead: +// bool TestNull(StringPiece sp) { return sp.data() == nullptr; } +// bool TestEmpty(StringPiece sp) { return sp.empty(); } +// The intent of TestWhat? is unclear. Did you mean TestNull or TestEmpty? +// Right now, TestWhat? behaves likes TestEmpty. +// We are working to make TestWhat? produce a compile error. +// TestNull is good to test for an out-of-band signal. +// TestEmpty is good to test for an empty StringPiece. +// +// Caveats (again): +// (1) The lifetime of the pointed-to string (or piece of a string) +// must be longer than the lifetime of the StringPiece. +// (2) There may or may not be a '\0' character after the end of +// StringPiece data. +// (3) A null StringPiece is empty. +// An empty StringPiece may or may not be a null StringPiece. + +#ifndef GOOGLE_PROTOBUF_STUBS_STRINGPIECE_H_ +#define GOOGLE_PROTOBUF_STUBS_STRINGPIECE_H_ + +#include +#include +#include +#include +#include +#include + +#include + +#include + +namespace google { +namespace protobuf { +// StringPiece has *two* size types. +// StringPiece::size_type +// is unsigned +// is 32 bits in LP32, 64 bits in LP64, 64 bits in LLP64 +// no future changes intended +// stringpiece_ssize_type +// is signed +// is 32 bits in LP32, 64 bits in LP64, 64 bits in LLP64 +// future changes intended: http://go/64BitStringPiece +// +typedef std::string::difference_type stringpiece_ssize_type; + +// STRINGPIECE_CHECK_SIZE protects us from 32-bit overflows. +// TODO(mec): delete this after stringpiece_ssize_type goes 64 bit. +#if !defined(NDEBUG) +#define STRINGPIECE_CHECK_SIZE 1 +#elif defined(_FORTIFY_SOURCE) && _FORTIFY_SOURCE > 0 +#define STRINGPIECE_CHECK_SIZE 1 +#else +#define STRINGPIECE_CHECK_SIZE 0 +#endif + +class PROTOBUF_EXPORT StringPiece { + private: + const char* ptr_; + stringpiece_ssize_type length_; + + // Prevent overflow in debug mode or fortified mode. + // sizeof(stringpiece_ssize_type) may be smaller than sizeof(size_t). + static stringpiece_ssize_type CheckedSsizeTFromSizeT(size_t size) { +#if STRINGPIECE_CHECK_SIZE > 0 +#ifdef max +#undef max +#endif + if (size > static_cast( + std::numeric_limits::max())) { + // Some people grep for this message in logs + // so take care if you ever change it. + LogFatalSizeTooBig(size, "size_t to int conversion"); + } +#endif + return static_cast(size); + } + + // Out-of-line error path. + static void LogFatalSizeTooBig(size_t size, const char* details); + + public: + // We provide non-explicit singleton constructors so users can pass + // in a "const char*" or a "string" wherever a "StringPiece" is + // expected. + // + // Style guide exception granted: + // http://goto/style-guide-exception-20978288 + StringPiece() : ptr_(nullptr), length_(0) {} + + StringPiece(const char* str) // NOLINT(runtime/explicit) + : ptr_(str), length_(0) { + if (str != nullptr) { + length_ = CheckedSsizeTFromSizeT(strlen(str)); + } + } + + template + StringPiece( // NOLINT(runtime/explicit) + const std::basic_string, Allocator>& str) + : ptr_(str.data()), length_(0) { + length_ = CheckedSsizeTFromSizeT(str.size()); + } + + StringPiece(const char* offset, stringpiece_ssize_type len) + : ptr_(offset), length_(len) { + assert(len >= 0); + } + + // Substring of another StringPiece. + // pos must be non-negative and <= x.length(). + StringPiece(StringPiece x, stringpiece_ssize_type pos); + // Substring of another StringPiece. + // pos must be non-negative and <= x.length(). + // len must be non-negative and will be pinned to at most x.length() - pos. + StringPiece(StringPiece x, + stringpiece_ssize_type pos, + stringpiece_ssize_type len); + + // data() may return a pointer to a buffer with embedded NULs, and the + // returned buffer may or may not be null terminated. Therefore it is + // typically a mistake to pass data() to a routine that expects a NUL + // terminated string. + const char* data() const { return ptr_; } + stringpiece_ssize_type size() const { return length_; } + stringpiece_ssize_type length() const { return length_; } + bool empty() const { return length_ == 0; } + + void clear() { + ptr_ = nullptr; + length_ = 0; + } + + void set(const char* data, stringpiece_ssize_type len) { + assert(len >= 0); + ptr_ = data; + length_ = len; + } + + void set(const char* str) { + ptr_ = str; + if (str != nullptr) + length_ = CheckedSsizeTFromSizeT(strlen(str)); + else + length_ = 0; + } + + void set(const void* data, stringpiece_ssize_type len) { + ptr_ = reinterpret_cast(data); + length_ = len; + } + + char operator[](stringpiece_ssize_type i) const { + assert(0 <= i); + assert(i < length_); + return ptr_[i]; + } + + void remove_prefix(stringpiece_ssize_type n) { + assert(length_ >= n); + ptr_ += n; + length_ -= n; + } + + void remove_suffix(stringpiece_ssize_type n) { + assert(length_ >= n); + length_ -= n; + } + + // returns {-1, 0, 1} + int compare(StringPiece x) const { + const stringpiece_ssize_type min_size = + length_ < x.length_ ? length_ : x.length_; + int r = memcmp(ptr_, x.ptr_, static_cast(min_size)); + if (r < 0) return -1; + if (r > 0) return 1; + if (length_ < x.length_) return -1; + if (length_ > x.length_) return 1; + return 0; + } + + std::string as_string() const { return ToString(); } + // We also define ToString() here, since many other string-like + // interfaces name the routine that converts to a C++ string + // "ToString", and it's confusing to have the method that does that + // for a StringPiece be called "as_string()". We also leave the + // "as_string()" method defined here for existing code. + std::string ToString() const { + if (ptr_ == nullptr) return ""; + return std::string(data(), static_cast(size())); + } + + explicit operator std::string() const { return ToString(); } + + void CopyToString(std::string* target) const; + void AppendToString(std::string* target) const; + + bool starts_with(StringPiece x) const { + return (length_ >= x.length_) && + (memcmp(ptr_, x.ptr_, static_cast(x.length_)) == 0); + } + + bool ends_with(StringPiece x) const { + return ((length_ >= x.length_) && + (memcmp(ptr_ + (length_-x.length_), x.ptr_, + static_cast(x.length_)) == 0)); + } + + // Checks whether StringPiece starts with x and if so advances the beginning + // of it to past the match. It's basically a shortcut for starts_with + // followed by remove_prefix. + bool Consume(StringPiece x); + // Like above but for the end of the string. + bool ConsumeFromEnd(StringPiece x); + + // standard STL container boilerplate + typedef char value_type; + typedef const char* pointer; + typedef const char& reference; + typedef const char& const_reference; + typedef size_t size_type; + typedef ptrdiff_t difference_type; + static const size_type npos; + typedef const char* const_iterator; + typedef const char* iterator; + typedef std::reverse_iterator const_reverse_iterator; + typedef std::reverse_iterator reverse_iterator; + iterator begin() const { return ptr_; } + iterator end() const { return ptr_ + length_; } + const_reverse_iterator rbegin() const { + return const_reverse_iterator(ptr_ + length_); + } + const_reverse_iterator rend() const { + return const_reverse_iterator(ptr_); + } + stringpiece_ssize_type max_size() const { return length_; } + stringpiece_ssize_type capacity() const { return length_; } + + // cpplint.py emits a false positive [build/include_what_you_use] + stringpiece_ssize_type copy(char* buf, size_type n, size_type pos = 0) const; // NOLINT + + bool contains(StringPiece s) const; + + stringpiece_ssize_type find(StringPiece s, size_type pos = 0) const; + stringpiece_ssize_type find(char c, size_type pos = 0) const; + stringpiece_ssize_type rfind(StringPiece s, size_type pos = npos) const; + stringpiece_ssize_type rfind(char c, size_type pos = npos) const; + + stringpiece_ssize_type find_first_of(StringPiece s, size_type pos = 0) const; + stringpiece_ssize_type find_first_of(char c, size_type pos = 0) const { + return find(c, pos); + } + stringpiece_ssize_type find_first_not_of(StringPiece s, + size_type pos = 0) const; + stringpiece_ssize_type find_first_not_of(char c, size_type pos = 0) const; + stringpiece_ssize_type find_last_of(StringPiece s, + size_type pos = npos) const; + stringpiece_ssize_type find_last_of(char c, size_type pos = npos) const { + return rfind(c, pos); + } + stringpiece_ssize_type find_last_not_of(StringPiece s, + size_type pos = npos) const; + stringpiece_ssize_type find_last_not_of(char c, size_type pos = npos) const; + + StringPiece substr(size_type pos, size_type n = npos) const; +}; + +// This large function is defined inline so that in a fairly common case where +// one of the arguments is a literal, the compiler can elide a lot of the +// following comparisons. +inline bool operator==(StringPiece x, StringPiece y) { + stringpiece_ssize_type len = x.size(); + if (len != y.size()) { + return false; + } + + return x.data() == y.data() || len <= 0 || + memcmp(x.data(), y.data(), static_cast(len)) == 0; +} + +inline bool operator!=(StringPiece x, StringPiece y) { + return !(x == y); +} + +inline bool operator<(StringPiece x, StringPiece y) { + const stringpiece_ssize_type min_size = + x.size() < y.size() ? x.size() : y.size(); + const int r = memcmp(x.data(), y.data(), static_cast(min_size)); + return (r < 0) || (r == 0 && x.size() < y.size()); +} + +inline bool operator>(StringPiece x, StringPiece y) { + return y < x; +} + +inline bool operator<=(StringPiece x, StringPiece y) { + return !(x > y); +} + +inline bool operator>=(StringPiece x, StringPiece y) { + return !(x < y); +} + +// allow StringPiece to be logged +extern std::ostream& operator<<(std::ostream& o, StringPiece piece); + +namespace internal { +// StringPiece is not a POD and can not be used in an union (pre C++11). We +// need a POD version of it. +struct StringPiecePod { + // Create from a StringPiece. + static StringPiecePod CreateFromStringPiece(StringPiece str) { + StringPiecePod pod; + pod.data_ = str.data(); + pod.size_ = str.size(); + return pod; + } + + // Cast to StringPiece. + operator StringPiece() const { return StringPiece(data_, size_); } + + bool operator==(const char* value) const { + return StringPiece(data_, size_) == StringPiece(value); + } + + char operator[](stringpiece_ssize_type i) const { + assert(0 <= i); + assert(i < size_); + return data_[i]; + } + + const char* data() const { return data_; } + + stringpiece_ssize_type size() const { + return size_; + } + + std::string ToString() const { + return std::string(data_, static_cast(size_)); + } + + explicit operator std::string() const { return ToString(); } + + private: + const char* data_; + stringpiece_ssize_type size_; +}; + +} // namespace internal +} // namespace protobuf +} // namespace google + +GOOGLE_PROTOBUF_HASH_NAMESPACE_DECLARATION_START +template<> struct hash { + size_t operator()(const StringPiece& s) const { + size_t result = 0; + for (const char *str = s.data(), *end = str + s.size(); str < end; str++) { + result = 5 * result + static_cast(*str); + } + return result; + } +}; +GOOGLE_PROTOBUF_HASH_NAMESPACE_DECLARATION_END + +#include + +#endif // STRINGS_STRINGPIECE_H_ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/strutil.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/strutil.h new file mode 100644 index 0000000000000000000000000000000000000000..c5fdd08e00c66772b4ba1054c149dca12fcc8e9a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/strutil.h @@ -0,0 +1,952 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// from google3/strings/strutil.h + +#ifndef GOOGLE_PROTOBUF_STUBS_STRUTIL_H__ +#define GOOGLE_PROTOBUF_STUBS_STRUTIL_H__ + +#include +#include +#include + +#include +#include +#include + +namespace google { +namespace protobuf { + +#if defined(_MSC_VER) && _MSC_VER < 1800 +#define strtoll _strtoi64 +#define strtoull _strtoui64 +#elif defined(__DECCXX) && defined(__osf__) +// HP C++ on Tru64 does not have strtoll, but strtol is already 64-bit. +#define strtoll strtol +#define strtoull strtoul +#endif + +// ---------------------------------------------------------------------- +// ascii_isalnum() +// Check if an ASCII character is alphanumeric. We can't use ctype's +// isalnum() because it is affected by locale. This function is applied +// to identifiers in the protocol buffer language, not to natural-language +// strings, so locale should not be taken into account. +// ascii_isdigit() +// Like above, but only accepts digits. +// ascii_isspace() +// Check if the character is a space character. +// ---------------------------------------------------------------------- + +inline bool ascii_isalnum(char c) { + return ('a' <= c && c <= 'z') || + ('A' <= c && c <= 'Z') || + ('0' <= c && c <= '9'); +} + +inline bool ascii_isdigit(char c) { + return ('0' <= c && c <= '9'); +} + +inline bool ascii_isspace(char c) { + return c == ' ' || c == '\t' || c == '\n' || c == '\v' || c == '\f' || + c == '\r'; +} + +inline bool ascii_isupper(char c) { + return c >= 'A' && c <= 'Z'; +} + +inline bool ascii_islower(char c) { + return c >= 'a' && c <= 'z'; +} + +inline char ascii_toupper(char c) { + return ascii_islower(c) ? c - ('a' - 'A') : c; +} + +inline char ascii_tolower(char c) { + return ascii_isupper(c) ? c + ('a' - 'A') : c; +} + +inline int hex_digit_to_int(char c) { + /* Assume ASCII. */ + int x = static_cast(c); + if (x > '9') { + x += 9; + } + return x & 0xf; +} + +// ---------------------------------------------------------------------- +// HasPrefixString() +// Check if a string begins with a given prefix. +// StripPrefixString() +// Given a string and a putative prefix, returns the string minus the +// prefix string if the prefix matches, otherwise the original +// string. +// ---------------------------------------------------------------------- +inline bool HasPrefixString(StringPiece str, StringPiece prefix) { + return str.size() >= prefix.size() && + memcmp(str.data(), prefix.data(), prefix.size()) == 0; +} + +inline string StripPrefixString(const string& str, const string& prefix) { + if (HasPrefixString(str, prefix)) { + return str.substr(prefix.size()); + } else { + return str; + } +} + +// ---------------------------------------------------------------------- +// HasSuffixString() +// Return true if str ends in suffix. +// StripSuffixString() +// Given a string and a putative suffix, returns the string minus the +// suffix string if the suffix matches, otherwise the original +// string. +// ---------------------------------------------------------------------- +inline bool HasSuffixString(StringPiece str, StringPiece suffix) { + return str.size() >= suffix.size() && + memcmp(str.data() + str.size() - suffix.size(), suffix.data(), + suffix.size()) == 0; +} + +inline string StripSuffixString(const string& str, const string& suffix) { + if (HasSuffixString(str, suffix)) { + return str.substr(0, str.size() - suffix.size()); + } else { + return str; + } +} + +// ---------------------------------------------------------------------- +// ReplaceCharacters +// Replaces any occurrence of the character 'remove' (or the characters +// in 'remove') with the character 'replacewith'. +// Good for keeping html characters or protocol characters (\t) out +// of places where they might cause a problem. +// StripWhitespace +// Removes whitespaces from both ends of the given string. +// ---------------------------------------------------------------------- +PROTOBUF_EXPORT void ReplaceCharacters(string* s, const char* remove, + char replacewith); + +PROTOBUF_EXPORT void StripWhitespace(string* s); + +// ---------------------------------------------------------------------- +// LowerString() +// UpperString() +// ToUpper() +// Convert the characters in "s" to lowercase or uppercase. ASCII-only: +// these functions intentionally ignore locale because they are applied to +// identifiers used in the Protocol Buffer language, not to natural-language +// strings. +// ---------------------------------------------------------------------- + +inline void LowerString(string * s) { + string::iterator end = s->end(); + for (string::iterator i = s->begin(); i != end; ++i) { + // tolower() changes based on locale. We don't want this! + if ('A' <= *i && *i <= 'Z') *i += 'a' - 'A'; + } +} + +inline void UpperString(string * s) { + string::iterator end = s->end(); + for (string::iterator i = s->begin(); i != end; ++i) { + // toupper() changes based on locale. We don't want this! + if ('a' <= *i && *i <= 'z') *i += 'A' - 'a'; + } +} + +inline void ToUpper(string* s) { UpperString(s); } + +inline string ToUpper(const string& s) { + string out = s; + UpperString(&out); + return out; +} + +// ---------------------------------------------------------------------- +// StringReplace() +// Give me a string and two patterns "old" and "new", and I replace +// the first instance of "old" in the string with "new", if it +// exists. RETURN a new string, regardless of whether the replacement +// happened or not. +// ---------------------------------------------------------------------- + +PROTOBUF_EXPORT string StringReplace(const string& s, const string& oldsub, + const string& newsub, bool replace_all); + +// ---------------------------------------------------------------------- +// SplitStringUsing() +// Split a string using a character delimiter. Append the components +// to 'result'. If there are consecutive delimiters, this function skips +// over all of them. +// ---------------------------------------------------------------------- +PROTOBUF_EXPORT void SplitStringUsing(StringPiece full, const char* delim, + std::vector* res); + +// Split a string using one or more byte delimiters, presented +// as a nul-terminated c string. Append the components to 'result'. +// If there are consecutive delimiters, this function will return +// corresponding empty strings. If you want to drop the empty +// strings, try SplitStringUsing(). +// +// If "full" is the empty string, yields an empty string as the only value. +// ---------------------------------------------------------------------- +PROTOBUF_EXPORT void SplitStringAllowEmpty(StringPiece full, const char* delim, + std::vector* result); + +// ---------------------------------------------------------------------- +// Split() +// Split a string using a character delimiter. +// ---------------------------------------------------------------------- +inline std::vector Split(StringPiece full, const char* delim, + bool skip_empty = true) { + std::vector result; + if (skip_empty) { + SplitStringUsing(full, delim, &result); + } else { + SplitStringAllowEmpty(full, delim, &result); + } + return result; +} + +// ---------------------------------------------------------------------- +// JoinStrings() +// These methods concatenate a vector of strings into a C++ string, using +// the C-string "delim" as a separator between components. There are two +// flavors of the function, one flavor returns the concatenated string, +// another takes a pointer to the target string. In the latter case the +// target string is cleared and overwritten. +// ---------------------------------------------------------------------- +PROTOBUF_EXPORT void JoinStrings(const std::vector& components, + const char* delim, string* result); + +inline string JoinStrings(const std::vector& components, + const char* delim) { + string result; + JoinStrings(components, delim, &result); + return result; +} + +// ---------------------------------------------------------------------- +// UnescapeCEscapeSequences() +// Copies "source" to "dest", rewriting C-style escape sequences +// -- '\n', '\r', '\\', '\ooo', etc -- to their ASCII +// equivalents. "dest" must be sufficiently large to hold all +// the characters in the rewritten string (i.e. at least as large +// as strlen(source) + 1 should be safe, since the replacements +// are always shorter than the original escaped sequences). It's +// safe for source and dest to be the same. RETURNS the length +// of dest. +// +// It allows hex sequences \xhh, or generally \xhhhhh with an +// arbitrary number of hex digits, but all of them together must +// specify a value of a single byte (e.g. \x0045 is equivalent +// to \x45, and \x1234 is erroneous). +// +// It also allows escape sequences of the form \uhhhh (exactly four +// hex digits, upper or lower case) or \Uhhhhhhhh (exactly eight +// hex digits, upper or lower case) to specify a Unicode code +// point. The dest array will contain the UTF8-encoded version of +// that code-point (e.g., if source contains \u2019, then dest will +// contain the three bytes 0xE2, 0x80, and 0x99). +// +// Errors: In the first form of the call, errors are reported with +// LOG(ERROR). The same is true for the second form of the call if +// the pointer to the string std::vector is nullptr; otherwise, error +// messages are stored in the std::vector. In either case, the effect on +// the dest array is not defined, but rest of the source will be +// processed. +// ---------------------------------------------------------------------- + +PROTOBUF_EXPORT int UnescapeCEscapeSequences(const char* source, char* dest); +PROTOBUF_EXPORT int UnescapeCEscapeSequences(const char* source, char* dest, + std::vector* errors); + +// ---------------------------------------------------------------------- +// UnescapeCEscapeString() +// This does the same thing as UnescapeCEscapeSequences, but creates +// a new string. The caller does not need to worry about allocating +// a dest buffer. This should be used for non performance critical +// tasks such as printing debug messages. It is safe for src and dest +// to be the same. +// +// The second call stores its errors in a supplied string vector. +// If the string vector pointer is nullptr, it reports the errors with LOG(). +// +// In the first and second calls, the length of dest is returned. In the +// the third call, the new string is returned. +// ---------------------------------------------------------------------- + +PROTOBUF_EXPORT int UnescapeCEscapeString(const string& src, string* dest); +PROTOBUF_EXPORT int UnescapeCEscapeString(const string& src, string* dest, + std::vector* errors); +PROTOBUF_EXPORT string UnescapeCEscapeString(const string& src); + +// ---------------------------------------------------------------------- +// CEscape() +// Escapes 'src' using C-style escape sequences and returns the resulting +// string. +// +// Escaped chars: \n, \r, \t, ", ', \, and !isprint(). +// ---------------------------------------------------------------------- +PROTOBUF_EXPORT string CEscape(const string& src); + +// ---------------------------------------------------------------------- +// CEscapeAndAppend() +// Escapes 'src' using C-style escape sequences, and appends the escaped +// string to 'dest'. +// ---------------------------------------------------------------------- +PROTOBUF_EXPORT void CEscapeAndAppend(StringPiece src, string* dest); + +namespace strings { +// Like CEscape() but does not escape bytes with the upper bit set. +PROTOBUF_EXPORT string Utf8SafeCEscape(const string& src); + +// Like CEscape() but uses hex (\x) escapes instead of octals. +PROTOBUF_EXPORT string CHexEscape(const string& src); +} // namespace strings + +// ---------------------------------------------------------------------- +// strto32() +// strtou32() +// strto64() +// strtou64() +// Architecture-neutral plug compatible replacements for strtol() and +// strtoul(). Long's have different lengths on ILP-32 and LP-64 +// platforms, so using these is safer, from the point of view of +// overflow behavior, than using the standard libc functions. +// ---------------------------------------------------------------------- +PROTOBUF_EXPORT int32 strto32_adaptor(const char* nptr, char** endptr, + int base); +PROTOBUF_EXPORT uint32 strtou32_adaptor(const char* nptr, char** endptr, + int base); + +inline int32 strto32(const char *nptr, char **endptr, int base) { + if (sizeof(int32) == sizeof(long)) + return strtol(nptr, endptr, base); + else + return strto32_adaptor(nptr, endptr, base); +} + +inline uint32 strtou32(const char *nptr, char **endptr, int base) { + if (sizeof(uint32) == sizeof(unsigned long)) + return strtoul(nptr, endptr, base); + else + return strtou32_adaptor(nptr, endptr, base); +} + +// For now, long long is 64-bit on all the platforms we care about, so these +// functions can simply pass the call to strto[u]ll. +inline int64 strto64(const char *nptr, char **endptr, int base) { + GOOGLE_COMPILE_ASSERT(sizeof(int64) == sizeof(long long), + sizeof_int64_is_not_sizeof_long_long); + return strtoll(nptr, endptr, base); +} + +inline uint64 strtou64(const char *nptr, char **endptr, int base) { + GOOGLE_COMPILE_ASSERT(sizeof(uint64) == sizeof(unsigned long long), + sizeof_uint64_is_not_sizeof_long_long); + return strtoull(nptr, endptr, base); +} + +// ---------------------------------------------------------------------- +// safe_strtob() +// safe_strto32() +// safe_strtou32() +// safe_strto64() +// safe_strtou64() +// safe_strtof() +// safe_strtod() +// ---------------------------------------------------------------------- +PROTOBUF_EXPORT bool safe_strtob(StringPiece str, bool* value); + +PROTOBUF_EXPORT bool safe_strto32(const string& str, int32* value); +PROTOBUF_EXPORT bool safe_strtou32(const string& str, uint32* value); +inline bool safe_strto32(const char* str, int32* value) { + return safe_strto32(string(str), value); +} +inline bool safe_strto32(StringPiece str, int32* value) { + return safe_strto32(str.ToString(), value); +} +inline bool safe_strtou32(const char* str, uint32* value) { + return safe_strtou32(string(str), value); +} +inline bool safe_strtou32(StringPiece str, uint32* value) { + return safe_strtou32(str.ToString(), value); +} + +PROTOBUF_EXPORT bool safe_strto64(const string& str, int64* value); +PROTOBUF_EXPORT bool safe_strtou64(const string& str, uint64* value); +inline bool safe_strto64(const char* str, int64* value) { + return safe_strto64(string(str), value); +} +inline bool safe_strto64(StringPiece str, int64* value) { + return safe_strto64(str.ToString(), value); +} +inline bool safe_strtou64(const char* str, uint64* value) { + return safe_strtou64(string(str), value); +} +inline bool safe_strtou64(StringPiece str, uint64* value) { + return safe_strtou64(str.ToString(), value); +} + +PROTOBUF_EXPORT bool safe_strtof(const char* str, float* value); +PROTOBUF_EXPORT bool safe_strtod(const char* str, double* value); +inline bool safe_strtof(const string& str, float* value) { + return safe_strtof(str.c_str(), value); +} +inline bool safe_strtod(const string& str, double* value) { + return safe_strtod(str.c_str(), value); +} +inline bool safe_strtof(StringPiece str, float* value) { + return safe_strtof(str.ToString(), value); +} +inline bool safe_strtod(StringPiece str, double* value) { + return safe_strtod(str.ToString(), value); +} + +// ---------------------------------------------------------------------- +// FastIntToBuffer() +// FastHexToBuffer() +// FastHex64ToBuffer() +// FastHex32ToBuffer() +// FastTimeToBuffer() +// These are intended for speed. FastIntToBuffer() assumes the +// integer is non-negative. FastHexToBuffer() puts output in +// hex rather than decimal. FastTimeToBuffer() puts the output +// into RFC822 format. +// +// FastHex64ToBuffer() puts a 64-bit unsigned value in hex-format, +// padded to exactly 16 bytes (plus one byte for '\0') +// +// FastHex32ToBuffer() puts a 32-bit unsigned value in hex-format, +// padded to exactly 8 bytes (plus one byte for '\0') +// +// All functions take the output buffer as an arg. +// They all return a pointer to the beginning of the output, +// which may not be the beginning of the input buffer. +// ---------------------------------------------------------------------- + +// Suggested buffer size for FastToBuffer functions. Also works with +// DoubleToBuffer() and FloatToBuffer(). +static const int kFastToBufferSize = 32; + +PROTOBUF_EXPORT char* FastInt32ToBuffer(int32 i, char* buffer); +PROTOBUF_EXPORT char* FastInt64ToBuffer(int64 i, char* buffer); +char* FastUInt32ToBuffer(uint32 i, char* buffer); // inline below +char* FastUInt64ToBuffer(uint64 i, char* buffer); // inline below +PROTOBUF_EXPORT char* FastHexToBuffer(int i, char* buffer); +PROTOBUF_EXPORT char* FastHex64ToBuffer(uint64 i, char* buffer); +PROTOBUF_EXPORT char* FastHex32ToBuffer(uint32 i, char* buffer); + +// at least 22 bytes long +inline char* FastIntToBuffer(int i, char* buffer) { + return (sizeof(i) == 4 ? + FastInt32ToBuffer(i, buffer) : FastInt64ToBuffer(i, buffer)); +} +inline char* FastUIntToBuffer(unsigned int i, char* buffer) { + return (sizeof(i) == 4 ? + FastUInt32ToBuffer(i, buffer) : FastUInt64ToBuffer(i, buffer)); +} +inline char* FastLongToBuffer(long i, char* buffer) { + return (sizeof(i) == 4 ? + FastInt32ToBuffer(i, buffer) : FastInt64ToBuffer(i, buffer)); +} +inline char* FastULongToBuffer(unsigned long i, char* buffer) { + return (sizeof(i) == 4 ? + FastUInt32ToBuffer(i, buffer) : FastUInt64ToBuffer(i, buffer)); +} + +// ---------------------------------------------------------------------- +// FastInt32ToBufferLeft() +// FastUInt32ToBufferLeft() +// FastInt64ToBufferLeft() +// FastUInt64ToBufferLeft() +// +// Like the Fast*ToBuffer() functions above, these are intended for speed. +// Unlike the Fast*ToBuffer() functions, however, these functions write +// their output to the beginning of the buffer (hence the name, as the +// output is left-aligned). The caller is responsible for ensuring that +// the buffer has enough space to hold the output. +// +// Returns a pointer to the end of the string (i.e. the null character +// terminating the string). +// ---------------------------------------------------------------------- + +PROTOBUF_EXPORT char* FastInt32ToBufferLeft(int32 i, char* buffer); +PROTOBUF_EXPORT char* FastUInt32ToBufferLeft(uint32 i, char* buffer); +PROTOBUF_EXPORT char* FastInt64ToBufferLeft(int64 i, char* buffer); +PROTOBUF_EXPORT char* FastUInt64ToBufferLeft(uint64 i, char* buffer); + +// Just define these in terms of the above. +inline char* FastUInt32ToBuffer(uint32 i, char* buffer) { + FastUInt32ToBufferLeft(i, buffer); + return buffer; +} +inline char* FastUInt64ToBuffer(uint64 i, char* buffer) { + FastUInt64ToBufferLeft(i, buffer); + return buffer; +} + +inline string SimpleBtoa(bool value) { + return value ? "true" : "false"; +} + +// ---------------------------------------------------------------------- +// SimpleItoa() +// Description: converts an integer to a string. +// +// Return value: string +// ---------------------------------------------------------------------- +PROTOBUF_EXPORT string SimpleItoa(int i); +PROTOBUF_EXPORT string SimpleItoa(unsigned int i); +PROTOBUF_EXPORT string SimpleItoa(long i); +PROTOBUF_EXPORT string SimpleItoa(unsigned long i); +PROTOBUF_EXPORT string SimpleItoa(long long i); +PROTOBUF_EXPORT string SimpleItoa(unsigned long long i); + +// ---------------------------------------------------------------------- +// SimpleDtoa() +// SimpleFtoa() +// DoubleToBuffer() +// FloatToBuffer() +// Description: converts a double or float to a string which, if +// passed to NoLocaleStrtod(), will produce the exact same original double +// (except in case of NaN; all NaNs are considered the same value). +// We try to keep the string short but it's not guaranteed to be as +// short as possible. +// +// DoubleToBuffer() and FloatToBuffer() write the text to the given +// buffer and return it. The buffer must be at least +// kDoubleToBufferSize bytes for doubles and kFloatToBufferSize +// bytes for floats. kFastToBufferSize is also guaranteed to be large +// enough to hold either. +// +// Return value: string +// ---------------------------------------------------------------------- +PROTOBUF_EXPORT string SimpleDtoa(double value); +PROTOBUF_EXPORT string SimpleFtoa(float value); + +PROTOBUF_EXPORT char* DoubleToBuffer(double i, char* buffer); +PROTOBUF_EXPORT char* FloatToBuffer(float i, char* buffer); + +// In practice, doubles should never need more than 24 bytes and floats +// should never need more than 14 (including null terminators), but we +// overestimate to be safe. +static const int kDoubleToBufferSize = 32; +static const int kFloatToBufferSize = 24; + +namespace strings { + +enum PadSpec { + NO_PAD = 1, + ZERO_PAD_2, + ZERO_PAD_3, + ZERO_PAD_4, + ZERO_PAD_5, + ZERO_PAD_6, + ZERO_PAD_7, + ZERO_PAD_8, + ZERO_PAD_9, + ZERO_PAD_10, + ZERO_PAD_11, + ZERO_PAD_12, + ZERO_PAD_13, + ZERO_PAD_14, + ZERO_PAD_15, + ZERO_PAD_16, +}; + +struct Hex { + uint64 value; + enum PadSpec spec; + template + explicit Hex(Int v, PadSpec s = NO_PAD) + : spec(s) { + // Prevent sign-extension by casting integers to + // their unsigned counterparts. +#ifdef LANG_CXX11 + static_assert( + sizeof(v) == 1 || sizeof(v) == 2 || sizeof(v) == 4 || sizeof(v) == 8, + "Unknown integer type"); +#endif + value = sizeof(v) == 1 ? static_cast(v) + : sizeof(v) == 2 ? static_cast(v) + : sizeof(v) == 4 ? static_cast(v) + : static_cast(v); + } +}; + +struct PROTOBUF_EXPORT AlphaNum { + const char *piece_data_; // move these to string_ref eventually + size_t piece_size_; // move these to string_ref eventually + + char digits[kFastToBufferSize]; + + // No bool ctor -- bools convert to an integral type. + // A bool ctor would also convert incoming pointers (bletch). + + AlphaNum(int i32) + : piece_data_(digits), + piece_size_(FastInt32ToBufferLeft(i32, digits) - &digits[0]) {} + AlphaNum(unsigned int u32) + : piece_data_(digits), + piece_size_(FastUInt32ToBufferLeft(u32, digits) - &digits[0]) {} + AlphaNum(long long i64) + : piece_data_(digits), + piece_size_(FastInt64ToBufferLeft(i64, digits) - &digits[0]) {} + AlphaNum(unsigned long long u64) + : piece_data_(digits), + piece_size_(FastUInt64ToBufferLeft(u64, digits) - &digits[0]) {} + + // Note: on some architectures, "long" is only 32 bits, not 64, but the + // performance hit of using FastInt64ToBufferLeft to handle 32-bit values + // is quite minor. + AlphaNum(long i64) + : piece_data_(digits), + piece_size_(FastInt64ToBufferLeft(i64, digits) - &digits[0]) {} + AlphaNum(unsigned long u64) + : piece_data_(digits), + piece_size_(FastUInt64ToBufferLeft(u64, digits) - &digits[0]) {} + + AlphaNum(float f) + : piece_data_(digits), piece_size_(strlen(FloatToBuffer(f, digits))) {} + AlphaNum(double f) + : piece_data_(digits), piece_size_(strlen(DoubleToBuffer(f, digits))) {} + + AlphaNum(Hex hex); + + AlphaNum(const char* c_str) + : piece_data_(c_str), piece_size_(strlen(c_str)) {} + // TODO: Add a string_ref constructor, eventually + // AlphaNum(const StringPiece &pc) : piece(pc) {} + + AlphaNum(const string& str) + : piece_data_(str.data()), piece_size_(str.size()) {} + + AlphaNum(StringPiece str) + : piece_data_(str.data()), piece_size_(str.size()) {} + + AlphaNum(internal::StringPiecePod str) + : piece_data_(str.data()), piece_size_(str.size()) {} + + size_t size() const { return piece_size_; } + const char *data() const { return piece_data_; } + + private: + // Use ":" not ':' + AlphaNum(char c); // NOLINT(runtime/explicit) + + // Disallow copy and assign. + AlphaNum(const AlphaNum&); + void operator=(const AlphaNum&); +}; + +} // namespace strings + +using strings::AlphaNum; + +// ---------------------------------------------------------------------- +// StrCat() +// This merges the given strings or numbers, with no delimiter. This +// is designed to be the fastest possible way to construct a string out +// of a mix of raw C strings, strings, bool values, +// and numeric values. +// +// Don't use this for user-visible strings. The localization process +// works poorly on strings built up out of fragments. +// +// For clarity and performance, don't use StrCat when appending to a +// string. In particular, avoid using any of these (anti-)patterns: +// str.append(StrCat(...) +// str += StrCat(...) +// str = StrCat(str, ...) +// where the last is the worse, with the potential to change a loop +// from a linear time operation with O(1) dynamic allocations into a +// quadratic time operation with O(n) dynamic allocations. StrAppend +// is a better choice than any of the above, subject to the restriction +// of StrAppend(&str, a, b, c, ...) that none of the a, b, c, ... may +// be a reference into str. +// ---------------------------------------------------------------------- + +PROTOBUF_EXPORT string StrCat(const AlphaNum& a, const AlphaNum& b); +PROTOBUF_EXPORT string StrCat(const AlphaNum& a, const AlphaNum& b, + const AlphaNum& c); +PROTOBUF_EXPORT string StrCat(const AlphaNum& a, const AlphaNum& b, + const AlphaNum& c, const AlphaNum& d); +PROTOBUF_EXPORT string StrCat(const AlphaNum& a, const AlphaNum& b, + const AlphaNum& c, const AlphaNum& d, + const AlphaNum& e); +PROTOBUF_EXPORT string StrCat(const AlphaNum& a, const AlphaNum& b, + const AlphaNum& c, const AlphaNum& d, + const AlphaNum& e, const AlphaNum& f); +PROTOBUF_EXPORT string StrCat(const AlphaNum& a, const AlphaNum& b, + const AlphaNum& c, const AlphaNum& d, + const AlphaNum& e, const AlphaNum& f, + const AlphaNum& g); +PROTOBUF_EXPORT string StrCat(const AlphaNum& a, const AlphaNum& b, + const AlphaNum& c, const AlphaNum& d, + const AlphaNum& e, const AlphaNum& f, + const AlphaNum& g, const AlphaNum& h); +PROTOBUF_EXPORT string StrCat(const AlphaNum& a, const AlphaNum& b, + const AlphaNum& c, const AlphaNum& d, + const AlphaNum& e, const AlphaNum& f, + const AlphaNum& g, const AlphaNum& h, + const AlphaNum& i); + +inline string StrCat(const AlphaNum& a) { return string(a.data(), a.size()); } + +// ---------------------------------------------------------------------- +// StrAppend() +// Same as above, but adds the output to the given string. +// WARNING: For speed, StrAppend does not try to check each of its input +// arguments to be sure that they are not a subset of the string being +// appended to. That is, while this will work: +// +// string s = "foo"; +// s += s; +// +// This will not (necessarily) work: +// +// string s = "foo"; +// StrAppend(&s, s); +// +// Note: while StrCat supports appending up to 9 arguments, StrAppend +// is currently limited to 4. That's rarely an issue except when +// automatically transforming StrCat to StrAppend, and can easily be +// worked around as consecutive calls to StrAppend are quite efficient. +// ---------------------------------------------------------------------- + +PROTOBUF_EXPORT void StrAppend(string* dest, const AlphaNum& a); +PROTOBUF_EXPORT void StrAppend(string* dest, const AlphaNum& a, + const AlphaNum& b); +PROTOBUF_EXPORT void StrAppend(string* dest, const AlphaNum& a, + const AlphaNum& b, const AlphaNum& c); +PROTOBUF_EXPORT void StrAppend(string* dest, const AlphaNum& a, + const AlphaNum& b, const AlphaNum& c, + const AlphaNum& d); + +// ---------------------------------------------------------------------- +// Join() +// These methods concatenate a range of components into a C++ string, using +// the C-string "delim" as a separator between components. +// ---------------------------------------------------------------------- +template +void Join(Iterator start, Iterator end, + const char* delim, string* result) { + for (Iterator it = start; it != end; ++it) { + if (it != start) { + result->append(delim); + } + StrAppend(result, *it); + } +} + +template +string Join(const Range& components, + const char* delim) { + string result; + Join(components.begin(), components.end(), delim, &result); + return result; +} + +// ---------------------------------------------------------------------- +// ToHex() +// Return a lower-case hex string representation of the given integer. +// ---------------------------------------------------------------------- +PROTOBUF_EXPORT string ToHex(uint64 num); + +// ---------------------------------------------------------------------- +// GlobalReplaceSubstring() +// Replaces all instances of a substring in a string. Does nothing +// if 'substring' is empty. Returns the number of replacements. +// +// NOTE: The string pieces must not overlap s. +// ---------------------------------------------------------------------- +PROTOBUF_EXPORT int GlobalReplaceSubstring(const string& substring, + const string& replacement, + string* s); + +// ---------------------------------------------------------------------- +// Base64Unescape() +// Converts "src" which is encoded in Base64 to its binary equivalent and +// writes it to "dest". If src contains invalid characters, dest is cleared +// and the function returns false. Returns true on success. +// ---------------------------------------------------------------------- +PROTOBUF_EXPORT bool Base64Unescape(StringPiece src, string* dest); + +// ---------------------------------------------------------------------- +// WebSafeBase64Unescape() +// This is a variation of Base64Unescape which uses '-' instead of '+', and +// '_' instead of '/'. src is not null terminated, instead specify len. I +// recommend that slen +struct identity_ { + typedef T type; +}; + +// integral_constant, defined in tr1, is a wrapper for an integer +// value. We don't really need this generality; we could get away +// with hardcoding the integer type to bool. We use the fully +// general integer_constant for compatibility with tr1. + +template +struct integral_constant { + static const T value = v; + typedef T value_type; + typedef integral_constant type; +}; + +template const T integral_constant::value; + + +// Abbreviations: true_type and false_type are structs that represent boolean +// true and false values. Also define the boost::mpl versions of those names, +// true_ and false_. +typedef integral_constant true_type; +typedef integral_constant false_type; +typedef true_type true_; +typedef false_type false_; + +// if_ is a templatized conditional statement. +// if_ is a compile time evaluation of cond. +// if_<>::type contains A if cond is true, B otherwise. +template +struct if_{ + typedef A type; +}; + +template +struct if_ { + typedef B type; +}; + + +// type_equals_ is a template type comparator, similar to Loki IsSameType. +// type_equals_::value is true iff "A" is the same type as "B". +// +// New code should prefer base::is_same, defined in base/type_traits.h. +// It is functionally identical, but is_same is the standard spelling. +template +struct type_equals_ : public false_ { +}; + +template +struct type_equals_ : public true_ { +}; + +// and_ is a template && operator. +// and_::value evaluates "A::value && B::value". +template +struct and_ : public integral_constant { +}; + +// or_ is a template || operator. +// or_::value evaluates "A::value || B::value". +template +struct or_ : public integral_constant { +}; + + +} // namespace internal +} // namespace protobuf +} // namespace google + +#endif // GOOGLE_PROTOBUF_TEMPLATE_UTIL_H_ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/text_format.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/text_format.h new file mode 100644 index 0000000000000000000000000000000000000000..43cb8041792469180262557e1c14666d9482cdac --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/text_format.h @@ -0,0 +1,651 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: jschorr@google.com (Joseph Schorr) +// Based on original Protocol Buffers design by +// Sanjay Ghemawat, Jeff Dean, and others. +// +// Utilities for printing and parsing protocol messages in a human-readable, +// text-based format. + +#ifndef GOOGLE_PROTOBUF_TEXT_FORMAT_H__ +#define GOOGLE_PROTOBUF_TEXT_FORMAT_H__ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include + +#ifdef SWIG +#error "You cannot SWIG proto headers" +#endif + +namespace google { +namespace protobuf { + +namespace io { +class ErrorCollector; // tokenizer.h +} + +// This class implements protocol buffer text format. Printing and parsing +// protocol messages in text format is useful for debugging and human editing +// of messages. +// +// This class is really a namespace that contains only static methods. +class PROTOBUF_EXPORT TextFormat { + public: + // Outputs a textual representation of the given message to the given + // output stream. Returns false if printing fails. + static bool Print(const Message& message, io::ZeroCopyOutputStream* output); + + // Print the fields in an UnknownFieldSet. They are printed by tag number + // only. Embedded messages are heuristically identified by attempting to + // parse them. Returns false if printing fails. + static bool PrintUnknownFields(const UnknownFieldSet& unknown_fields, + io::ZeroCopyOutputStream* output); + + // Like Print(), but outputs directly to a string. + // Note: output will be cleared prior to printing, and will be left empty + // even if printing fails. Returns false if printing fails. + static bool PrintToString(const Message& message, std::string* output); + + // Like PrintUnknownFields(), but outputs directly to a string. Returns + // false if printing fails. + static bool PrintUnknownFieldsToString(const UnknownFieldSet& unknown_fields, + std::string* output); + + // Outputs a textual representation of the value of the field supplied on + // the message supplied. For non-repeated fields, an index of -1 must + // be supplied. Note that this method will print the default value for a + // field if it is not set. + static void PrintFieldValueToString(const Message& message, + const FieldDescriptor* field, int index, + std::string* output); + + class PROTOBUF_EXPORT BaseTextGenerator { + public: + virtual ~BaseTextGenerator(); + + virtual void Indent() {} + virtual void Outdent() {} + // Returns the current indentation size in characters. + virtual size_t GetCurrentIndentationSize() const { return 0; } + + // Print text to the output stream. + virtual void Print(const char* text, size_t size) = 0; + + void PrintString(const std::string& str) { Print(str.data(), str.size()); } + + template + void PrintLiteral(const char (&text)[n]) { + Print(text, n - 1); // n includes the terminating zero character. + } + }; + + // The default printer that converts scalar values from fields into their + // string representation. + // You can derive from this FastFieldValuePrinter if you want to have fields + // to be printed in a different way and register it at the Printer. + class PROTOBUF_EXPORT FastFieldValuePrinter { + public: + FastFieldValuePrinter(); + virtual ~FastFieldValuePrinter(); + virtual void PrintBool(bool val, BaseTextGenerator* generator) const; + virtual void PrintInt32(int32 val, BaseTextGenerator* generator) const; + virtual void PrintUInt32(uint32 val, BaseTextGenerator* generator) const; + virtual void PrintInt64(int64 val, BaseTextGenerator* generator) const; + virtual void PrintUInt64(uint64 val, BaseTextGenerator* generator) const; + virtual void PrintFloat(float val, BaseTextGenerator* generator) const; + virtual void PrintDouble(double val, BaseTextGenerator* generator) const; + virtual void PrintString(const std::string& val, + BaseTextGenerator* generator) const; + virtual void PrintBytes(const std::string& val, + BaseTextGenerator* generator) const; + virtual void PrintEnum(int32 val, const std::string& name, + BaseTextGenerator* generator) const; + virtual void PrintFieldName(const Message& message, int field_index, + int field_count, const Reflection* reflection, + const FieldDescriptor* field, + BaseTextGenerator* generator) const; + virtual void PrintFieldName(const Message& message, + const Reflection* reflection, + const FieldDescriptor* field, + BaseTextGenerator* generator) const; + virtual void PrintMessageStart(const Message& message, int field_index, + int field_count, bool single_line_mode, + BaseTextGenerator* generator) const; + // Allows to override the logic on how to print the content of a message. + // Return false to use the default printing logic. Note that it is legal for + // this function to print something and then return false to use the default + // content printing (although at that point it would behave similarly to + // PrintMessageStart). + virtual bool PrintMessageContent(const Message& message, int field_index, + int field_count, bool single_line_mode, + BaseTextGenerator* generator) const; + virtual void PrintMessageEnd(const Message& message, int field_index, + int field_count, bool single_line_mode, + BaseTextGenerator* generator) const; + + private: + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(FastFieldValuePrinter); + }; + + // Deprecated: please use FastFieldValuePrinter instead. + class PROTOBUF_EXPORT FieldValuePrinter { + public: + FieldValuePrinter(); + virtual ~FieldValuePrinter(); + virtual std::string PrintBool(bool val) const; + virtual std::string PrintInt32(int32 val) const; + virtual std::string PrintUInt32(uint32 val) const; + virtual std::string PrintInt64(int64 val) const; + virtual std::string PrintUInt64(uint64 val) const; + virtual std::string PrintFloat(float val) const; + virtual std::string PrintDouble(double val) const; + virtual std::string PrintString(const std::string& val) const; + virtual std::string PrintBytes(const std::string& val) const; + virtual std::string PrintEnum(int32 val, const std::string& name) const; + virtual std::string PrintFieldName(const Message& message, + const Reflection* reflection, + const FieldDescriptor* field) const; + virtual std::string PrintMessageStart(const Message& message, + int field_index, int field_count, + bool single_line_mode) const; + virtual std::string PrintMessageEnd(const Message& message, int field_index, + int field_count, + bool single_line_mode) const; + + private: + FastFieldValuePrinter delegate_; + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(FieldValuePrinter); + }; + + class PROTOBUF_EXPORT MessagePrinter { + public: + MessagePrinter() {} + virtual ~MessagePrinter() {} + virtual void Print(const Message& message, bool single_line_mode, + BaseTextGenerator* generator) const = 0; + + private: + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(MessagePrinter); + }; + + // Interface that Printers or Parsers can use to find extensions, or types + // referenced in Any messages. + class PROTOBUF_EXPORT Finder { + public: + virtual ~Finder(); + + // Try to find an extension of *message by fully-qualified field + // name. Returns NULL if no extension is known for this name or number. + // The base implementation uses the extensions already known by the message. + virtual const FieldDescriptor* FindExtension(Message* message, + const std::string& name) const; + + // Similar to FindExtension, but uses a Descriptor and the extension number + // instead of using a Message and the name when doing the look up. + virtual const FieldDescriptor* FindExtensionByNumber( + const Descriptor* descriptor, int number) const; + + // Find the message type for an Any proto. + // Returns NULL if no message is known for this name. + // The base implementation only accepts prefixes of type.googleprod.com/ or + // type.googleapis.com/, and searches the DescriptorPool of the parent + // message. + virtual const Descriptor* FindAnyType(const Message& message, + const std::string& prefix, + const std::string& name) const; + + // Find the message factory for the given extension field. This can be used + // to generalize the Parser to add extension fields to a message in the same + // way as the "input" message for the Parser. + virtual MessageFactory* FindExtensionFactory( + const FieldDescriptor* field) const; + }; + + // Class for those users which require more fine-grained control over how + // a protobuffer message is printed out. + class PROTOBUF_EXPORT Printer { + public: + Printer(); + + // Like TextFormat::Print + bool Print(const Message& message, io::ZeroCopyOutputStream* output) const; + // Like TextFormat::PrintUnknownFields + bool PrintUnknownFields(const UnknownFieldSet& unknown_fields, + io::ZeroCopyOutputStream* output) const; + // Like TextFormat::PrintToString + bool PrintToString(const Message& message, std::string* output) const; + // Like TextFormat::PrintUnknownFieldsToString + bool PrintUnknownFieldsToString(const UnknownFieldSet& unknown_fields, + std::string* output) const; + // Like TextFormat::PrintFieldValueToString + void PrintFieldValueToString(const Message& message, + const FieldDescriptor* field, int index, + std::string* output) const; + + // Adjust the initial indent level of all output. Each indent level is + // equal to two spaces. + void SetInitialIndentLevel(int indent_level) { + initial_indent_level_ = indent_level; + } + + // If printing in single line mode, then the entire message will be output + // on a single line with no line breaks. + void SetSingleLineMode(bool single_line_mode) { + single_line_mode_ = single_line_mode; + } + + bool IsInSingleLineMode() const { return single_line_mode_; } + + // If use_field_number is true, uses field number instead of field name. + void SetUseFieldNumber(bool use_field_number) { + use_field_number_ = use_field_number; + } + + // Set true to print repeated primitives in a format like: + // field_name: [1, 2, 3, 4] + // instead of printing each value on its own line. Short format applies + // only to primitive values -- i.e. everything except strings and + // sub-messages/groups. + void SetUseShortRepeatedPrimitives(bool use_short_repeated_primitives) { + use_short_repeated_primitives_ = use_short_repeated_primitives; + } + + // Set true to output UTF-8 instead of ASCII. The only difference + // is that bytes >= 0x80 in string fields will not be escaped, + // because they are assumed to be part of UTF-8 multi-byte + // sequences. This will change the default FastFieldValuePrinter. + void SetUseUtf8StringEscaping(bool as_utf8); + + // Set the default FastFieldValuePrinter that is used for all fields that + // don't have a field-specific printer registered. + // Takes ownership of the printer. + void SetDefaultFieldValuePrinter(const FastFieldValuePrinter* printer); + + PROTOBUF_DEPRECATED_MSG("Please use FastFieldValuePrinter") + void SetDefaultFieldValuePrinter(const FieldValuePrinter* printer); + + // Sets whether we want to hide unknown fields or not. + // Usually unknown fields are printed in a generic way that includes the + // tag number of the field instead of field name. However, sometimes it + // is useful to be able to print the message without unknown fields (e.g. + // for the python protobuf version to maintain consistency between its pure + // python and c++ implementations). + void SetHideUnknownFields(bool hide) { hide_unknown_fields_ = hide; } + + // If print_message_fields_in_index_order is true, fields of a proto message + // will be printed using the order defined in source code instead of the + // field number, extensions will be printed at the end of the message + // and their relative order is determined by the extension number. + // By default, use the field number order. + void SetPrintMessageFieldsInIndexOrder( + bool print_message_fields_in_index_order) { + print_message_fields_in_index_order_ = + print_message_fields_in_index_order; + } + + // If expand==true, expand google.protobuf.Any payloads. The output + // will be of form + // [type_url] { } + // + // If expand==false, print Any using the default printer. The output will + // look like + // type_url: "" value: "serialized_content" + void SetExpandAny(bool expand) { expand_any_ = expand; } + + // Set how parser finds message for Any payloads. + void SetFinder(const Finder* finder) { finder_ = finder; } + + // If non-zero, we truncate all string fields that are longer than + // this threshold. This is useful when the proto message has very long + // strings, e.g., dump of encoded image file. + // + // NOTE(hfgong): Setting a non-zero value breaks round-trip safe + // property of TextFormat::Printer. That is, from the printed message, we + // cannot fully recover the original string field any more. + void SetTruncateStringFieldLongerThan( + const int64 truncate_string_field_longer_than) { + truncate_string_field_longer_than_ = truncate_string_field_longer_than; + } + + // Register a custom field-specific FastFieldValuePrinter for fields + // with a particular FieldDescriptor. + // Returns "true" if the registration succeeded, or "false", if there is + // already a printer for that FieldDescriptor. + // Takes ownership of the printer on successful registration. + bool RegisterFieldValuePrinter(const FieldDescriptor* field, + const FastFieldValuePrinter* printer); + + PROTOBUF_DEPRECATED_MSG("Please use FastFieldValuePrinter") + bool RegisterFieldValuePrinter(const FieldDescriptor* field, + const FieldValuePrinter* printer); + + // Register a custom message-specific MessagePrinter for messages with a + // particular Descriptor. + // Returns "true" if the registration succeeded, or "false" if there is + // already a printer for that Descriptor. + bool RegisterMessagePrinter(const Descriptor* descriptor, + const MessagePrinter* printer); + + private: + // Forward declaration of an internal class used to print the text + // output to the OutputStream (see text_format.cc for implementation). + class TextGenerator; + + static const char* const kDoNotParse; + + // Internal Print method, used for writing to the OutputStream via + // the TextGenerator class. + void Print(const Message& message, TextGenerator* generator) const; + + // Print a single field. + void PrintField(const Message& message, const Reflection* reflection, + const FieldDescriptor* field, + TextGenerator* generator) const; + + // Print a repeated primitive field in short form. + void PrintShortRepeatedField(const Message& message, + const Reflection* reflection, + const FieldDescriptor* field, + TextGenerator* generator) const; + + // Print the name of a field -- i.e. everything that comes before the + // ':' for a single name/value pair. + void PrintFieldName(const Message& message, int field_index, + int field_count, const Reflection* reflection, + const FieldDescriptor* field, + TextGenerator* generator) const; + + // Outputs a textual representation of the value of the field supplied on + // the message supplied or the default value if not set. + void PrintFieldValue(const Message& message, const Reflection* reflection, + const FieldDescriptor* field, int index, + TextGenerator* generator) const; + + // Print the fields in an UnknownFieldSet. They are printed by tag number + // only. Embedded messages are heuristically identified by attempting to + // parse them (subject to the recursion budget). + void PrintUnknownFields(const UnknownFieldSet& unknown_fields, + TextGenerator* generator, + int recursion_budget) const; + + bool PrintAny(const Message& message, TextGenerator* generator) const; + + const FastFieldValuePrinter* GetFieldPrinter( + const FieldDescriptor* field) const { + auto it = custom_printers_.find(field); + return it == custom_printers_.end() ? default_field_value_printer_.get() + : it->second.get(); + } + + int initial_indent_level_; + bool single_line_mode_; + bool use_field_number_; + bool use_short_repeated_primitives_; + bool hide_unknown_fields_; + bool print_message_fields_in_index_order_; + bool expand_any_; + int64 truncate_string_field_longer_than_; + + std::unique_ptr default_field_value_printer_; + typedef std::map> + CustomPrinterMap; + CustomPrinterMap custom_printers_; + + typedef std::map> + CustomMessagePrinterMap; + CustomMessagePrinterMap custom_message_printers_; + + const Finder* finder_; + }; + + // Parses a text-format protocol message from the given input stream to + // the given message object. This function parses the human-readable format + // written by Print(). Returns true on success. The message is cleared first, + // even if the function fails -- See Merge() to avoid this behavior. + // + // Example input: "user {\n id: 123 extra { gender: MALE language: 'en' }\n}" + // + // One use for this function is parsing handwritten strings in test code. + // Another use is to parse the output from google::protobuf::Message::DebugString() + // (or ShortDebugString()), because these functions output using + // google::protobuf::TextFormat::Print(). + // + // If you would like to read a protocol buffer serialized in the + // (non-human-readable) binary wire format, see + // google::protobuf::MessageLite::ParseFromString(). + static bool Parse(io::ZeroCopyInputStream* input, Message* output); + // Like Parse(), but reads directly from a string. + static bool ParseFromString(const std::string& input, Message* output); + + // Like Parse(), but the data is merged into the given message, as if + // using Message::MergeFrom(). + static bool Merge(io::ZeroCopyInputStream* input, Message* output); + // Like Merge(), but reads directly from a string. + static bool MergeFromString(const std::string& input, Message* output); + + // Parse the given text as a single field value and store it into the + // given field of the given message. If the field is a repeated field, + // the new value will be added to the end + static bool ParseFieldValueFromString(const std::string& input, + const FieldDescriptor* field, + Message* message); + + // A location in the parsed text. + struct ParseLocation { + int line; + int column; + + ParseLocation() : line(-1), column(-1) {} + ParseLocation(int line_param, int column_param) + : line(line_param), column(column_param) {} + }; + + // Data structure which is populated with the locations of each field + // value parsed from the text. + class PROTOBUF_EXPORT ParseInfoTree { + public: + ParseInfoTree() = default; + ParseInfoTree(const ParseInfoTree&) = delete; + ParseInfoTree& operator=(const ParseInfoTree&) = delete; + + // Returns the parse location for index-th value of the field in the parsed + // text. If none exists, returns a location with line = -1. Index should be + // -1 for not-repeated fields. + ParseLocation GetLocation(const FieldDescriptor* field, int index) const; + + // Returns the parse info tree for the given field, which must be a message + // type. The nested information tree is owned by the root tree and will be + // deleted when it is deleted. + ParseInfoTree* GetTreeForNested(const FieldDescriptor* field, + int index) const; + + private: + // Allow the text format parser to record information into the tree. + friend class TextFormat; + + // Records the starting location of a single value for a field. + void RecordLocation(const FieldDescriptor* field, ParseLocation location); + + // Create and records a nested tree for a nested message field. + ParseInfoTree* CreateNested(const FieldDescriptor* field); + + // Defines the map from the index-th field descriptor to its parse location. + typedef std::map > + LocationMap; + + // Defines the map from the index-th field descriptor to the nested parse + // info tree. + typedef std::map>> + NestedMap; + + LocationMap locations_; + NestedMap nested_; + }; + + // For more control over parsing, use this class. + class PROTOBUF_EXPORT Parser { + public: + Parser(); + ~Parser(); + + // Like TextFormat::Parse(). + bool Parse(io::ZeroCopyInputStream* input, Message* output); + // Like TextFormat::ParseFromString(). + bool ParseFromString(const std::string& input, Message* output); + // Like TextFormat::Merge(). + bool Merge(io::ZeroCopyInputStream* input, Message* output); + // Like TextFormat::MergeFromString(). + bool MergeFromString(const std::string& input, Message* output); + + // Set where to report parse errors. If NULL (the default), errors will + // be printed to stderr. + void RecordErrorsTo(io::ErrorCollector* error_collector) { + error_collector_ = error_collector; + } + + // Set how parser finds extensions. If NULL (the default), the + // parser will use the standard Reflection object associated with + // the message being parsed. + void SetFinder(const Finder* finder) { finder_ = finder; } + + // Sets where location information about the parse will be written. If NULL + // (the default), then no location will be written. + void WriteLocationsTo(ParseInfoTree* tree) { parse_info_tree_ = tree; } + + // Normally parsing fails if, after parsing, output->IsInitialized() + // returns false. Call AllowPartialMessage(true) to skip this check. + void AllowPartialMessage(bool allow) { allow_partial_ = allow; } + + // Allow field names to be matched case-insensitively. + // This is not advisable if there are fields that only differ in case, or + // if you want to enforce writing in the canonical form. + // This is 'false' by default. + void AllowCaseInsensitiveField(bool allow) { + allow_case_insensitive_field_ = allow; + } + + // Like TextFormat::ParseFieldValueFromString + bool ParseFieldValueFromString(const std::string& input, + const FieldDescriptor* field, + Message* output); + + // When an unknown extension is met, parsing will fail if this option is set + // to false (the default). If true, unknown extensions will be ignored and + // a warning message will be generated. + void AllowUnknownExtension(bool allow) { allow_unknown_extension_ = allow; } + + // When an unknown field is met, parsing will fail if this option is set + // to false(the default). If true, unknown fields will be ignored and + // a warning message will be generated. + // Please aware that set this option true may hide some errors (e.g. + // spelling error on field name). Avoid to use this option if possible. + void AllowUnknownField(bool allow) { allow_unknown_field_ = allow; } + + + void AllowFieldNumber(bool allow) { allow_field_number_ = allow; } + + // Sets maximum recursion depth which parser can use. This is effectively + // the maximum allowed nesting of proto messages. + void SetRecursionLimit(int limit) { recursion_limit_ = limit; } + + private: + // Forward declaration of an internal class used to parse text + // representations (see text_format.cc for implementation). + class ParserImpl; + + // Like TextFormat::Merge(). The provided implementation is used + // to do the parsing. + bool MergeUsingImpl(io::ZeroCopyInputStream* input, Message* output, + ParserImpl* parser_impl); + + io::ErrorCollector* error_collector_; + const Finder* finder_; + ParseInfoTree* parse_info_tree_; + bool allow_partial_; + bool allow_case_insensitive_field_; + bool allow_unknown_field_; + bool allow_unknown_extension_; + bool allow_unknown_enum_; + bool allow_field_number_; + bool allow_relaxed_whitespace_; + bool allow_singular_overwrites_; + int recursion_limit_; + }; + + + private: + // Hack: ParseInfoTree declares TextFormat as a friend which should extend + // the friendship to TextFormat::Parser::ParserImpl, but unfortunately some + // old compilers (e.g. GCC 3.4.6) don't implement this correctly. We provide + // helpers for ParserImpl to call methods of ParseInfoTree. + static inline void RecordLocation(ParseInfoTree* info_tree, + const FieldDescriptor* field, + ParseLocation location); + static inline ParseInfoTree* CreateNested(ParseInfoTree* info_tree, + const FieldDescriptor* field); + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(TextFormat); +}; + +inline void TextFormat::RecordLocation(ParseInfoTree* info_tree, + const FieldDescriptor* field, + ParseLocation location) { + info_tree->RecordLocation(field, location); +} + +inline TextFormat::ParseInfoTree* TextFormat::CreateNested( + ParseInfoTree* info_tree, const FieldDescriptor* field) { + return info_tree->CreateNested(field); +} + +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_TEXT_FORMAT_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/timestamp.pb.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/timestamp.pb.h new file mode 100644 index 0000000000000000000000000000000000000000..95f333e728d8c81d5315bb1ad4143b024ecb9c4c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/timestamp.pb.h @@ -0,0 +1,282 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: google/protobuf/timestamp.proto + +#ifndef GOOGLE_PROTOBUF_INCLUDED_google_2fprotobuf_2ftimestamp_2eproto +#define GOOGLE_PROTOBUF_INCLUDED_google_2fprotobuf_2ftimestamp_2eproto + +#include +#include + +#include +#if PROTOBUF_VERSION < 3013000 +#error This file was generated by a newer version of protoc which is +#error incompatible with your Protocol Buffer headers. Please update +#error your headers. +#endif +#if 3013000 < PROTOBUF_MIN_PROTOC_VERSION +#error This file was generated by an older version of protoc which is +#error incompatible with your Protocol Buffer headers. Please +#error regenerate this file with a newer version of protoc. +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include // IWYU pragma: export +#include // IWYU pragma: export +#include +// @@protoc_insertion_point(includes) +#include +#define PROTOBUF_INTERNAL_EXPORT_google_2fprotobuf_2ftimestamp_2eproto PROTOBUF_EXPORT +PROTOBUF_NAMESPACE_OPEN +namespace internal { +class AnyMetadata; +} // namespace internal +PROTOBUF_NAMESPACE_CLOSE + +// Internal implementation detail -- do not use these members. +struct PROTOBUF_EXPORT TableStruct_google_2fprotobuf_2ftimestamp_2eproto { + static const ::PROTOBUF_NAMESPACE_ID::internal::ParseTableField entries[] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::AuxiliaryParseTableField aux[] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::ParseTable schema[1] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::FieldMetadata field_metadata[]; + static const ::PROTOBUF_NAMESPACE_ID::internal::SerializationTable serialization_table[]; + static const ::PROTOBUF_NAMESPACE_ID::uint32 offsets[]; +}; +extern PROTOBUF_EXPORT const ::PROTOBUF_NAMESPACE_ID::internal::DescriptorTable descriptor_table_google_2fprotobuf_2ftimestamp_2eproto; +PROTOBUF_NAMESPACE_OPEN +class Timestamp; +class TimestampDefaultTypeInternal; +PROTOBUF_EXPORT extern TimestampDefaultTypeInternal _Timestamp_default_instance_; +PROTOBUF_NAMESPACE_CLOSE +PROTOBUF_NAMESPACE_OPEN +template<> PROTOBUF_EXPORT PROTOBUF_NAMESPACE_ID::Timestamp* Arena::CreateMaybeMessage(Arena*); +PROTOBUF_NAMESPACE_CLOSE +PROTOBUF_NAMESPACE_OPEN + +// =================================================================== + +class PROTOBUF_EXPORT Timestamp PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:google.protobuf.Timestamp) */ { + public: + inline Timestamp() : Timestamp(nullptr) {} + virtual ~Timestamp(); + + Timestamp(const Timestamp& from); + Timestamp(Timestamp&& from) noexcept + : Timestamp() { + *this = ::std::move(from); + } + + inline Timestamp& operator=(const Timestamp& from) { + CopyFrom(from); + return *this; + } + inline Timestamp& operator=(Timestamp&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const Timestamp& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const Timestamp* internal_default_instance() { + return reinterpret_cast( + &_Timestamp_default_instance_); + } + static constexpr int kIndexInFileMessages = + 0; + + friend void swap(Timestamp& a, Timestamp& b) { + a.Swap(&b); + } + inline void Swap(Timestamp* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(Timestamp* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline Timestamp* New() const final { + return CreateMaybeMessage(nullptr); + } + + Timestamp* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const Timestamp& from); + void MergeFrom(const Timestamp& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(Timestamp* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "google.protobuf.Timestamp"; + } + protected: + explicit Timestamp(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_google_2fprotobuf_2ftimestamp_2eproto); + return ::descriptor_table_google_2fprotobuf_2ftimestamp_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kSecondsFieldNumber = 1, + kNanosFieldNumber = 2, + }; + // int64 seconds = 1; + void clear_seconds(); + ::PROTOBUF_NAMESPACE_ID::int64 seconds() const; + void set_seconds(::PROTOBUF_NAMESPACE_ID::int64 value); + private: + ::PROTOBUF_NAMESPACE_ID::int64 _internal_seconds() const; + void _internal_set_seconds(::PROTOBUF_NAMESPACE_ID::int64 value); + public: + + // int32 nanos = 2; + void clear_nanos(); + ::PROTOBUF_NAMESPACE_ID::int32 nanos() const; + void set_nanos(::PROTOBUF_NAMESPACE_ID::int32 value); + private: + ::PROTOBUF_NAMESPACE_ID::int32 _internal_nanos() const; + void _internal_set_nanos(::PROTOBUF_NAMESPACE_ID::int32 value); + public: + + // @@protoc_insertion_point(class_scope:google.protobuf.Timestamp) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::int64 seconds_; + ::PROTOBUF_NAMESPACE_ID::int32 nanos_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + friend struct ::TableStruct_google_2fprotobuf_2ftimestamp_2eproto; +}; +// =================================================================== + + +// =================================================================== + +#ifdef __GNUC__ + #pragma GCC diagnostic push + #pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif // __GNUC__ +// Timestamp + +// int64 seconds = 1; +inline void Timestamp::clear_seconds() { + seconds_ = PROTOBUF_LONGLONG(0); +} +inline ::PROTOBUF_NAMESPACE_ID::int64 Timestamp::_internal_seconds() const { + return seconds_; +} +inline ::PROTOBUF_NAMESPACE_ID::int64 Timestamp::seconds() const { + // @@protoc_insertion_point(field_get:google.protobuf.Timestamp.seconds) + return _internal_seconds(); +} +inline void Timestamp::_internal_set_seconds(::PROTOBUF_NAMESPACE_ID::int64 value) { + + seconds_ = value; +} +inline void Timestamp::set_seconds(::PROTOBUF_NAMESPACE_ID::int64 value) { + _internal_set_seconds(value); + // @@protoc_insertion_point(field_set:google.protobuf.Timestamp.seconds) +} + +// int32 nanos = 2; +inline void Timestamp::clear_nanos() { + nanos_ = 0; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 Timestamp::_internal_nanos() const { + return nanos_; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 Timestamp::nanos() const { + // @@protoc_insertion_point(field_get:google.protobuf.Timestamp.nanos) + return _internal_nanos(); +} +inline void Timestamp::_internal_set_nanos(::PROTOBUF_NAMESPACE_ID::int32 value) { + + nanos_ = value; +} +inline void Timestamp::set_nanos(::PROTOBUF_NAMESPACE_ID::int32 value) { + _internal_set_nanos(value); + // @@protoc_insertion_point(field_set:google.protobuf.Timestamp.nanos) +} + +#ifdef __GNUC__ + #pragma GCC diagnostic pop +#endif // __GNUC__ + +// @@protoc_insertion_point(namespace_scope) + +PROTOBUF_NAMESPACE_CLOSE + +// @@protoc_insertion_point(global_scope) + +#include +#endif // GOOGLE_PROTOBUF_INCLUDED_GOOGLE_PROTOBUF_INCLUDED_google_2fprotobuf_2ftimestamp_2eproto + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/type.pb.h b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/type.pb.h new file mode 100644 index 0000000000000000000000000000000000000000..e54ceddf44abf0249bac5814020fa0ec0b05de88 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/google/protobuf/type.pb.h @@ -0,0 +1,2612 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: google/protobuf/type.proto + +#ifndef GOOGLE_PROTOBUF_INCLUDED_google_2fprotobuf_2ftype_2eproto +#define GOOGLE_PROTOBUF_INCLUDED_google_2fprotobuf_2ftype_2eproto + +#include +#include + +#include +#if PROTOBUF_VERSION < 3013000 +#error This file was generated by a newer version of protoc which is +#error incompatible with your Protocol Buffer headers. Please update +#error your headers. +#endif +#if 3013000 < PROTOBUF_MIN_PROTOC_VERSION +#error This file was generated by an older version of protoc which is +#error incompatible with your Protocol Buffer headers. Please +#error regenerate this file with a newer version of protoc. +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include // IWYU pragma: export +#include // IWYU pragma: export +#include +#include +#include +#include +// @@protoc_insertion_point(includes) +#include +#define PROTOBUF_INTERNAL_EXPORT_google_2fprotobuf_2ftype_2eproto PROTOBUF_EXPORT +PROTOBUF_NAMESPACE_OPEN +namespace internal { +class AnyMetadata; +} // namespace internal +PROTOBUF_NAMESPACE_CLOSE + +// Internal implementation detail -- do not use these members. +struct PROTOBUF_EXPORT TableStruct_google_2fprotobuf_2ftype_2eproto { + static const ::PROTOBUF_NAMESPACE_ID::internal::ParseTableField entries[] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::AuxiliaryParseTableField aux[] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::ParseTable schema[5] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::FieldMetadata field_metadata[]; + static const ::PROTOBUF_NAMESPACE_ID::internal::SerializationTable serialization_table[]; + static const ::PROTOBUF_NAMESPACE_ID::uint32 offsets[]; +}; +extern PROTOBUF_EXPORT const ::PROTOBUF_NAMESPACE_ID::internal::DescriptorTable descriptor_table_google_2fprotobuf_2ftype_2eproto; +PROTOBUF_NAMESPACE_OPEN +class Enum; +class EnumDefaultTypeInternal; +PROTOBUF_EXPORT extern EnumDefaultTypeInternal _Enum_default_instance_; +class EnumValue; +class EnumValueDefaultTypeInternal; +PROTOBUF_EXPORT extern EnumValueDefaultTypeInternal _EnumValue_default_instance_; +class Field; +class FieldDefaultTypeInternal; +PROTOBUF_EXPORT extern FieldDefaultTypeInternal _Field_default_instance_; +class Option; +class OptionDefaultTypeInternal; +PROTOBUF_EXPORT extern OptionDefaultTypeInternal _Option_default_instance_; +class Type; +class TypeDefaultTypeInternal; +PROTOBUF_EXPORT extern TypeDefaultTypeInternal _Type_default_instance_; +PROTOBUF_NAMESPACE_CLOSE +PROTOBUF_NAMESPACE_OPEN +template<> PROTOBUF_EXPORT PROTOBUF_NAMESPACE_ID::Enum* Arena::CreateMaybeMessage(Arena*); +template<> PROTOBUF_EXPORT PROTOBUF_NAMESPACE_ID::EnumValue* Arena::CreateMaybeMessage(Arena*); +template<> PROTOBUF_EXPORT PROTOBUF_NAMESPACE_ID::Field* Arena::CreateMaybeMessage(Arena*); +template<> PROTOBUF_EXPORT PROTOBUF_NAMESPACE_ID::Option* Arena::CreateMaybeMessage(Arena*); +template<> PROTOBUF_EXPORT PROTOBUF_NAMESPACE_ID::Type* Arena::CreateMaybeMessage(Arena*); +PROTOBUF_NAMESPACE_CLOSE +PROTOBUF_NAMESPACE_OPEN + +enum Field_Kind : int { + Field_Kind_TYPE_UNKNOWN = 0, + Field_Kind_TYPE_DOUBLE = 1, + Field_Kind_TYPE_FLOAT = 2, + Field_Kind_TYPE_INT64 = 3, + Field_Kind_TYPE_UINT64 = 4, + Field_Kind_TYPE_INT32 = 5, + Field_Kind_TYPE_FIXED64 = 6, + Field_Kind_TYPE_FIXED32 = 7, + Field_Kind_TYPE_BOOL = 8, + Field_Kind_TYPE_STRING = 9, + Field_Kind_TYPE_GROUP = 10, + Field_Kind_TYPE_MESSAGE = 11, + Field_Kind_TYPE_BYTES = 12, + Field_Kind_TYPE_UINT32 = 13, + Field_Kind_TYPE_ENUM = 14, + Field_Kind_TYPE_SFIXED32 = 15, + Field_Kind_TYPE_SFIXED64 = 16, + Field_Kind_TYPE_SINT32 = 17, + Field_Kind_TYPE_SINT64 = 18, + Field_Kind_Field_Kind_INT_MIN_SENTINEL_DO_NOT_USE_ = std::numeric_limits<::PROTOBUF_NAMESPACE_ID::int32>::min(), + Field_Kind_Field_Kind_INT_MAX_SENTINEL_DO_NOT_USE_ = std::numeric_limits<::PROTOBUF_NAMESPACE_ID::int32>::max() +}; +PROTOBUF_EXPORT bool Field_Kind_IsValid(int value); +constexpr Field_Kind Field_Kind_Kind_MIN = Field_Kind_TYPE_UNKNOWN; +constexpr Field_Kind Field_Kind_Kind_MAX = Field_Kind_TYPE_SINT64; +constexpr int Field_Kind_Kind_ARRAYSIZE = Field_Kind_Kind_MAX + 1; + +PROTOBUF_EXPORT const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* Field_Kind_descriptor(); +template +inline const std::string& Field_Kind_Name(T enum_t_value) { + static_assert(::std::is_same::value || + ::std::is_integral::value, + "Incorrect type passed to function Field_Kind_Name."); + return ::PROTOBUF_NAMESPACE_ID::internal::NameOfEnum( + Field_Kind_descriptor(), enum_t_value); +} +inline bool Field_Kind_Parse( + ::PROTOBUF_NAMESPACE_ID::ConstStringParam name, Field_Kind* value) { + return ::PROTOBUF_NAMESPACE_ID::internal::ParseNamedEnum( + Field_Kind_descriptor(), name, value); +} +enum Field_Cardinality : int { + Field_Cardinality_CARDINALITY_UNKNOWN = 0, + Field_Cardinality_CARDINALITY_OPTIONAL = 1, + Field_Cardinality_CARDINALITY_REQUIRED = 2, + Field_Cardinality_CARDINALITY_REPEATED = 3, + Field_Cardinality_Field_Cardinality_INT_MIN_SENTINEL_DO_NOT_USE_ = std::numeric_limits<::PROTOBUF_NAMESPACE_ID::int32>::min(), + Field_Cardinality_Field_Cardinality_INT_MAX_SENTINEL_DO_NOT_USE_ = std::numeric_limits<::PROTOBUF_NAMESPACE_ID::int32>::max() +}; +PROTOBUF_EXPORT bool Field_Cardinality_IsValid(int value); +constexpr Field_Cardinality Field_Cardinality_Cardinality_MIN = Field_Cardinality_CARDINALITY_UNKNOWN; +constexpr Field_Cardinality Field_Cardinality_Cardinality_MAX = Field_Cardinality_CARDINALITY_REPEATED; +constexpr int Field_Cardinality_Cardinality_ARRAYSIZE = Field_Cardinality_Cardinality_MAX + 1; + +PROTOBUF_EXPORT const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* Field_Cardinality_descriptor(); +template +inline const std::string& Field_Cardinality_Name(T enum_t_value) { + static_assert(::std::is_same::value || + ::std::is_integral::value, + "Incorrect type passed to function Field_Cardinality_Name."); + return ::PROTOBUF_NAMESPACE_ID::internal::NameOfEnum( + Field_Cardinality_descriptor(), enum_t_value); +} +inline bool Field_Cardinality_Parse( + ::PROTOBUF_NAMESPACE_ID::ConstStringParam name, Field_Cardinality* value) { + return ::PROTOBUF_NAMESPACE_ID::internal::ParseNamedEnum( + Field_Cardinality_descriptor(), name, value); +} +enum Syntax : int { + SYNTAX_PROTO2 = 0, + SYNTAX_PROTO3 = 1, + Syntax_INT_MIN_SENTINEL_DO_NOT_USE_ = std::numeric_limits<::PROTOBUF_NAMESPACE_ID::int32>::min(), + Syntax_INT_MAX_SENTINEL_DO_NOT_USE_ = std::numeric_limits<::PROTOBUF_NAMESPACE_ID::int32>::max() +}; +PROTOBUF_EXPORT bool Syntax_IsValid(int value); +constexpr Syntax Syntax_MIN = SYNTAX_PROTO2; +constexpr Syntax Syntax_MAX = SYNTAX_PROTO3; +constexpr int Syntax_ARRAYSIZE = Syntax_MAX + 1; + +PROTOBUF_EXPORT const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* Syntax_descriptor(); +template +inline const std::string& Syntax_Name(T enum_t_value) { + static_assert(::std::is_same::value || + ::std::is_integral::value, + "Incorrect type passed to function Syntax_Name."); + return ::PROTOBUF_NAMESPACE_ID::internal::NameOfEnum( + Syntax_descriptor(), enum_t_value); +} +inline bool Syntax_Parse( + ::PROTOBUF_NAMESPACE_ID::ConstStringParam name, Syntax* value) { + return ::PROTOBUF_NAMESPACE_ID::internal::ParseNamedEnum( + Syntax_descriptor(), name, value); +} +// =================================================================== + +class PROTOBUF_EXPORT Type PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:google.protobuf.Type) */ { + public: + inline Type() : Type(nullptr) {} + virtual ~Type(); + + Type(const Type& from); + Type(Type&& from) noexcept + : Type() { + *this = ::std::move(from); + } + + inline Type& operator=(const Type& from) { + CopyFrom(from); + return *this; + } + inline Type& operator=(Type&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const Type& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const Type* internal_default_instance() { + return reinterpret_cast( + &_Type_default_instance_); + } + static constexpr int kIndexInFileMessages = + 0; + + friend void swap(Type& a, Type& b) { + a.Swap(&b); + } + inline void Swap(Type* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(Type* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline Type* New() const final { + return CreateMaybeMessage(nullptr); + } + + Type* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const Type& from); + void MergeFrom(const Type& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(Type* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "google.protobuf.Type"; + } + protected: + explicit Type(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_google_2fprotobuf_2ftype_2eproto); + return ::descriptor_table_google_2fprotobuf_2ftype_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kFieldsFieldNumber = 2, + kOneofsFieldNumber = 3, + kOptionsFieldNumber = 4, + kNameFieldNumber = 1, + kSourceContextFieldNumber = 5, + kSyntaxFieldNumber = 6, + }; + // repeated .google.protobuf.Field fields = 2; + int fields_size() const; + private: + int _internal_fields_size() const; + public: + void clear_fields(); + PROTOBUF_NAMESPACE_ID::Field* mutable_fields(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::Field >* + mutable_fields(); + private: + const PROTOBUF_NAMESPACE_ID::Field& _internal_fields(int index) const; + PROTOBUF_NAMESPACE_ID::Field* _internal_add_fields(); + public: + const PROTOBUF_NAMESPACE_ID::Field& fields(int index) const; + PROTOBUF_NAMESPACE_ID::Field* add_fields(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::Field >& + fields() const; + + // repeated string oneofs = 3; + int oneofs_size() const; + private: + int _internal_oneofs_size() const; + public: + void clear_oneofs(); + const std::string& oneofs(int index) const; + std::string* mutable_oneofs(int index); + void set_oneofs(int index, const std::string& value); + void set_oneofs(int index, std::string&& value); + void set_oneofs(int index, const char* value); + void set_oneofs(int index, const char* value, size_t size); + std::string* add_oneofs(); + void add_oneofs(const std::string& value); + void add_oneofs(std::string&& value); + void add_oneofs(const char* value); + void add_oneofs(const char* value, size_t size); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField& oneofs() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField* mutable_oneofs(); + private: + const std::string& _internal_oneofs(int index) const; + std::string* _internal_add_oneofs(); + public: + + // repeated .google.protobuf.Option options = 4; + int options_size() const; + private: + int _internal_options_size() const; + public: + void clear_options(); + PROTOBUF_NAMESPACE_ID::Option* mutable_options(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::Option >* + mutable_options(); + private: + const PROTOBUF_NAMESPACE_ID::Option& _internal_options(int index) const; + PROTOBUF_NAMESPACE_ID::Option* _internal_add_options(); + public: + const PROTOBUF_NAMESPACE_ID::Option& options(int index) const; + PROTOBUF_NAMESPACE_ID::Option* add_options(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::Option >& + options() const; + + // string name = 1; + void clear_name(); + const std::string& name() const; + void set_name(const std::string& value); + void set_name(std::string&& value); + void set_name(const char* value); + void set_name(const char* value, size_t size); + std::string* mutable_name(); + std::string* release_name(); + void set_allocated_name(std::string* name); + private: + const std::string& _internal_name() const; + void _internal_set_name(const std::string& value); + std::string* _internal_mutable_name(); + public: + + // .google.protobuf.SourceContext source_context = 5; + bool has_source_context() const; + private: + bool _internal_has_source_context() const; + public: + void clear_source_context(); + const PROTOBUF_NAMESPACE_ID::SourceContext& source_context() const; + PROTOBUF_NAMESPACE_ID::SourceContext* release_source_context(); + PROTOBUF_NAMESPACE_ID::SourceContext* mutable_source_context(); + void set_allocated_source_context(PROTOBUF_NAMESPACE_ID::SourceContext* source_context); + private: + const PROTOBUF_NAMESPACE_ID::SourceContext& _internal_source_context() const; + PROTOBUF_NAMESPACE_ID::SourceContext* _internal_mutable_source_context(); + public: + void unsafe_arena_set_allocated_source_context( + PROTOBUF_NAMESPACE_ID::SourceContext* source_context); + PROTOBUF_NAMESPACE_ID::SourceContext* unsafe_arena_release_source_context(); + + // .google.protobuf.Syntax syntax = 6; + void clear_syntax(); + PROTOBUF_NAMESPACE_ID::Syntax syntax() const; + void set_syntax(PROTOBUF_NAMESPACE_ID::Syntax value); + private: + PROTOBUF_NAMESPACE_ID::Syntax _internal_syntax() const; + void _internal_set_syntax(PROTOBUF_NAMESPACE_ID::Syntax value); + public: + + // @@protoc_insertion_point(class_scope:google.protobuf.Type) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::Field > fields_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField oneofs_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::Option > options_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr name_; + PROTOBUF_NAMESPACE_ID::SourceContext* source_context_; + int syntax_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + friend struct ::TableStruct_google_2fprotobuf_2ftype_2eproto; +}; +// ------------------------------------------------------------------- + +class PROTOBUF_EXPORT Field PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:google.protobuf.Field) */ { + public: + inline Field() : Field(nullptr) {} + virtual ~Field(); + + Field(const Field& from); + Field(Field&& from) noexcept + : Field() { + *this = ::std::move(from); + } + + inline Field& operator=(const Field& from) { + CopyFrom(from); + return *this; + } + inline Field& operator=(Field&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const Field& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const Field* internal_default_instance() { + return reinterpret_cast( + &_Field_default_instance_); + } + static constexpr int kIndexInFileMessages = + 1; + + friend void swap(Field& a, Field& b) { + a.Swap(&b); + } + inline void Swap(Field* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(Field* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline Field* New() const final { + return CreateMaybeMessage(nullptr); + } + + Field* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const Field& from); + void MergeFrom(const Field& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(Field* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "google.protobuf.Field"; + } + protected: + explicit Field(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_google_2fprotobuf_2ftype_2eproto); + return ::descriptor_table_google_2fprotobuf_2ftype_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + typedef Field_Kind Kind; + static constexpr Kind TYPE_UNKNOWN = + Field_Kind_TYPE_UNKNOWN; + static constexpr Kind TYPE_DOUBLE = + Field_Kind_TYPE_DOUBLE; + static constexpr Kind TYPE_FLOAT = + Field_Kind_TYPE_FLOAT; + static constexpr Kind TYPE_INT64 = + Field_Kind_TYPE_INT64; + static constexpr Kind TYPE_UINT64 = + Field_Kind_TYPE_UINT64; + static constexpr Kind TYPE_INT32 = + Field_Kind_TYPE_INT32; + static constexpr Kind TYPE_FIXED64 = + Field_Kind_TYPE_FIXED64; + static constexpr Kind TYPE_FIXED32 = + Field_Kind_TYPE_FIXED32; + static constexpr Kind TYPE_BOOL = + Field_Kind_TYPE_BOOL; + static constexpr Kind TYPE_STRING = + Field_Kind_TYPE_STRING; + static constexpr Kind TYPE_GROUP = + Field_Kind_TYPE_GROUP; + static constexpr Kind TYPE_MESSAGE = + Field_Kind_TYPE_MESSAGE; + static constexpr Kind TYPE_BYTES = + Field_Kind_TYPE_BYTES; + static constexpr Kind TYPE_UINT32 = + Field_Kind_TYPE_UINT32; + static constexpr Kind TYPE_ENUM = + Field_Kind_TYPE_ENUM; + static constexpr Kind TYPE_SFIXED32 = + Field_Kind_TYPE_SFIXED32; + static constexpr Kind TYPE_SFIXED64 = + Field_Kind_TYPE_SFIXED64; + static constexpr Kind TYPE_SINT32 = + Field_Kind_TYPE_SINT32; + static constexpr Kind TYPE_SINT64 = + Field_Kind_TYPE_SINT64; + static inline bool Kind_IsValid(int value) { + return Field_Kind_IsValid(value); + } + static constexpr Kind Kind_MIN = + Field_Kind_Kind_MIN; + static constexpr Kind Kind_MAX = + Field_Kind_Kind_MAX; + static constexpr int Kind_ARRAYSIZE = + Field_Kind_Kind_ARRAYSIZE; + static inline const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* + Kind_descriptor() { + return Field_Kind_descriptor(); + } + template + static inline const std::string& Kind_Name(T enum_t_value) { + static_assert(::std::is_same::value || + ::std::is_integral::value, + "Incorrect type passed to function Kind_Name."); + return Field_Kind_Name(enum_t_value); + } + static inline bool Kind_Parse(::PROTOBUF_NAMESPACE_ID::ConstStringParam name, + Kind* value) { + return Field_Kind_Parse(name, value); + } + + typedef Field_Cardinality Cardinality; + static constexpr Cardinality CARDINALITY_UNKNOWN = + Field_Cardinality_CARDINALITY_UNKNOWN; + static constexpr Cardinality CARDINALITY_OPTIONAL = + Field_Cardinality_CARDINALITY_OPTIONAL; + static constexpr Cardinality CARDINALITY_REQUIRED = + Field_Cardinality_CARDINALITY_REQUIRED; + static constexpr Cardinality CARDINALITY_REPEATED = + Field_Cardinality_CARDINALITY_REPEATED; + static inline bool Cardinality_IsValid(int value) { + return Field_Cardinality_IsValid(value); + } + static constexpr Cardinality Cardinality_MIN = + Field_Cardinality_Cardinality_MIN; + static constexpr Cardinality Cardinality_MAX = + Field_Cardinality_Cardinality_MAX; + static constexpr int Cardinality_ARRAYSIZE = + Field_Cardinality_Cardinality_ARRAYSIZE; + static inline const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* + Cardinality_descriptor() { + return Field_Cardinality_descriptor(); + } + template + static inline const std::string& Cardinality_Name(T enum_t_value) { + static_assert(::std::is_same::value || + ::std::is_integral::value, + "Incorrect type passed to function Cardinality_Name."); + return Field_Cardinality_Name(enum_t_value); + } + static inline bool Cardinality_Parse(::PROTOBUF_NAMESPACE_ID::ConstStringParam name, + Cardinality* value) { + return Field_Cardinality_Parse(name, value); + } + + // accessors ------------------------------------------------------- + + enum : int { + kOptionsFieldNumber = 9, + kNameFieldNumber = 4, + kTypeUrlFieldNumber = 6, + kJsonNameFieldNumber = 10, + kDefaultValueFieldNumber = 11, + kKindFieldNumber = 1, + kCardinalityFieldNumber = 2, + kNumberFieldNumber = 3, + kOneofIndexFieldNumber = 7, + kPackedFieldNumber = 8, + }; + // repeated .google.protobuf.Option options = 9; + int options_size() const; + private: + int _internal_options_size() const; + public: + void clear_options(); + PROTOBUF_NAMESPACE_ID::Option* mutable_options(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::Option >* + mutable_options(); + private: + const PROTOBUF_NAMESPACE_ID::Option& _internal_options(int index) const; + PROTOBUF_NAMESPACE_ID::Option* _internal_add_options(); + public: + const PROTOBUF_NAMESPACE_ID::Option& options(int index) const; + PROTOBUF_NAMESPACE_ID::Option* add_options(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::Option >& + options() const; + + // string name = 4; + void clear_name(); + const std::string& name() const; + void set_name(const std::string& value); + void set_name(std::string&& value); + void set_name(const char* value); + void set_name(const char* value, size_t size); + std::string* mutable_name(); + std::string* release_name(); + void set_allocated_name(std::string* name); + private: + const std::string& _internal_name() const; + void _internal_set_name(const std::string& value); + std::string* _internal_mutable_name(); + public: + + // string type_url = 6; + void clear_type_url(); + const std::string& type_url() const; + void set_type_url(const std::string& value); + void set_type_url(std::string&& value); + void set_type_url(const char* value); + void set_type_url(const char* value, size_t size); + std::string* mutable_type_url(); + std::string* release_type_url(); + void set_allocated_type_url(std::string* type_url); + private: + const std::string& _internal_type_url() const; + void _internal_set_type_url(const std::string& value); + std::string* _internal_mutable_type_url(); + public: + + // string json_name = 10; + void clear_json_name(); + const std::string& json_name() const; + void set_json_name(const std::string& value); + void set_json_name(std::string&& value); + void set_json_name(const char* value); + void set_json_name(const char* value, size_t size); + std::string* mutable_json_name(); + std::string* release_json_name(); + void set_allocated_json_name(std::string* json_name); + private: + const std::string& _internal_json_name() const; + void _internal_set_json_name(const std::string& value); + std::string* _internal_mutable_json_name(); + public: + + // string default_value = 11; + void clear_default_value(); + const std::string& default_value() const; + void set_default_value(const std::string& value); + void set_default_value(std::string&& value); + void set_default_value(const char* value); + void set_default_value(const char* value, size_t size); + std::string* mutable_default_value(); + std::string* release_default_value(); + void set_allocated_default_value(std::string* default_value); + private: + const std::string& _internal_default_value() const; + void _internal_set_default_value(const std::string& value); + std::string* _internal_mutable_default_value(); + public: + + // .google.protobuf.Field.Kind kind = 1; + void clear_kind(); + PROTOBUF_NAMESPACE_ID::Field_Kind kind() const; + void set_kind(PROTOBUF_NAMESPACE_ID::Field_Kind value); + private: + PROTOBUF_NAMESPACE_ID::Field_Kind _internal_kind() const; + void _internal_set_kind(PROTOBUF_NAMESPACE_ID::Field_Kind value); + public: + + // .google.protobuf.Field.Cardinality cardinality = 2; + void clear_cardinality(); + PROTOBUF_NAMESPACE_ID::Field_Cardinality cardinality() const; + void set_cardinality(PROTOBUF_NAMESPACE_ID::Field_Cardinality value); + private: + PROTOBUF_NAMESPACE_ID::Field_Cardinality _internal_cardinality() const; + void _internal_set_cardinality(PROTOBUF_NAMESPACE_ID::Field_Cardinality value); + public: + + // int32 number = 3; + void clear_number(); + ::PROTOBUF_NAMESPACE_ID::int32 number() const; + void set_number(::PROTOBUF_NAMESPACE_ID::int32 value); + private: + ::PROTOBUF_NAMESPACE_ID::int32 _internal_number() const; + void _internal_set_number(::PROTOBUF_NAMESPACE_ID::int32 value); + public: + + // int32 oneof_index = 7; + void clear_oneof_index(); + ::PROTOBUF_NAMESPACE_ID::int32 oneof_index() const; + void set_oneof_index(::PROTOBUF_NAMESPACE_ID::int32 value); + private: + ::PROTOBUF_NAMESPACE_ID::int32 _internal_oneof_index() const; + void _internal_set_oneof_index(::PROTOBUF_NAMESPACE_ID::int32 value); + public: + + // bool packed = 8; + void clear_packed(); + bool packed() const; + void set_packed(bool value); + private: + bool _internal_packed() const; + void _internal_set_packed(bool value); + public: + + // @@protoc_insertion_point(class_scope:google.protobuf.Field) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::Option > options_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr name_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr type_url_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr json_name_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr default_value_; + int kind_; + int cardinality_; + ::PROTOBUF_NAMESPACE_ID::int32 number_; + ::PROTOBUF_NAMESPACE_ID::int32 oneof_index_; + bool packed_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + friend struct ::TableStruct_google_2fprotobuf_2ftype_2eproto; +}; +// ------------------------------------------------------------------- + +class PROTOBUF_EXPORT Enum PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:google.protobuf.Enum) */ { + public: + inline Enum() : Enum(nullptr) {} + virtual ~Enum(); + + Enum(const Enum& from); + Enum(Enum&& from) noexcept + : Enum() { + *this = ::std::move(from); + } + + inline Enum& operator=(const Enum& from) { + CopyFrom(from); + return *this; + } + inline Enum& operator=(Enum&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const Enum& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const Enum* internal_default_instance() { + return reinterpret_cast( + &_Enum_default_instance_); + } + static constexpr int kIndexInFileMessages = + 2; + + friend void swap(Enum& a, Enum& b) { + a.Swap(&b); + } + inline void Swap(Enum* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(Enum* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline Enum* New() const final { + return CreateMaybeMessage(nullptr); + } + + Enum* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const Enum& from); + void MergeFrom(const Enum& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(Enum* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "google.protobuf.Enum"; + } + protected: + explicit Enum(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_google_2fprotobuf_2ftype_2eproto); + return ::descriptor_table_google_2fprotobuf_2ftype_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kEnumvalueFieldNumber = 2, + kOptionsFieldNumber = 3, + kNameFieldNumber = 1, + kSourceContextFieldNumber = 4, + kSyntaxFieldNumber = 5, + }; + // repeated .google.protobuf.EnumValue enumvalue = 2; + int enumvalue_size() const; + private: + int _internal_enumvalue_size() const; + public: + void clear_enumvalue(); + PROTOBUF_NAMESPACE_ID::EnumValue* mutable_enumvalue(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::EnumValue >* + mutable_enumvalue(); + private: + const PROTOBUF_NAMESPACE_ID::EnumValue& _internal_enumvalue(int index) const; + PROTOBUF_NAMESPACE_ID::EnumValue* _internal_add_enumvalue(); + public: + const PROTOBUF_NAMESPACE_ID::EnumValue& enumvalue(int index) const; + PROTOBUF_NAMESPACE_ID::EnumValue* add_enumvalue(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::EnumValue >& + enumvalue() const; + + // repeated .google.protobuf.Option options = 3; + int options_size() const; + private: + int _internal_options_size() const; + public: + void clear_options(); + PROTOBUF_NAMESPACE_ID::Option* mutable_options(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::Option >* + mutable_options(); + private: + const PROTOBUF_NAMESPACE_ID::Option& _internal_options(int index) const; + PROTOBUF_NAMESPACE_ID::Option* _internal_add_options(); + public: + const PROTOBUF_NAMESPACE_ID::Option& options(int index) const; + PROTOBUF_NAMESPACE_ID::Option* add_options(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::Option >& + options() const; + + // string name = 1; + void clear_name(); + const std::string& name() const; + void set_name(const std::string& value); + void set_name(std::string&& value); + void set_name(const char* value); + void set_name(const char* value, size_t size); + std::string* mutable_name(); + std::string* release_name(); + void set_allocated_name(std::string* name); + private: + const std::string& _internal_name() const; + void _internal_set_name(const std::string& value); + std::string* _internal_mutable_name(); + public: + + // .google.protobuf.SourceContext source_context = 4; + bool has_source_context() const; + private: + bool _internal_has_source_context() const; + public: + void clear_source_context(); + const PROTOBUF_NAMESPACE_ID::SourceContext& source_context() const; + PROTOBUF_NAMESPACE_ID::SourceContext* release_source_context(); + PROTOBUF_NAMESPACE_ID::SourceContext* mutable_source_context(); + void set_allocated_source_context(PROTOBUF_NAMESPACE_ID::SourceContext* source_context); + private: + const PROTOBUF_NAMESPACE_ID::SourceContext& _internal_source_context() const; + PROTOBUF_NAMESPACE_ID::SourceContext* _internal_mutable_source_context(); + public: + void unsafe_arena_set_allocated_source_context( + PROTOBUF_NAMESPACE_ID::SourceContext* source_context); + PROTOBUF_NAMESPACE_ID::SourceContext* unsafe_arena_release_source_context(); + + // .google.protobuf.Syntax syntax = 5; + void clear_syntax(); + PROTOBUF_NAMESPACE_ID::Syntax syntax() const; + void set_syntax(PROTOBUF_NAMESPACE_ID::Syntax value); + private: + PROTOBUF_NAMESPACE_ID::Syntax _internal_syntax() const; + void _internal_set_syntax(PROTOBUF_NAMESPACE_ID::Syntax value); + public: + + // @@protoc_insertion_point(class_scope:google.protobuf.Enum) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::EnumValue > enumvalue_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::Option > options_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr name_; + PROTOBUF_NAMESPACE_ID::SourceContext* source_context_; + int syntax_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + friend struct ::TableStruct_google_2fprotobuf_2ftype_2eproto; +}; +// ------------------------------------------------------------------- + +class PROTOBUF_EXPORT EnumValue PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:google.protobuf.EnumValue) */ { + public: + inline EnumValue() : EnumValue(nullptr) {} + virtual ~EnumValue(); + + EnumValue(const EnumValue& from); + EnumValue(EnumValue&& from) noexcept + : EnumValue() { + *this = ::std::move(from); + } + + inline EnumValue& operator=(const EnumValue& from) { + CopyFrom(from); + return *this; + } + inline EnumValue& operator=(EnumValue&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const EnumValue& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const EnumValue* internal_default_instance() { + return reinterpret_cast( + &_EnumValue_default_instance_); + } + static constexpr int kIndexInFileMessages = + 3; + + friend void swap(EnumValue& a, EnumValue& b) { + a.Swap(&b); + } + inline void Swap(EnumValue* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(EnumValue* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline EnumValue* New() const final { + return CreateMaybeMessage(nullptr); + } + + EnumValue* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const EnumValue& from); + void MergeFrom(const EnumValue& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(EnumValue* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "google.protobuf.EnumValue"; + } + protected: + explicit EnumValue(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_google_2fprotobuf_2ftype_2eproto); + return ::descriptor_table_google_2fprotobuf_2ftype_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kOptionsFieldNumber = 3, + kNameFieldNumber = 1, + kNumberFieldNumber = 2, + }; + // repeated .google.protobuf.Option options = 3; + int options_size() const; + private: + int _internal_options_size() const; + public: + void clear_options(); + PROTOBUF_NAMESPACE_ID::Option* mutable_options(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::Option >* + mutable_options(); + private: + const PROTOBUF_NAMESPACE_ID::Option& _internal_options(int index) const; + PROTOBUF_NAMESPACE_ID::Option* _internal_add_options(); + public: + const PROTOBUF_NAMESPACE_ID::Option& options(int index) const; + PROTOBUF_NAMESPACE_ID::Option* add_options(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::Option >& + options() const; + + // string name = 1; + void clear_name(); + const std::string& name() const; + void set_name(const std::string& value); + void set_name(std::string&& value); + void set_name(const char* value); + void set_name(const char* value, size_t size); + std::string* mutable_name(); + std::string* release_name(); + void set_allocated_name(std::string* name); + private: + const std::string& _internal_name() const; + void _internal_set_name(const std::string& value); + std::string* _internal_mutable_name(); + public: + + // int32 number = 2; + void clear_number(); + ::PROTOBUF_NAMESPACE_ID::int32 number() const; + void set_number(::PROTOBUF_NAMESPACE_ID::int32 value); + private: + ::PROTOBUF_NAMESPACE_ID::int32 _internal_number() const; + void _internal_set_number(::PROTOBUF_NAMESPACE_ID::int32 value); + public: + + // @@protoc_insertion_point(class_scope:google.protobuf.EnumValue) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::Option > options_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr name_; + ::PROTOBUF_NAMESPACE_ID::int32 number_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + friend struct ::TableStruct_google_2fprotobuf_2ftype_2eproto; +}; +// ------------------------------------------------------------------- + +class PROTOBUF_EXPORT Option PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:google.protobuf.Option) */ { + public: + inline Option() : Option(nullptr) {} + virtual ~Option(); + + Option(const Option& from); + Option(Option&& from) noexcept + : Option() { + *this = ::std::move(from); + } + + inline Option& operator=(const Option& from) { + CopyFrom(from); + return *this; + } + inline Option& operator=(Option&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const Option& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const Option* internal_default_instance() { + return reinterpret_cast( + &_Option_default_instance_); + } + static constexpr int kIndexInFileMessages = + 4; + + friend void swap(Option& a, Option& b) { + a.Swap(&b); + } + inline void Swap(Option* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(Option* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline Option* New() const final { + return CreateMaybeMessage